hints.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import typing
  4. from dataclasses import fields
  5. from enum import auto, Enum
  6. from typing import Dict, List, Optional, Union
  7. # NOTE: if these fail asserts submit a PR to increase them
  8. TRITON_MAX_BLOCK = {
  9. "X": 2048,
  10. "Y": 1024,
  11. "Z": 1024,
  12. "R": 4096 * 16, # * 16 is multi-kernel only
  13. }
  14. class ReductionHint(Enum):
  15. INNER = 0
  16. OUTER = 1
  17. OUTER_TINY = 2
  18. DEFAULT = 3
  19. class TileHint(Enum):
  20. SQUARE = 0
  21. DEFAULT = 1
  22. # Attempt to import AttrsDescriptor from Triton
  23. try:
  24. from triton.compiler.compiler import AttrsDescriptor
  25. attrs_descriptor_available = True
  26. # Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor
  27. attr_desc_fields = {f.name for f in fields(AttrsDescriptor)}
  28. ids_of_folded_args_available = "ids_of_folded_args" in attr_desc_fields
  29. divisible_by_8_available = "divisible_by_8" in attr_desc_fields
  30. except ImportError:
  31. attrs_descriptor_available = False
  32. # Define `instance_descriptor` function with clear conditional handling
  33. if attrs_descriptor_available:
  34. def instance_descriptor(
  35. divisible_by_16=None,
  36. equal_to_1=None,
  37. ids_of_folded_args=None,
  38. divisible_by_8=None,
  39. ):
  40. # Prepare the arguments for AttrsDescriptor
  41. kwargs = {
  42. "divisible_by_16": divisible_by_16,
  43. "equal_to_1": equal_to_1,
  44. }
  45. # Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor
  46. if ids_of_folded_args_available:
  47. kwargs["ids_of_folded_args"] = ids_of_folded_args
  48. if divisible_by_8_available:
  49. kwargs["divisible_by_8"] = divisible_by_8
  50. # Instantiate AttrsDescriptor with the prepared arguments
  51. return AttrsDescriptor(**kwargs)
  52. else:
  53. # Define a namedtuple as a fallback when AttrsDescriptor is not available
  54. instance_descriptor = collections.namedtuple( # type: ignore[no-redef]
  55. "instance_descriptor",
  56. ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
  57. defaults=[tuple(), tuple(), tuple(), tuple()],
  58. )
  59. _NUM_THREADS_PER_WARP = 32
  60. class HeuristicType(Enum):
  61. PERSISTENT_REDUCTION = auto()
  62. POINTWISE = auto()
  63. REDUCTION = auto()
  64. SPLIT_SCAN = auto()
  65. TEMPLATE = auto()
  66. USER_AUTOTUNE = auto()
  67. class AutotuneHint(Enum):
  68. ELEMENTS_PER_WARP_32 = 0
  69. # Triton codegen tries to codegen set of AutotuneHints.
  70. # Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
  71. # which isn't valid python.
  72. # Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
  73. __repr__ = Enum.__str__
  74. class DeviceProperties(typing.NamedTuple):
  75. """Copy device properties into a data structure not requiring torch to be imported"""
  76. type: str # type: ignore[assignment]
  77. index: int # type: ignore[assignment]
  78. cc: int
  79. major: Optional[int] = None
  80. regs_per_multiprocessor: Optional[int] = None
  81. max_threads_per_multi_processor: Optional[int] = None
  82. multi_processor_count: Optional[int] = None
  83. @classmethod
  84. def create(cls, device):
  85. import torch
  86. from torch._dynamo.device_interface import get_interface_for_device
  87. device_type = device.type if torch.version.hip is None else "hip"
  88. device_interface = get_interface_for_device(device)
  89. if device_type == "cuda":
  90. props = device_interface.get_device_properties(device)
  91. return cls(
  92. type=device_type,
  93. index=device.index,
  94. cc=device_interface.get_compute_capability(device),
  95. major=props.major,
  96. regs_per_multiprocessor=props.regs_per_multiprocessor,
  97. max_threads_per_multi_processor=props.max_threads_per_multi_processor,
  98. multi_processor_count=props.multi_processor_count,
  99. )
  100. return cls(
  101. type=device_type,
  102. index=device.index,
  103. cc=device_interface.get_compute_capability(device),
  104. )
  105. class HalideInputSpec(typing.NamedTuple):
  106. ctype: str
  107. name: str
  108. numel: Optional[str] = None
  109. def bindings_type(self):
  110. if self.ctype == "half*":
  111. return "void*" # half not defined
  112. return self.ctype
  113. def halide_type(self):
  114. if self.ctype == "half*":
  115. return "halide_type_t(halide_type_float, 16)" # half not defined
  116. return f"halide_type_of<{self.ctype.replace('*', '')}>()"
  117. class HalideMeta(typing.NamedTuple):
  118. argtypes: List[HalideInputSpec]
  119. target: str
  120. scheduler: str
  121. scheduler_flags: Dict[str, Union[int, str]]
  122. def args(self):
  123. """Command line args to pass to halide generator"""
  124. args = [f"target={self.target}", f"autoscheduler={self.scheduler}"]
  125. for k, v in self.scheduler_flags.items():
  126. args.append(f"autoscheduler.{k}={v}")
  127. return args