_triton.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import hashlib
  4. @functools.lru_cache(None)
  5. def has_triton_package() -> bool:
  6. try:
  7. import triton
  8. return triton is not None
  9. except ImportError:
  10. return False
  11. @functools.lru_cache(None)
  12. def has_triton() -> bool:
  13. from torch._dynamo.device_interface import get_interface_for_device
  14. def cuda_extra_check(device_interface):
  15. return device_interface.Worker.get_device_properties().major >= 7
  16. def _return_true(device_interface):
  17. return True
  18. triton_supported_devices = {"cuda": cuda_extra_check, "xpu": _return_true}
  19. def is_device_compatible_with_triton():
  20. for device, extra_check in triton_supported_devices.items():
  21. device_interface = get_interface_for_device(device)
  22. if device_interface.is_available() and extra_check(device_interface):
  23. return True
  24. return False
  25. return is_device_compatible_with_triton() and has_triton_package()
  26. @functools.lru_cache(None)
  27. def triton_backend():
  28. import torch
  29. if torch.version.hip:
  30. # Does not work with ROCm
  31. return None
  32. from triton.compiler.compiler import make_backend
  33. from triton.runtime.driver import driver
  34. target = driver.active.get_current_target()
  35. return make_backend(target)
  36. @functools.lru_cache(None)
  37. def triton_hash_with_backend():
  38. import torch
  39. if torch.version.hip:
  40. # Does not work with ROCm
  41. return None
  42. from triton.compiler.compiler import triton_key
  43. backend = triton_backend()
  44. key = f"{triton_key()}-{backend.hash()}"
  45. # Hash is upper case so that it can't contain any Python keywords.
  46. return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()
  47. def dtype_to_string(dtype):
  48. if dtype.name.startswith("fp"):
  49. suffix = "float" + dtype.name[2:]
  50. elif dtype.name.startswith("bf"):
  51. suffix = "bfloat" + dtype.name[2:]
  52. else:
  53. suffix = dtype.name
  54. return "triton.language." + suffix
  55. def patch_triton_dtype_repr():
  56. import triton
  57. # Hack to get triton dtype repr to produce an evaluatable expression
  58. # triton.language.float32 emits triton.language.fp32 which does not
  59. # exist
  60. # REMOVE when https://github.com/openai/triton/pull/3342 lands
  61. triton.language.dtype.__repr__ = lambda self: dtype_to_string(self)