functions.py 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031
  1. # mypy: ignore-errors
  2. import collections
  3. import copy
  4. import functools
  5. import inspect
  6. import itertools
  7. import types
  8. from typing import Dict, List, Optional, TYPE_CHECKING, Union
  9. import torch
  10. from .. import variables
  11. from ..bytecode_transformation import create_call_function, create_rot_n
  12. from ..exc import unimplemented, Unsupported
  13. from ..guards import GuardBuilder, install_guard
  14. from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
  15. from ..utils import check_constant_args, get_first_attr, identity, istype, make_cell
  16. from .base import MutableLocal, typestr, VariableTracker
  17. from .constant import ConstantVariable
  18. if TYPE_CHECKING:
  19. from torch._guards import Source
  20. def wrap_bound_arg(tx, val, source=None):
  21. # Source propagation is best effort since not every object we encounter has a source to begin with.
  22. if isinstance(val, VariableTracker):
  23. return val
  24. elif not source:
  25. from torch._dynamo.variables.builder import SourcelessBuilder
  26. return SourcelessBuilder.create(tx, val)
  27. else:
  28. # Create a lazy variable to avoid guarding on __defaults__ unless really
  29. # needed.
  30. return variables.LazyVariableTracker.create(val, source)
  31. def wrap_args_kwargs(tx, result):
  32. for k, v in list(result.items()):
  33. if isinstance(v, (tuple, dict)):
  34. # args/kwargs
  35. result[k] = wrap_bound_arg(tx, v)
  36. def init_cellvars(parent, result, code):
  37. closure_cells = dict()
  38. side_effects = parent.output.side_effects
  39. # for name in itertools.chain(code.co_cellvars, code.co_freevars):
  40. for name in code.co_cellvars:
  41. closure_cells[name] = side_effects.track_cell_new()
  42. if name in result:
  43. side_effects.store_cell(closure_cells[name], result.pop(name))
  44. return closure_cells
  45. def _create_nested_fn(
  46. code, f_globals, name, defaults, closure, kwdefaults, annotations
  47. ):
  48. from types import FunctionType
  49. func = FunctionType(code, f_globals, name, defaults, closure)
  50. func.__kwdefaults__ = kwdefaults
  51. if isinstance(annotations, tuple):
  52. from itertools import pairwise
  53. annotations = dict(pairwise(annotations))
  54. # TypeError: __annotations__ must be set to a dict object
  55. assert annotations is None or isinstance(annotations, dict)
  56. func.__annotations__ = annotations
  57. return func
  58. class BaseUserFunctionVariable(VariableTracker):
  59. def get_filename(self):
  60. return self.get_code().co_filename
  61. def get_name(self):
  62. return self.get_code().co_name
  63. def call_function(
  64. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  65. ) -> "VariableTracker":
  66. return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  67. def call_hasattr(self, tx, name: str) -> VariableTracker:
  68. result = False
  69. try:
  70. result = hasattr(self.get_function(), name)
  71. except NotImplementedError:
  72. if name == "__name__" and isinstance(self, NestedUserFunctionVariable):
  73. result = True
  74. return variables.ConstantVariable.create(result)
  75. def inspect_parameter_names(self):
  76. return list(inspect.signature(self.get_function()).parameters)
  77. def closure_vars(self, tx):
  78. return {}
  79. class UserFunctionVariable(BaseUserFunctionVariable):
  80. """Some unsupported user-defined global function"""
  81. _nonvar_fields = {
  82. "fn",
  83. "is_constant",
  84. *BaseUserFunctionVariable._nonvar_fields,
  85. }
  86. @classmethod
  87. def create_with_source(cls, value, source):
  88. install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
  89. return cls(
  90. value,
  91. source=source,
  92. )
  93. def __init__(self, fn, is_constant=False, **kwargs):
  94. super().__init__(**kwargs)
  95. if getattr(fn, "_dynamo_marked_constant", False):
  96. # This method should be treated as a constant for the purposes of compilation
  97. self.is_constant = True
  98. else:
  99. self.is_constant = False
  100. assert isinstance(
  101. fn, (types.FunctionType, torch.jit.ScriptFunction)
  102. ), f"expected FunctionType found {typestr(fn)} {fn}"
  103. # unpack @torch._dynamo.optimize()(fn) wrapped function
  104. fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
  105. # unpack torch.jit.script_if_tracing
  106. if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
  107. fn = inspect.getattr_static(fn, "__original_fn", fn)
  108. self.fn: types.FunctionType = fn
  109. def as_python_constant(self):
  110. if istype(self, UserFunctionVariable):
  111. return self.fn
  112. # subclasses (such as methods) usually aren't a constant
  113. return super().as_python_constant()
  114. def self_args(self):
  115. return []
  116. def get_function(self):
  117. return self.fn
  118. def get_code(self):
  119. return self.fn.__code__
  120. def python_type(self):
  121. return types.FunctionType
  122. def has_self(self):
  123. return getattr(self.fn, "__self__", None) is not None
  124. def get_globals(self):
  125. return self.fn.__globals__
  126. def bind_args(self, parent, args, kwargs):
  127. assert not self.is_constant
  128. tx = parent.output.root_tx
  129. wrap = functools.partial(wrap_bound_arg, tx=tx)
  130. fn: types.FunctionType = self.fn
  131. defaults = fn.__defaults__ or []
  132. defaults_sources = [
  133. None if self.source is None else DefaultsSource(self.source, idx)
  134. for idx, _ in enumerate(defaults)
  135. ]
  136. fake_func = types.FunctionType(
  137. fn.__code__,
  138. fn.__globals__,
  139. fn.__name__,
  140. tuple(
  141. [
  142. wrap(val=arg, source=source)
  143. for arg, source in zip(defaults, defaults_sources)
  144. ]
  145. ),
  146. fn.__closure__,
  147. )
  148. if fn.__kwdefaults__:
  149. kwdefaults_sources = {
  150. k: None
  151. if self.source is None
  152. else DefaultsSource(self.source, k, is_kw=True)
  153. for k in fn.__kwdefaults__
  154. }
  155. fake_func.__kwdefaults__ = {
  156. k: wrap(val=v, source=kwdefaults_sources[k])
  157. for k, v in fn.__kwdefaults__.items()
  158. }
  159. bound = inspect.signature(fake_func).bind(*args, **kwargs)
  160. bound.apply_defaults()
  161. result = dict(bound.arguments.items())
  162. wrap_args_kwargs(tx, result)
  163. closure_cells = init_cellvars(parent, result, fn.__code__)
  164. closure = self.fn.__closure__ or ()
  165. assert len(closure) == len(self.fn.__code__.co_freevars)
  166. for idx, name, cell in zip(
  167. itertools.count(), self.fn.__code__.co_freevars, closure
  168. ):
  169. if name == "__class__":
  170. source = AttrSource(self.source, "__class__") if self.source else None
  171. result[name] = variables.UserDefinedClassVariable(
  172. cell.cell_contents,
  173. source=source,
  174. )
  175. else:
  176. var = tx.match_nested_cell(name, cell)
  177. if var is not None:
  178. # optimization for cleaner codegen
  179. result[name] = var
  180. elif self.source:
  181. from .builder import VariableBuilder
  182. side_effects = parent.output.side_effects
  183. if cell in side_effects:
  184. out = side_effects[cell]
  185. else:
  186. closure_cell = GetItemSource(
  187. AttrSource(self.source, "__closure__"), idx
  188. )
  189. closure_cell_contents = AttrSource(
  190. closure_cell, "cell_contents"
  191. )
  192. try:
  193. contents_var = VariableBuilder(
  194. parent, closure_cell_contents
  195. )(cell.cell_contents)
  196. except ValueError:
  197. # Cell has not yet been assigned
  198. contents_var = variables.DeletedVariable()
  199. if (
  200. closure_cell_contents.name()
  201. not in tx.mutated_closure_cell_contents
  202. ):
  203. # Optimistically don't allocate the cell, to
  204. # reduce the number of side effects. This is
  205. # important for cond, as without it, any accesses
  206. # to closures create side effects and cond doesn't
  207. # support side effects. If we're wrong and this
  208. # closure cell gets written to, we will restart
  209. # the analysis with this cell's name in the
  210. # mutated list here
  211. result[name] = contents_var
  212. continue
  213. # cells are written to with "cell_contents",
  214. # so the source should just be the closure_cell, not its contents
  215. out = side_effects.track_cell_existing(closure_cell, cell)
  216. side_effects.store_cell(
  217. out,
  218. contents_var,
  219. )
  220. result[name] = out
  221. else:
  222. from .builder import SourcelessBuilder
  223. result[name] = SourcelessBuilder.create(tx, cell.cell_contents)
  224. return result, closure_cells
  225. def export_freevars(self, parent, child):
  226. pass
  227. def call_hasattr(self, tx, name: str) -> VariableTracker:
  228. result = hasattr(self.fn, name)
  229. return variables.ConstantVariable.create(result)
  230. def call_function(
  231. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  232. ) -> "VariableTracker":
  233. if self.is_constant:
  234. return invoke_and_store_as_constant(
  235. tx, self.fn, self.get_name(), args, kwargs
  236. )
  237. return super().call_function(tx, args, kwargs)
  238. class UserMethodVariable(UserFunctionVariable):
  239. """Some unsupported user-defined method"""
  240. def __init__(self, fn, obj, **kwargs):
  241. super().__init__(fn=fn, **kwargs)
  242. self.obj = obj
  243. def __str__(self):
  244. return f"{self.__class__.__name__}({self.fn}, {self.obj})"
  245. def self_args(self):
  246. return [self.obj]
  247. def python_type(self):
  248. return types.MethodType
  249. def call_function(
  250. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  251. ) -> "VariableTracker":
  252. # For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution
  253. # rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method
  254. # since we ensure `forward` of allowed modules can be traced by AOT safely.
  255. # Note this is not only for allowed modules, as user customized modules can extend from
  256. # allowed modules but using parent's `forward` method, which is also covered by this branch.
  257. # If we are tracing the higher order op, we want Dynamo to step inside
  258. # the module call so that Dynamo can see the underlying parameters and
  259. # buffers and raise them as inputs to the graph. The is_root_tracer
  260. # check bypasses the if condition for non-root tracers and directly
  261. # calls the super().call_function at the end, which is basically
  262. # equivalent of inlining the method.
  263. if tx.output.is_root_tracer() and isinstance(
  264. self.obj, variables.NNModuleVariable
  265. ):
  266. module_attr = getattr(self.fn, "__module__", "")
  267. # inline torch.nn.utils.parametrize
  268. if (
  269. module_attr is not None
  270. and module_attr.startswith("torch.nn.")
  271. and module_attr != "torch.nn.utils.parametrize"
  272. or self.is_constant
  273. ):
  274. return self.obj.call_method(
  275. tx, self.fn.__name__, args, kwargs, constant=self.is_constant
  276. )
  277. if self.is_constant:
  278. fn = getattr(self.obj.value, self.fn.__name__)
  279. return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs)
  280. return super().call_function(tx, args, kwargs)
  281. def inspect_parameter_names(self):
  282. return super().inspect_parameter_names()[1:]
  283. class WrappedUserMethodVariable(UserMethodVariable):
  284. def __init__(self, wrapped, context, **kwargs):
  285. kwargs.pop("fn", None)
  286. kwargs.pop("obj", None)
  287. super().__init__(wrapped.fn, wrapped.obj, **kwargs)
  288. self.wrapped = wrapped
  289. self.context = context
  290. def call_function(
  291. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  292. ) -> "VariableTracker":
  293. self.context.enter(tx)
  294. result = super().call_function(tx, args, kwargs)
  295. self.context.exit(tx)
  296. return result
  297. class WrappedUserFunctionVariable(UserFunctionVariable):
  298. def __init__(self, wrapped, context, **kwargs):
  299. kwargs.pop("fn", None)
  300. kwargs.pop("obj", None)
  301. super().__init__(wrapped.fn, **kwargs)
  302. self.wrapped = wrapped
  303. self.context = context
  304. def call_function(
  305. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  306. ) -> "VariableTracker":
  307. self.context.enter(tx)
  308. result = super().call_function(tx, args, kwargs)
  309. self.context.exit(tx)
  310. return result
  311. def invoke_and_store_as_constant(tx, fn, name, args, kwargs):
  312. def convert(x):
  313. if isinstance(x, variables.TensorVariable):
  314. return x.get_real_value()
  315. return x.as_python_constant()
  316. args = [convert(x) for x in args]
  317. kwargs = {k: convert(v) for k, v in kwargs.items()}
  318. res = fn(*args, **kwargs)
  319. return tx.output.register_attr_or_module(
  320. res,
  321. name,
  322. source=ConstantSource(name),
  323. )
  324. class NestedUserFunctionVariable(BaseUserFunctionVariable):
  325. _nonvar_fields = {
  326. "closure_scope",
  327. "f_globals",
  328. *BaseUserFunctionVariable._nonvar_fields,
  329. }
  330. def __init__(
  331. self,
  332. fn_name,
  333. code,
  334. f_globals,
  335. defaults,
  336. kwdefaults,
  337. annotations,
  338. closure,
  339. closure_scope,
  340. wrapped_reconstructible=None,
  341. **kwargs,
  342. ):
  343. super().__init__(**kwargs)
  344. assert isinstance(fn_name.as_python_constant(), str)
  345. assert isinstance(code.as_python_constant(), types.CodeType)
  346. assert isinstance(f_globals, dict)
  347. self.fn_name = fn_name
  348. self.code = code
  349. self.f_globals = f_globals
  350. self.defaults = defaults
  351. self.kwdefaults = kwdefaults
  352. self.annotations = annotations
  353. self.closure = closure
  354. if closure is None:
  355. closure_scope = None
  356. self.closure_scope = closure_scope
  357. # Either a source or a VT with .can_reconstruct() == True
  358. self.wrapped_reconstructible: Optional[
  359. Union[Source, VariableTracker]
  360. ] = wrapped_reconstructible
  361. def self_args(self):
  362. return []
  363. def get_code(self):
  364. return self.code.as_python_constant()
  365. def get_function(self):
  366. if self.closure:
  367. raise NotImplementedError
  368. func = types.FunctionType(
  369. self.code.as_python_constant(),
  370. self.f_globals,
  371. self.fn_name.as_python_constant(),
  372. )
  373. if self.defaults:
  374. func.__defaults__ = self.defaults.as_python_constant()
  375. if self.kwdefaults:
  376. func.__kwdefaults__ = self.kwdefaults.as_python_constant()
  377. if self.annotations:
  378. annotations = self.annotations.as_python_constant()
  379. if isinstance(annotations, tuple):
  380. from itertools import pairwise
  381. annotations = dict(pairwise(annotations))
  382. # TypeError: __annotations__ must be set to a dict object
  383. assert isinstance(annotations, dict)
  384. func.__annotations__ = annotations
  385. return func
  386. def has_closure(self):
  387. return self.closure is not None
  388. def has_self(self):
  389. return False
  390. def get_globals(self):
  391. return self.f_globals
  392. def bind_args(self, parent, args, kwargs):
  393. from .misc import InlinedClosureVariable
  394. code = self.get_code()
  395. func = types.FunctionType(
  396. code,
  397. self.f_globals,
  398. self.fn_name.as_python_constant(),
  399. tuple(self.defaults.items) if self.defaults else None,
  400. tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
  401. )
  402. if self.kwdefaults:
  403. func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant()
  404. bound = inspect.signature(func).bind(*args, **kwargs)
  405. bound.apply_defaults()
  406. result = dict(bound.arguments.items())
  407. wrap_args_kwargs(parent.output.root_tx, result)
  408. closure_cells = init_cellvars(parent, result, code)
  409. for idx, name in enumerate(code.co_freevars):
  410. cell = self.closure.items[idx]
  411. assert getattr(cell, name, name) == name
  412. assert name not in result
  413. if isinstance(cell, InlinedClosureVariable):
  414. # InlinedClosureVariable's are created from LOAD_CLOSURE's from
  415. # InliningInstructionTranslators when the variable name is not found in closure_cells.
  416. # They should remain outside of closure_cells, so that our callee (the
  417. # InliningInstructionTranslator that traces `func`) handles
  418. # the cell correctly - that is, the cell's contents are treated as if they
  419. # are local variables, like in UserFunctionVariable's bind_args for freevars.
  420. cand = parent
  421. while cand and name not in cand.symbolic_locals:
  422. cand = cand.parent
  423. if cand is None:
  424. raise RuntimeError(
  425. f"Couldn't find {name} in the symbolic_locals of the inline interpreter stack"
  426. )
  427. result[name] = cand.symbolic_locals[name]
  428. else:
  429. closure_cells[name] = self.closure.items[idx]
  430. return result, closure_cells
  431. def export_freevars(self, parent, child):
  432. code = self.get_code()
  433. for var in code.co_freevars:
  434. if var in child.symbolic_locals:
  435. parent.symbolic_locals[var] = child.symbolic_locals[var]
  436. def reconstruct(self, codegen):
  437. codegen.load_import_from(__name__, "_create_nested_fn")
  438. codegen(self.code)
  439. codegen.extend_output([codegen._create_load_const(self.f_globals)])
  440. codegen(ConstantVariable.create(self.code.value.co_name))
  441. if self.defaults:
  442. codegen(self.defaults)
  443. else:
  444. codegen.extend_output([codegen.create_load_const(None)])
  445. if self.closure:
  446. codegen(self.closure)
  447. else:
  448. codegen.extend_output([codegen.create_load_const(None)])
  449. if self.kwdefaults:
  450. codegen(self.kwdefaults)
  451. else:
  452. codegen.extend_output([codegen.create_load_const(None)])
  453. if self.annotations:
  454. try:
  455. annotations = self.annotations.as_python_constant()
  456. codegen.extend_output([codegen._create_load_const(annotations)])
  457. except NotImplementedError:
  458. codegen(self.annotations)
  459. else:
  460. codegen.extend_output([codegen.create_load_const(None)])
  461. codegen.extend_output(create_call_function(7, push_null=True))
  462. if self.wrapped_reconstructible:
  463. codegen.load_import_from("functools", "wraps")
  464. codegen(self.wrapped_reconstructible)
  465. codegen.extend_output(create_call_function(1, True))
  466. codegen.extend_output(create_rot_n(2))
  467. codegen.extend_output(create_call_function(1, True))
  468. class SkipFunctionVariable(VariableTracker):
  469. _nonvar_fields = {
  470. "value",
  471. "reason",
  472. *VariableTracker._nonvar_fields,
  473. }
  474. def __init__(self, value, reason=None, **kwargs):
  475. super().__init__(**kwargs)
  476. self.value = value
  477. self.reason = reason
  478. def python_type(self):
  479. return type(self.value)
  480. def as_python_constant(self):
  481. return self.value
  482. @classmethod
  483. def create_with_source(cls, value, source):
  484. install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
  485. return cls(
  486. value,
  487. source=source,
  488. )
  489. @staticmethod
  490. @functools.lru_cache(None)
  491. def fold_through_function_to_wrapper():
  492. return {
  493. collections.namedtuple: variables.UserDefinedClassVariable,
  494. }
  495. def call_function(
  496. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  497. ) -> "VariableTracker":
  498. if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
  499. unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
  500. # Fold through the functions(e.g, collections.namedtuple)
  501. # that inputs & outputs are all python constants
  502. elif (
  503. self.value in self.fold_through_function_to_wrapper().keys()
  504. and check_constant_args(args, kwargs)
  505. ):
  506. value = self.value(
  507. *[x.as_python_constant() for x in args],
  508. **{k: v.as_python_constant() for k, v in kwargs.items()},
  509. )
  510. return self.fold_through_function_to_wrapper().get(self.value)(
  511. value, mutable_local=MutableLocal()
  512. )
  513. elif (
  514. self.value is functools.wraps
  515. and not kwargs
  516. and len(args) == 1
  517. and (
  518. args[0].source is not None or args[0].can_reconstruct(tx.output.root_tx)
  519. )
  520. ):
  521. def wraps(fn):
  522. if isinstance(fn, variables.NestedUserFunctionVariable):
  523. if args[0].source:
  524. reconstructible = args[0].source
  525. else:
  526. reconstructible = args[0]
  527. return fn.clone(wrapped_reconstructible=reconstructible)
  528. unimplemented(f"functools.wraps({fn})")
  529. return variables.LambdaVariable(wraps)
  530. else:
  531. try:
  532. path = inspect.getfile(self.value)
  533. msg = f"'skip function {self.value.__qualname__} in file {path}'"
  534. except TypeError:
  535. known_python_builtin_modules = {"_abc", "_warnings"}
  536. if self.value.__module__ in known_python_builtin_modules:
  537. msg = (
  538. f"Graph break due to unsupported Python builtin {self.value.__module__}.{self.value.__qualname__}. "
  539. f"Please file an issue on GitHub "
  540. f"so the PyTorch team can add support for it. "
  541. )
  542. else:
  543. msg = (
  544. f"Graph break due to unsupported builtin {self.value.__module__}.{self.value.__qualname__}. "
  545. f"This function is either a Python builtin (e.g. _warnings.warn) "
  546. f"or a third-party C/C++ Python extension (perhaps created with pybind). "
  547. f"If it is a Python builtin, please file an issue on GitHub "
  548. f"so the PyTorch team can add support for it and see the next case for a workaround. "
  549. f"If it is a third-party C/C++ Python extension, please "
  550. f"either wrap it into a PyTorch-understood custom operator "
  551. f"(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html "
  552. f"for more details) or, if it is traceable, use "
  553. f"torch.compiler.allow_in_graph."
  554. )
  555. # also warn on it because most users won't see the graph break message
  556. torch._dynamo.utils.warn_once(msg)
  557. msg += f"', {self.reason}'" if self.reason else ""
  558. unimplemented(msg)
  559. def _traceable_collective_remaps():
  560. # We can't rely on importing from distributed, since it's not always built
  561. if torch.distributed.is_available():
  562. from torch.distributed._functional_collectives import (
  563. traceable_collective_remaps,
  564. )
  565. return traceable_collective_remaps
  566. return {}
  567. def _traceable_collectives_source(tx, fn):
  568. assert torch.distributed.is_available(), "Illegal invocation."
  569. assert fn in _traceable_collective_remaps().values()
  570. inner_name = fn.__name__
  571. path_source = tx.import_source("torch.distributed._functional_collectives")
  572. return AttrSource(path_source, inner_name)
  573. class CollectiveFunctionRewriteVariable(UserFunctionVariable):
  574. """
  575. Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives.
  576. This class provides both a way to check if a function is remappable, and perform the remapping.
  577. In the case that a function is 'remappable' but only for some combinations of call-time arguments,
  578. we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse
  579. than status-quo as we currently graph-break on all distributed.* collectives.
  580. """
  581. def __init__(self, fn, *, replacement_var, **kwargs):
  582. super().__init__(fn, **kwargs)
  583. assert isinstance(replacement_var, UserFunctionVariable)
  584. self.replacement_var = replacement_var
  585. @staticmethod
  586. def create(tx, old_fn, source, **options):
  587. new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn)
  588. return CollectiveFunctionRewriteVariable(
  589. old_fn,
  590. replacement_var=UserFunctionVariable(new_fn, source=new_source, **options),
  591. source=source,
  592. **options,
  593. )
  594. @staticmethod
  595. def can_rewrite(variable):
  596. return (
  597. inspect.isfunction(variable) and variable in _traceable_collective_remaps()
  598. )
  599. @staticmethod
  600. def rewrite(tx, fn):
  601. new_fn = _traceable_collective_remaps()[fn]
  602. return new_fn, _traceable_collectives_source(tx, new_fn)
  603. def call_function(
  604. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  605. ) -> "VariableTracker":
  606. # call_function must check any unsupported arguments and graph-break.
  607. # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
  608. # since that's the contract for putting a mapping in `traceable_collective_remaps`
  609. import torch.distributed as dist
  610. from torch.distributed._functional_collectives import REDUCE_OP_TO_STR
  611. # Merge args into kwargs so positional and keyword args
  612. # can be processed the same way.
  613. signature = inspect.signature(self.fn)
  614. kwargs = dict(signature.bind(*args, **kwargs).arguments)
  615. args = ()
  616. if "async_op" in kwargs and kwargs["async_op"].as_python_constant():
  617. unimplemented(
  618. f"CollectiveFunctionRewriteVariable can't support async_op=True for {self.fn}"
  619. )
  620. if self.fn in (
  621. dist.all_reduce,
  622. dist.reduce_scatter_tensor,
  623. dist._reduce_scatter_base,
  624. ):
  625. reduce_op_var = kwargs.get("op")
  626. reduce_op = (
  627. reduce_op_var.value
  628. if reduce_op_var is not None
  629. else signature.parameters["op"].default
  630. )
  631. if reduce_op not in REDUCE_OP_TO_STR:
  632. raise ValueError(f"Unsupported all_reduce op: {reduce_op}")
  633. kwargs["op"] = variables.ConstantVariable.create(
  634. REDUCE_OP_TO_STR[reduce_op]
  635. )
  636. return self.replacement_var.call_function(tx, args, kwargs)
  637. class FunctoolsPartialVariable(VariableTracker):
  638. def __init__(self, func: VariableTracker, args, keywords, **kwargs):
  639. super().__init__(**kwargs)
  640. self.func = func
  641. assert isinstance(args, list)
  642. self.args = args
  643. assert isinstance(keywords, dict)
  644. self.keywords = keywords
  645. def reconstruct(self, codegen):
  646. codegen.load_import_from("functools", "partial")
  647. codegen(self.func)
  648. if self.args:
  649. codegen.foreach(self.args)
  650. if not self.keywords:
  651. codegen.extend_output(create_call_function(len(self.args) + 1, True))
  652. return
  653. codegen.foreach(self.keywords.values())
  654. keys = tuple(self.keywords.keys())
  655. codegen.extend_output(
  656. codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, True)
  657. )
  658. def get_function(self):
  659. return self.as_python_constant()
  660. def call_function(
  661. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  662. ) -> "VariableTracker":
  663. merged_args = self.args + args
  664. merged_kwargs = {**self.keywords, **kwargs}
  665. return self.func.call_function(tx, merged_args, merged_kwargs)
  666. def call_hasattr(self, tx, name: str) -> VariableTracker:
  667. # functools.partial uses slots, so attributes are constant
  668. return variables.ConstantVariable.create(
  669. hasattr(functools.partial(identity), name)
  670. )
  671. def as_python_constant(self):
  672. return functools.partial(
  673. self.func.as_python_constant(),
  674. *[arg.as_python_constant() for arg in self.args],
  675. **{k: v.as_python_constant() for k, v in self.keywords.items()},
  676. )
  677. def guard_as_python_constant(self):
  678. """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
  679. return functools.partial(
  680. self.func.guard_as_python_constant(),
  681. *[v.guard_as_python_constant() for v in self.args],
  682. **{k: v.guard_as_python_constant() for k, v in self.keywords.items()},
  683. )
  684. class TritonKernelVariable(VariableTracker):
  685. def __init__(self, kernel, kernel_idx, grid, **kwargs):
  686. from triton.runtime.autotuner import Autotuner
  687. from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
  688. super().__init__(**kwargs)
  689. assert kernel is not None
  690. self.kernel = kernel
  691. self.kernel_idx = kernel_side_table.add_kernel(kernel)
  692. assert kernel_idx is None or self.kernel_idx == kernel_idx
  693. self.grid = grid
  694. if isinstance(kernel, Autotuner):
  695. # We only support configs and keys arguments of triton.autotune
  696. # Make sure other arguments are defaulted
  697. defaults = inspect.signature(Autotuner.__init__).parameters
  698. # Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep.
  699. # The call to get_first_attr is to maintain backward-compatibility.
  700. if (
  701. (
  702. "warmup" in defaults
  703. and defaults["warmup"].default
  704. != get_first_attr(kernel, "num_warmups", "warmup")
  705. )
  706. or (
  707. "rep" in defaults
  708. and defaults["rep"].default
  709. != get_first_attr(kernel, "num_reps", "rep")
  710. )
  711. or (
  712. "prune_configs_by" in defaults
  713. and defaults["prune_configs_by"].default
  714. != kernel.early_config_prune
  715. )
  716. # Set via reset_to_zero argument
  717. or len(kernel.reset_idx) != 0
  718. or len(kernel.restore_idx) != 0
  719. ):
  720. raise Unsupported(
  721. "Only configs and keys are supported for triton.autotune"
  722. )
  723. def call_function(
  724. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  725. ) -> "VariableTracker":
  726. from triton.runtime.autotuner import autotune, Autotuner, Config
  727. from .constant import ConstantVariable
  728. from .dicts import ConstDictVariable
  729. from .lists import BaseListVariable
  730. if "num_ctas" in kwargs:
  731. raise Unsupported(
  732. "Passing num_ctas directly to the Triton kernel is not supported. "
  733. "Please use a Config in @triton.autotune instead."
  734. )
  735. special_kwargs = {}
  736. for name in ("num_warps", "num_stages"):
  737. if name in kwargs:
  738. # remove special kwargs from `kwargs`
  739. val = kwargs.pop(name)
  740. assert isinstance(val, ConstantVariable)
  741. special_kwargs[name] = val.value
  742. if special_kwargs:
  743. if isinstance(self.kernel, Autotuner):
  744. # if there is Autotuner already, set
  745. # special kwargs to each of its configs
  746. new_configs = copy.deepcopy(self.kernel.configs)
  747. for config in new_configs:
  748. config.__dict__.update(special_kwargs)
  749. new_kernel = autotune(configs=new_configs, key=[])(self.kernel.fn)
  750. else:
  751. # if there is no Autotuner, wrap the kernel into a
  752. # new one with a single config with special kwargs
  753. new_config = Config(kwargs={}, **special_kwargs)
  754. new_kernel = autotune(configs=[new_config], key=[])(self.kernel)
  755. # create a new variable to contain the new (wrapped) kernel;
  756. # skip kernel_idx to get a new record in the kernel side table
  757. new_var = TritonKernelVariable(new_kernel, None, self.grid)
  758. return new_var.call_function(tx, args, kwargs)
  759. if self.grid is None:
  760. raise Unsupported("Triton kernels should always be called with a grid")
  761. # Both for grid's meta as well as for the kernel, we need combined
  762. # args and kwargs combined and normalized
  763. combined_args_raw = {**dict(zip(self.kernel.arg_names, args)), **kwargs}
  764. combined_args = {
  765. variables.ConstantVariable.create(k): v
  766. for k, v in combined_args_raw.items()
  767. }
  768. configs = (
  769. [config.kwargs for config in self.kernel.configs]
  770. if isinstance(self.kernel, Autotuner)
  771. else [{}]
  772. )
  773. grids = []
  774. for config_args in configs:
  775. # If the grid is a function, then lets execute it and convert it to
  776. # a list
  777. grid = self.grid
  778. if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)):
  779. # Populate the special "meta" argument to call the grid function
  780. config_args = {
  781. ConstantVariable.create(k): ConstantVariable.create(v)
  782. for k, v in config_args.items()
  783. }
  784. meta = ConstDictVariable({**combined_args, **config_args}, dict)
  785. grid = grid.call_function(tx, [meta], {})
  786. # Now, the grid must be a list either originally or through above
  787. # modification
  788. if isinstance(grid, BaseListVariable):
  789. grids.append(grid.as_proxy())
  790. else:
  791. unimplemented(f"grid for the triton kernel is {type(grid)}")
  792. for i in range(len(grids)):
  793. if not isinstance(grids[i], tuple):
  794. raise Unsupported("Only tuple grids are supported")
  795. # inductor expects all grids to be 3-tuple so lets make it
  796. if len(grids[i]) == 1:
  797. grids[i] = (grids[i][0], 1, 1)
  798. elif len(grids[i]) == 2:
  799. grids[i] = (grids[i][0], grids[i][1], 1)
  800. elif len(grids[i]) > 3:
  801. raise Unsupported("Grid can have at most rank 3")
  802. assert len(grids) != 0
  803. if len(set(grids)) == 1:
  804. # If there's only one unique grid, lets simplify
  805. grids = [grids[0]]
  806. from torch._higher_order_ops.triton_kernel_wrap import (
  807. kernel_side_table,
  808. triton_kernel_wrapper_mutation,
  809. )
  810. # Combine args and kwargs and pass as a dict so that if user defined triton
  811. # kernel uses variables as 'grid' or 'kernel', it does not conflict with
  812. # parameters of the wrapper function
  813. constant_args = {
  814. k: v.as_python_constant()
  815. for k, v in combined_args_raw.items()
  816. if isinstance(v, ConstantVariable)
  817. }
  818. non_constant_args = {
  819. k: v
  820. for k, v in combined_args.items()
  821. if not isinstance(v, ConstantVariable)
  822. }
  823. constant_args_idx = kernel_side_table.add_constant_args(constant_args)
  824. meta = ConstDictVariable(non_constant_args, dict)
  825. tx.output.create_proxy(
  826. "call_function",
  827. triton_kernel_wrapper_mutation,
  828. (),
  829. {
  830. "kernel_idx": self.kernel_idx,
  831. "constant_args_idx": constant_args_idx,
  832. "grid": grids,
  833. "kwargs": meta.as_proxy(),
  834. },
  835. )
  836. return variables.ConstantVariable(
  837. None,
  838. )
  839. def call_method(
  840. self,
  841. tx,
  842. name,
  843. args: "List[VariableTracker]",
  844. kwargs: "Dict[str, VariableTracker]",
  845. ) -> "VariableTracker":
  846. if name == "__getitem__":
  847. # __getitem__ should only be called if we don't already have a grid
  848. # Only grid needs to be passed
  849. if self.grid is not None or len(args) != 1:
  850. raise Unsupported(
  851. "Triton kernels should be called with only a single grid"
  852. )
  853. return TritonKernelVariable(
  854. kernel=self.kernel,
  855. kernel_idx=self.kernel_idx,
  856. grid=args[0],
  857. )
  858. elif name == "run":
  859. if "grid" not in kwargs:
  860. raise Unsupported("Triton kernel requires to be called with a grid")
  861. grid = kwargs.pop("grid")
  862. kwargs.pop("warmup", None)
  863. # rewrite kernel.run(*args, grid=grid) to kernel[grid](*args)
  864. return TritonKernelVariable(
  865. kernel=self.kernel, kernel_idx=self.kernel_idx, grid=grid
  866. ).call_function(tx, args, kwargs)
  867. # Bail out to parent's implementation
  868. return super().call_method(tx, name, args, kwargs)