torch.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import inspect
  4. import logging
  5. import math
  6. import re
  7. from typing import Dict, List
  8. import torch._C
  9. import torch._refs
  10. import torch.fx
  11. import torch.nn
  12. import torch.onnx.operators
  13. from torch._logging import warning_once
  14. from torch._streambase import _StreamBase
  15. from ..._guards import TracingContext
  16. from .. import config, polyfill, variables
  17. from ..codegen import PyCodegen
  18. from ..create_parameter_op import new_parameter_placeholder, tracable_create_parameter
  19. from ..device_interface import get_registered_device_interfaces
  20. from ..exc import unimplemented
  21. from ..guards import GuardBuilder, install_guard
  22. from ..source import SyntheticLocalSource
  23. from ..utils import (
  24. check_unspec_or_constant_args,
  25. guard_if_dyn,
  26. has_torch_function,
  27. hashable,
  28. product,
  29. proxy_args_kwargs,
  30. unwrap_if_wrapper,
  31. )
  32. from .base import VariableTracker
  33. from .ctx_manager import (
  34. AutocastModeVariable,
  35. NullContextVariable,
  36. TorchFunctionDisableVariable,
  37. )
  38. from .distributed import DistributedVariable, ProcessGroupVariable
  39. from .lists import ListVariable, TupleVariable
  40. from .torch_function import can_dispatch_torch_function, dispatch_torch_function
  41. try:
  42. import numpy as np
  43. except ModuleNotFoundError:
  44. np = None # type: ignore[assignment]
  45. log = logging.getLogger(__name__)
  46. supported_ctx_manager_classes = dict.fromkeys(
  47. [
  48. torch.profiler.profiler.profile,
  49. torch.autograd.forward_ad._set_fwd_grad_enabled,
  50. torch.autograd.forward_ad.dual_level,
  51. torch.autograd.profiler.profile,
  52. torch.autograd.profiler.record_function,
  53. torch._C.DisableTorchFunctionSubclass,
  54. torch._functorch.vmap.vmap_increment_nesting,
  55. torch._functorch.eager_transforms.grad_increment_nesting,
  56. torch._functorch.eager_transforms.jvp_increment_nesting,
  57. torch._functorch.eager_transforms.enable_inplace_requires_grad,
  58. torch.amp.autocast_mode.autocast,
  59. torch.autograd.grad_mode.enable_grad,
  60. torch.autograd.grad_mode.inference_mode,
  61. torch.autograd.grad_mode.no_grad,
  62. torch.autograd.grad_mode.set_grad_enabled,
  63. torch.autograd.graph.disable_saved_tensors_hooks,
  64. torch.cpu.amp.autocast_mode.autocast,
  65. torch.cuda.amp.autocast_mode.autocast,
  66. ]
  67. )
  68. REWRITE_OPS_TO_TENSOR_SIZE_METHOD = dict.fromkeys(
  69. [
  70. torch.onnx.operators.shape_as_tensor,
  71. torch._shape_as_tensor,
  72. ]
  73. )
  74. constant_fold_functions = [
  75. torch._assert,
  76. torch._utils._get_device_index,
  77. torch._C._get_cublas_allow_tf32,
  78. torch._C._is_any_autocast_enabled,
  79. torch.cuda.get_device_properties,
  80. torch.cuda.is_available,
  81. torch.distributed.is_available,
  82. torch.get_autocast_dtype,
  83. torch.get_autocast_gpu_dtype,
  84. torch.get_default_dtype,
  85. torch.is_autocast_cache_enabled,
  86. torch.is_autocast_cpu_enabled,
  87. torch.is_autocast_enabled,
  88. torch.is_complex,
  89. torch.is_floating_point,
  90. torch.nn.functional._Reduction.get_enum, # type: ignore[attr-defined]
  91. torch.promote_types,
  92. torch._C._get_privateuse1_backend_name,
  93. ]
  94. if torch.distributed.is_available():
  95. constant_fold_functions.extend(
  96. [
  97. torch.distributed.is_initialized,
  98. torch.distributed.get_rank,
  99. torch.distributed.get_world_size,
  100. ]
  101. )
  102. # Convert to dict for O(1) access times
  103. constant_fold_functions = dict.fromkeys(constant_fold_functions)
  104. tracing_state_functions = {
  105. torch.jit.is_scripting: False,
  106. torch.jit.is_tracing: False,
  107. torch._C._get_tracing_state: None,
  108. torch.fx._symbolic_trace.is_fx_tracing: False,
  109. torch.onnx.is_in_onnx_export: False,
  110. torch._dynamo.external_utils.is_compiling: True,
  111. torch._utils.is_compiling: True,
  112. torch.compiler.is_compiling: True,
  113. torch.compiler.is_dynamo_compiling: True,
  114. }
  115. bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"])
  116. class BaseTorchVariable(VariableTracker):
  117. """common base for all torch.* functions, classes, modules and other things"""
  118. @classmethod
  119. def create_with_source(cls, value, source):
  120. install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
  121. return cls(
  122. value,
  123. source=source,
  124. )
  125. def __init__(self, value, **kwargs):
  126. super().__init__(**kwargs)
  127. self.value = value
  128. def reconstruct(self, codegen):
  129. try:
  130. name = f"{self.value.__module__}.{self.value.__name__}"
  131. except Exception:
  132. name = f"torch_obj_{id(self.value)}"
  133. unique_var_name = "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name)
  134. codegen.extend_output(
  135. codegen.setup_globally_cached(unique_var_name, self.value, False)
  136. )
  137. def as_proxy(self):
  138. return self.value
  139. def python_type(self):
  140. return type(self.value)
  141. def as_python_constant(self):
  142. return self.value
  143. def call_hasattr(self, tx, name):
  144. result = hasattr(self.value, name)
  145. return variables.ConstantVariable.create(result)
  146. def can_constant_fold_through(self):
  147. if self.value in constant_fold_functions:
  148. return True
  149. return getattr(self.value, "__module__", None) == "math"
  150. class TorchCtxManagerClassVariable(BaseTorchVariable):
  151. """Points to a context manager class in torch.* that dynamo has implementations"""
  152. def __repr__(self):
  153. return f"TorchCtxManagerClassVariable({self.value})"
  154. @staticmethod
  155. def is_matching_cls(value):
  156. # Unwrap if it's a functools.lru_cache wrapper
  157. value = unwrap_if_wrapper(value)
  158. # We can't do isinstance(value, type) check because some ctx managers
  159. # are implemented as a function decorated by contextlib.contextmanager,
  160. # E.g., torch._functorch.vmap.vmap_increment_nesting.
  161. return (
  162. # Context manager type or function with @contextmanager is callable
  163. callable(value)
  164. and (
  165. hashable(value) # accesses value.__hash__()
  166. and value in supported_ctx_manager_classes
  167. )
  168. )
  169. def call_function(
  170. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  171. ) -> "VariableTracker":
  172. from . import (
  173. DisabledSavedTensorsHooksVariable,
  174. DualLevelContextManager,
  175. GradIncrementNestingCtxManagerVariable,
  176. GradInplaceRequiresGradCtxManagerVariable,
  177. GradModeVariable,
  178. InferenceModeVariable,
  179. JvpIncrementNestingCtxManagerVariable,
  180. SetFwdGradEnabledContextManager,
  181. StreamVariable,
  182. VmapIncrementNestingCtxManagerVariable,
  183. )
  184. if self.value is torch.no_grad:
  185. if len(args) == 1 and isinstance(
  186. args[0], variables.functions.BaseUserFunctionVariable
  187. ):
  188. ctx = GradModeVariable.create(tx, False)
  189. return ctx.call_function(tx, args, kwargs)
  190. else:
  191. return GradModeVariable.create(tx, False)
  192. elif self.value is torch.enable_grad:
  193. if len(args) == 1 and isinstance(
  194. args[0], variables.functions.BaseUserFunctionVariable
  195. ):
  196. ctx = GradModeVariable.create(tx, True)
  197. return ctx.call_function(tx, args, kwargs)
  198. return GradModeVariable.create(tx, True)
  199. elif self.value is torch.set_grad_enabled and len(args) == 1:
  200. return GradModeVariable.create(
  201. tx, args[0].as_python_constant(), initialized=True
  202. )
  203. elif self.value is torch.inference_mode:
  204. assert len(args) <= 1 and len(kwargs) == 0
  205. inf_mode = args[0].as_python_constant() if len(args) == 1 else True
  206. return InferenceModeVariable.create(tx, inf_mode)
  207. elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase):
  208. from torch._dynamo.variables.builder import wrap_fx_proxy_cls
  209. return wrap_fx_proxy_cls(
  210. StreamVariable,
  211. tx,
  212. tx.output.create_proxy(
  213. "call_function",
  214. self.value,
  215. (),
  216. {},
  217. ),
  218. )
  219. elif self.value in (
  220. torch.amp.autocast_mode.autocast,
  221. torch.cuda.amp.autocast,
  222. torch.cpu.amp.autocast,
  223. ):
  224. return AutocastModeVariable.create(self.value, args, kwargs)
  225. elif self.value in (
  226. torch.profiler.profile,
  227. torch.profiler.record_function,
  228. torch.autograd.profiler.profile,
  229. torch.autograd.profiler.record_function,
  230. ):
  231. warning_once(log, "Profiler function %s will be ignored", self.value)
  232. return NullContextVariable()
  233. elif self.value is torch._C.DisableTorchFunctionSubclass:
  234. assert not (args or kwargs)
  235. return TorchFunctionDisableVariable.create(tx)
  236. elif self.value is torch._functorch.vmap.vmap_increment_nesting:
  237. assert len(args) == 2
  238. return VmapIncrementNestingCtxManagerVariable.create(
  239. tx,
  240. [guard_if_dyn(x) for x in args],
  241. )
  242. elif self.value is torch._functorch.eager_transforms.jvp_increment_nesting:
  243. assert len(args) == 0
  244. return JvpIncrementNestingCtxManagerVariable.create(tx)
  245. elif self.value is torch.autograd.forward_ad._set_fwd_grad_enabled:
  246. assert len(args) == 1
  247. return SetFwdGradEnabledContextManager.create(
  248. tx,
  249. [guard_if_dyn(x) for x in args],
  250. )
  251. elif self.value is torch.autograd.forward_ad.dual_level:
  252. assert len(args) == 0
  253. return DualLevelContextManager.create(tx)
  254. elif self.value is torch._functorch.eager_transforms.grad_increment_nesting:
  255. assert len(args) == 0
  256. return GradIncrementNestingCtxManagerVariable.create(tx)
  257. elif (
  258. self.value is torch._functorch.eager_transforms.enable_inplace_requires_grad
  259. ):
  260. assert len(args) == 1
  261. return GradInplaceRequiresGradCtxManagerVariable.create(
  262. tx,
  263. [guard_if_dyn(x) for x in args],
  264. )
  265. elif self.value is torch.autograd.graph.disable_saved_tensors_hooks:
  266. assert len(args) == 1
  267. return DisabledSavedTensorsHooksVariable.create(
  268. tx, args[0].as_python_constant()
  269. )
  270. return super().call_function(tx, args, kwargs)
  271. class TorchInGraphFunctionVariable(BaseTorchVariable):
  272. """Points to a torch function/method that should be put in FX graph"""
  273. def __repr__(self):
  274. return f"TorchInGraphFunctionVariable({self.value})"
  275. def get_function(self):
  276. return self.value
  277. @staticmethod
  278. @functools.lru_cache(None)
  279. def _get_handlers():
  280. """Build a dict from function -> method to handle it so that we are O(1)
  281. in terms of the number of function with special handling."""
  282. handlers = {}
  283. def register(*fns):
  284. def _register(handler):
  285. for fn in fns:
  286. assert fn not in handlers, fn
  287. handlers[fn] = handler
  288. return handler
  289. assert callable(fns[0])
  290. return _register
  291. from torch.backends.cuda import SDPAParams
  292. from . import (
  293. ConstantVariable,
  294. DeterministicAlgorithmsVariable,
  295. GradModeVariable,
  296. StreamContextVariable,
  297. SymNodeVariable,
  298. TensorVariable,
  299. UserDefinedObjectVariable,
  300. )
  301. from .builder import SourcelessBuilder, wrap_fx_proxy, wrap_fx_proxy_cls
  302. @register(*tracing_state_functions)
  303. def handle_tracing_state_functions(self, tx, *args, **kwargs):
  304. assert not args and not kwargs
  305. # See: https://github.com/pytorch/pytorch/issues/110765
  306. if self.value in (
  307. torch._utils.is_compiling,
  308. torch._dynamo.external_utils.is_compiling,
  309. torch.compiler.is_compiling,
  310. torch.compiler.is_dynamo_compiling,
  311. ):
  312. tx.mark_inconsistent_side_effects()
  313. return ConstantVariable.create(tracing_state_functions[self.value])
  314. @register(torch.overrides.get_default_nowrap_functions.__wrapped__)
  315. def handle_get_default_nowrap_functions(self, tx, *args, **kwargs):
  316. # [Note: __torch_function__] we return empty here because we restrict
  317. # the set of functions that we trace __torch_function__ on to
  318. # functions outside of the actual set. Implementing this properly will require implementing
  319. # some variable types to track and compare tensor getset descriptors
  320. return SourcelessBuilder.create(
  321. tx, torch.overrides.get_default_nowrap_functions()
  322. )
  323. @register(torch.ops.inductor.accumulate_grad_.default)
  324. def handle_accumulate_grad_(self, tx, *args, **kwargs):
  325. return tx.inline_user_function_return(
  326. SourcelessBuilder.create(tx, polyfill.accumulate_grad), args, kwargs
  327. )
  328. @register(math.radians)
  329. def handle_radians(self, tx, *args, **kwargs):
  330. if not check_unspec_or_constant_args(args, kwargs):
  331. # Use polyfill to convert math.radians(x) into math.pi * x / 180.0
  332. return tx.inline_user_function_return(
  333. SourcelessBuilder.create(tx, polyfill.radians), args, kwargs
  334. )
  335. @register(torch.is_tensor, torch.overrides.is_tensor_like)
  336. def handle_is_tensor(self, tx, arg):
  337. if isinstance(arg, TensorVariable) or (
  338. self.value is torch.overrides.is_tensor_like
  339. and isinstance(arg, UserDefinedObjectVariable)
  340. and hasattr(arg.value, "__torch_function__")
  341. ):
  342. return ConstantVariable.create(True)
  343. else:
  344. return ConstantVariable.create(False)
  345. @register(
  346. torch.is_floating_point,
  347. torch.is_complex,
  348. )
  349. def handle_is_floating_point(self, tx, input):
  350. input_arg = input
  351. if isinstance(input_arg, TensorVariable) and input_arg.dtype is not None:
  352. if self.value is torch.is_floating_point:
  353. return ConstantVariable.create(input_arg.dtype.is_floating_point)
  354. elif self.value is torch.is_complex:
  355. return ConstantVariable.create(input_arg.dtype.is_complex)
  356. else:
  357. raise AssertionError(f"calling {self.value}")
  358. @register(torch.numel)
  359. def handle_numel(self, tx, input):
  360. if isinstance(input, TensorVariable) and input.size is not None:
  361. return ConstantVariable.create(product(input.size))
  362. elif isinstance(input, TensorVariable):
  363. # Workaround dynamic shapes issue
  364. return input.call_method(tx, "numel", [], {})
  365. @register(*REWRITE_OPS_TO_TENSOR_SIZE_METHOD)
  366. def handle_tensor_size_rewrites(self, tx, input):
  367. assert isinstance(input, TensorVariable)
  368. return input.call_method(tx, "size", [], {})
  369. @register(
  370. torch.nn.modules.utils._single,
  371. torch.nn.modules.utils._pair,
  372. torch.nn.modules.utils._triple,
  373. torch.nn.modules.utils._quadruple,
  374. torch.nn.modules.utils._ntuple,
  375. )
  376. def handle_ntuple(self, tx, *args, **kwargs):
  377. return self._call_ntuple(tx, args, kwargs)
  378. @register(torch.is_grad_enabled)
  379. def handle_is_grad_enabled(self, tx):
  380. install_guard(GradModeVariable._guards_singleton)
  381. return ConstantVariable.create(torch.is_grad_enabled())
  382. @register(torch.use_deterministic_algorithms)
  383. def handle_use_deterministic_algorithms(self, tx, mode, warn_only=False):
  384. if warn_only and warn_only.as_python_constant():
  385. unimplemented("torch.use_deterministic_algorithms(warn_only=True)")
  386. return DeterministicAlgorithmsVariable.create(tx, mode.as_python_constant())
  387. @register(torch.are_deterministic_algorithms_enabled)
  388. def handle_are_deterministic_algorithms_enabled(self, tx):
  389. install_guard(DeterministicAlgorithmsVariable._guards_singleton)
  390. return ConstantVariable.create(torch.are_deterministic_algorithms_enabled())
  391. @register(torch._C._is_torch_function_enabled)
  392. def handle_is_torch_function_enabled(self, tx):
  393. install_guard(TorchFunctionDisableVariable._guards_singleton)
  394. return ConstantVariable.create(tx.output.torch_function_enabled)
  395. @register(
  396. torch.overrides.has_torch_function,
  397. torch.overrides.has_torch_function_variadic,
  398. torch.overrides.has_torch_function_unary,
  399. )
  400. def handle_has_torch_function(self, tx, *args):
  401. elems = (
  402. args[0].unpack_var_sequence(tx)
  403. if len(args) == 1 and isinstance(args[0], TupleVariable)
  404. else args
  405. )
  406. return ConstantVariable.create(
  407. any(has_torch_function(x) for x in elems),
  408. )
  409. @register(
  410. *dict.fromkeys( # remove duplicates
  411. device_interface.stream
  412. for _, device_interface in get_registered_device_interfaces()
  413. )
  414. )
  415. def handle_device_interface_stream(self, tx, stream):
  416. return StreamContextVariable.create(tx, stream)
  417. @register(torch.from_numpy)
  418. def handle_from_numpy(self, tx, *args):
  419. if not config.trace_numpy:
  420. unimplemented("torch.from_numpy. config.trace_numpy is False")
  421. if not np:
  422. unimplemented("torch.from_numpy. NumPy is not available")
  423. return wrap_fx_proxy_cls(
  424. target_cls=TensorVariable,
  425. tx=tx,
  426. proxy=tx.output.create_proxy(
  427. "call_function",
  428. torch.as_tensor,
  429. *proxy_args_kwargs(args, {}),
  430. ),
  431. example_value=None,
  432. )
  433. @register(torch.jit.annotate)
  434. def handle_jit_annotate(self, tx, the_type, the_value):
  435. return the_value
  436. @register(torch.backends.cudnn.is_acceptable)
  437. def handle_cudnn_is_acceptable(self, tx, tensor, *extra):
  438. # is_acceptable(tensor) returns true if
  439. # (a) tensor dtype/device are supported by cudnn
  440. # (b) cudnn is available
  441. # (c) some initialization has completed
  442. # technically, it depends on some global state from (c) (torch.backends.cudnn.__cudnn_version)
  443. assert not extra, "Expect 1 input to cudnn.is_acceptable"
  444. assert isinstance(
  445. tensor, TensorVariable
  446. ), "Expect input to cudnn.is_acceptable to be a tensor"
  447. tensor_inp = torch.tensor(0, dtype=tensor.dtype, device=tensor.device)
  448. return ConstantVariable.create(
  449. torch.backends.cudnn.is_acceptable(tensor_inp)
  450. )
  451. @register(torch.utils.hooks.BackwardHook)
  452. def handle_backward_hook(self, tx, *args, **kwargs):
  453. return variables.BackwardHookVariable.create(tx, *args, **kwargs)
  454. @register(torch.nn.Parameter)
  455. def handle_parameter(self, tx, *args, **kwargs):
  456. return self.call_nn_parameter(tx, *args, **kwargs)
  457. @register(torch.ops.aten.sym_size, torch.ops.aten.sym_size.int)
  458. def handle_sym_size(self_, tx, self, dim=None):
  459. # we see this when retracing already traced code
  460. if dim is not None:
  461. return self.call_method(tx, "size", [dim], {})
  462. @register(torch.ops.aten.sym_stride, torch.ops.aten.sym_stride.int)
  463. def handle_sym_stride(self_, tx, self, dim=None):
  464. if dim is not None:
  465. return self.call_method(tx, "stride", [dim], {})
  466. @register(torch.addcdiv)
  467. def handle_addcdiv(self, tx, *args, **kwargs):
  468. if len(args) == 3 and "value" in kwargs and len(kwargs) == 1:
  469. # decompose addcdiv into constituent ops, prevents a graph break due to converting
  470. # value to a scalar
  471. result = TorchInGraphFunctionVariable(torch.div).call_function(
  472. tx, [*args[1:]], {}
  473. )
  474. result = TorchInGraphFunctionVariable(torch.mul).call_function(
  475. tx, [result, kwargs["value"]], {}
  476. )
  477. return TorchInGraphFunctionVariable(torch.add).call_function(
  478. tx, [args[0], result], {}
  479. )
  480. @register(torch._assert)
  481. def handle_assert(self, tx, condition, message):
  482. if (condition.is_python_constant() and condition.as_python_constant()) or (
  483. isinstance(condition, variables.SymNodeVariable)
  484. and condition.evaluate_expr()
  485. ):
  486. return ConstantVariable(None)
  487. @register(SDPAParams)
  488. def handle_sdpa_params(self, tx, *args, **kwargs):
  489. return wrap_fx_proxy(
  490. tx,
  491. proxy=tx.output.create_proxy(
  492. "call_function",
  493. torch._C._SDPAParams,
  494. *proxy_args_kwargs(args, kwargs),
  495. ),
  496. param_vars=args,
  497. )
  498. if DistributedVariable.is_available():
  499. from torch.distributed._tensor import DTensor
  500. from torch.distributed.distributed_c10d import (
  501. _get_group_size_by_name,
  502. _get_group_tag,
  503. _rank_not_in_group,
  504. _resolve_group_name_by_ranks_and_tag,
  505. get_process_group_ranks,
  506. )
  507. @register(
  508. _get_group_size_by_name,
  509. _get_group_tag,
  510. _rank_not_in_group,
  511. get_process_group_ranks,
  512. _resolve_group_name_by_ranks_and_tag,
  513. )
  514. def handle_constant_processgroup_functions(self, tx, *args):
  515. # because the input is a "ProcessGroupVariable", we'll be guarding on its
  516. # ID_MATCH based on how it was constructed.
  517. # We desugar it at trace-time into ranks by directly calling util
  518. # bake the result into the trace
  519. if len(args) == 1:
  520. # group or group name
  521. assert isinstance(args[0], (ProcessGroupVariable, ConstantVariable))
  522. elif len(args) == 2:
  523. # ranks + tag
  524. assert isinstance(args[0], ListVariable) and isinstance(
  525. args[1], ConstantVariable
  526. )
  527. else:
  528. raise AssertionError(
  529. f"Invalid group value ({args}) for constant pg "
  530. f"function {self.value}"
  531. )
  532. args_as_value = [arg.as_python_constant() for arg in args]
  533. invocation_result = self.value(*args_as_value)
  534. # Note - while we *could* cook up sources around invocations, like a FunctionSource
  535. # the space of invoking functions in the middle of the guard chain is very iffy. As such,
  536. # guard propagation via options is the best we can do.
  537. return SourcelessBuilder.create(tx, invocation_result)
  538. @register(DTensor.from_local)
  539. def handle_from_local(self, tx, *args, **kwargs):
  540. # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
  541. # and rewrite args to have only proxyable args, then insert call_function
  542. args_as_value = [x.as_python_constant() for x in args[1:]]
  543. kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()}
  544. def fn_with_prim_types(x):
  545. return self.value(x, *args_as_value, **kwargs_as_value)
  546. # attach the same function name for better debugging
  547. fn_with_prim_types.__name__ = "prim " + self.value.__name__
  548. return wrap_fx_proxy(
  549. tx=tx,
  550. proxy=tx.output.create_proxy(
  551. "call_function",
  552. fn_with_prim_types,
  553. *proxy_args_kwargs([args[0]], {}),
  554. ),
  555. )
  556. @register(torch.nested.nested_tensor)
  557. def handle_nested_tensor(
  558. self, tx, tensor_list=None, *args, layout=None, **kwargs
  559. ):
  560. from .lists import BaseListVariable
  561. if layout and layout.as_python_constant() == torch.strided:
  562. unimplemented("torch.compile does not support strided NestedTensor")
  563. if not isinstance(tensor_list, BaseListVariable):
  564. unimplemented("nested_tensor with non-list input")
  565. @register(torch.nn.functional.one_hot)
  566. def handle_one_hot(self, tx, *args, **kwargs):
  567. if len(args) + len(kwargs) == 1 or (
  568. len(args) == 2
  569. and args[1].is_python_constant()
  570. and args[1].as_python_constant() == -1
  571. ):
  572. unimplemented(
  573. "torch.nn.functional.one_hot with data-dependent output shape"
  574. )
  575. @register(torch.fx.experimental.symbolic_shapes.guard_size_oblivious)
  576. def handle_guard_size_oblivious(self, tx, expr):
  577. if isinstance(expr, SymNodeVariable):
  578. # TODO: this probably should be folded somewhere else but I'm not sure where
  579. # TODO: some of the other symbolic_shapes special tools can also get this treatment too
  580. return variables.ConstantVariable.create(
  581. torch.fx.experimental.symbolic_shapes.guard_size_oblivious(
  582. expr.sym_num
  583. )
  584. )
  585. elif isinstance(expr, ConstantVariable):
  586. return expr
  587. @register(torch._C._autograd._unsafe_set_version_counter)
  588. def handle_unsafe_set_version_counter(self, tx, *args, **kwargs):
  589. from ..tensor_version_op import _unsafe_set_version_counter
  590. return TorchInGraphFunctionVariable(
  591. _unsafe_set_version_counter
  592. ).call_function(tx, [*args], kwargs)
  593. @register(torch.tensor)
  594. def handle_torch_tensor(self, tx, *args, **kwargs):
  595. def check_any_unspec(x):
  596. # NB: This includes UnspecializedPythonVariable
  597. if isinstance(x, (TensorVariable, SymNodeVariable)):
  598. return True
  599. elif isinstance(x, (ListVariable, TupleVariable)):
  600. return any(check_any_unspec(y) for y in x.items)
  601. # TODO: there maybe other recursive structures you need to
  602. # check
  603. else:
  604. return False
  605. data_arg = None
  606. if args:
  607. data_arg = args[0]
  608. elif "data" in kwargs:
  609. data_arg = kwargs["data"]
  610. # NB: OK to pass torch.tensor(tensor), this will trace fine
  611. if not isinstance(data_arg, TensorVariable) and check_any_unspec(data_arg):
  612. # This is slower and less canonical, so only use it if we
  613. # have to
  614. return TorchInGraphFunctionVariable(torch._refs.tensor).call_function(
  615. tx, [*args], kwargs
  616. )
  617. return handlers
  618. def call_function(
  619. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  620. ) -> "VariableTracker":
  621. from . import ConstantVariable, SymNodeVariable, TensorVariable
  622. from .builder import wrap_fx_proxy
  623. if self.can_constant_fold_through() and check_unspec_or_constant_args(
  624. args, kwargs
  625. ):
  626. # constant fold
  627. return ConstantVariable.create(
  628. self.as_python_constant()(
  629. *[x.as_python_constant() for x in args],
  630. **{k: v.as_python_constant() for k, v in kwargs.items()},
  631. ),
  632. )
  633. special_handler = self._get_handlers().get(self.value)
  634. if special_handler:
  635. result = special_handler(self, tx, *args, **kwargs)
  636. if result:
  637. return result
  638. if can_dispatch_torch_function(tx, args, kwargs):
  639. return dispatch_torch_function(tx, self, args, kwargs)
  640. else:
  641. any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
  642. all_ints_or_floats = all(
  643. isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
  644. for x in args
  645. )
  646. if (
  647. getattr(self.value, "__module__", "") == "torch"
  648. and self.value.__name__ in bin_ops
  649. and any_symints_or_symfloats
  650. and all_ints_or_floats
  651. ):
  652. msg = f"""\
  653. Calling {str(self.value)} on only torch.SymInt arguments is not yet supported.
  654. To support this behavior, we need to allow const-propping tensors that store symint data.
  655. For now, dynamo will explicitly graph break when it encounters user code with this behavior.
  656. """
  657. log.warning(msg)
  658. unimplemented(msg)
  659. # TODO(voz): Replace w/ dynamic shape rewrite table.
  660. # Ideally, we would be able to do this at ctor time, but alas we need a combination
  661. # of value + args to determine this.
  662. fn_ = self.value
  663. if any_symints_or_symfloats:
  664. torch_sym_op = f"_sym_{self.value.__name__}"
  665. if getattr(self.value, "__module__", None) == "math" and hasattr(
  666. torch, torch_sym_op
  667. ):
  668. fn_ = getattr(torch, torch_sym_op)
  669. tensor_variable = wrap_fx_proxy(
  670. tx=tx,
  671. proxy=tx.output.create_proxy(
  672. "call_function",
  673. fn_,
  674. *proxy_args_kwargs(args, kwargs),
  675. ),
  676. )
  677. if (
  678. isinstance(tensor_variable, TensorVariable)
  679. and "requires_grad" in kwargs
  680. and kwargs["requires_grad"].as_python_constant()
  681. ):
  682. unimplemented(
  683. """factory functions that return tensors that require grad are not supported.
  684. Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
  685. )
  686. if "out" in kwargs and not (
  687. isinstance(kwargs["out"], variables.ConstantVariable)
  688. and kwargs["out"].as_python_constant() is None
  689. ):
  690. # out variants of torch operators like torch.sort and
  691. # torch.sigmoid mutate the tensors in the out field. Track such
  692. # tensors and rewrite the symbolic locals.
  693. if isinstance(tensor_variable, TupleVariable):
  694. assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
  695. output_tensor_names = [
  696. tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
  697. ]
  698. for idx, name in enumerate(output_tensor_names):
  699. if name in tx.symbolic_locals:
  700. tx.symbolic_locals[name] = tensor_variable.items[idx]
  701. for out_tensor, result_tensor in zip(
  702. kwargs["out"].items, tensor_variable.items
  703. ):
  704. if (
  705. out_tensor.source
  706. and out_tensor in tx.output.graphargs
  707. and isinstance(out_tensor, variables.TensorVariable)
  708. and isinstance(result_tensor, variables.TensorVariable)
  709. and out_tensor.size != result_tensor.size
  710. ):
  711. # It's hard to get out variants with resizing on graph inputs work
  712. # properly across dynamo/aot/inductor, just fall back.
  713. unimplemented("out variants with resizing on graph inputs")
  714. elif isinstance(tensor_variable, TensorVariable):
  715. assert isinstance(kwargs["out"], TensorVariable)
  716. assert "example_value" in kwargs["out"].proxy.node.meta
  717. fake_tensor = tensor_variable.proxy.node.meta["example_value"]
  718. fake_out = kwargs["out"].proxy.node.meta["example_value"]
  719. if (
  720. kwargs["out"].source
  721. and kwargs["out"] in tx.output.graphargs
  722. and fake_out.shape != fake_tensor.shape
  723. ):
  724. # It's hard to get out variants with resizing on graph inputs work
  725. # properly across dynamo/aot/inductor, just fall back.
  726. unimplemented("out variants with resizing on graph inputs")
  727. if not torch._prims_common.is_contiguous(fake_out):
  728. # It's difficult to handle strides correctly in functionalization
  729. # when calling an out= op with a non-contiguous out argument
  730. unimplemented(
  731. "out= op was called where output tensor was non-contiguous"
  732. )
  733. name = tx.find_symbolic_locals_name(kwargs["out"])
  734. if name in tx.symbolic_locals:
  735. tx.symbolic_locals[name] = tensor_variable
  736. else:
  737. unimplemented(f"out variant of {type(kwargs['out'])}")
  738. return tensor_variable
  739. def _call_ntuple(self, tx, args, kwargs):
  740. """inline behavior of torch.nn.modules.utils._ntuple"""
  741. if self.value is torch.nn.modules.utils._ntuple:
  742. count = args[0].as_python_constant()
  743. else:
  744. count = self.value.__closure__[0].cell_contents
  745. assert isinstance(count, int)
  746. assert not kwargs
  747. def handle_ntuple(value):
  748. if value.has_unpack_var_sequence(tx):
  749. return variables.TupleVariable(
  750. list(value.unpack_var_sequence(tx)),
  751. )
  752. elif value.is_python_constant():
  753. # constant prop through it
  754. return variables.ConstantVariable.create(
  755. torch.nn.modules.utils._ntuple(count)(value.as_python_constant()),
  756. )
  757. else:
  758. unimplemented(f"torch.nn.modules.utils._ntuple({value})")
  759. if self.value is torch.nn.modules.utils._ntuple:
  760. return variables.LambdaVariable(handle_ntuple)
  761. else:
  762. return handle_ntuple(args[0])
  763. @classmethod
  764. def call_nn_parameter(cls, tx, data=None, requires_grad=True):
  765. """A call to torch.nn.Parameter() gets lifted to before the graph"""
  766. if isinstance(requires_grad, variables.VariableTracker):
  767. try:
  768. requires_grad = requires_grad.as_python_constant()
  769. except NotImplementedError:
  770. unimplemented("Parameter(requires_grad=...) not constant")
  771. if not isinstance(data, variables.TensorVariable):
  772. unimplemented(f"Parameter(data={data}) not implemented")
  773. # this results in cleaner graphs, but only works for inputs
  774. if data.source:
  775. return cls._nn_param_via_prefix_insert(tx, data, requires_grad)
  776. try:
  777. shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
  778. dtype = data.var_getattr(tx, "dtype").as_python_constant()
  779. device = data.var_getattr(tx, "device").as_python_constant()
  780. except NotImplementedError as e:
  781. unimplemented(f"Parameter not python_constant: {e}")
  782. placeholder = tx.output.synthetic_graph_input(
  783. new_parameter_placeholder, [shape, dtype, device, requires_grad]
  784. )
  785. if data.requires_grad:
  786. data = data.call_method(tx, "detach", [], {})
  787. from .builder import wrap_fx_proxy
  788. result = wrap_fx_proxy(
  789. tx,
  790. tx.output.create_proxy(
  791. "call_function",
  792. tracable_create_parameter,
  793. (data.as_proxy(), placeholder.as_proxy()),
  794. {},
  795. ),
  796. )
  797. assert isinstance(result, variables.TensorVariable)
  798. result.class_type = torch.nn.Parameter
  799. # TODO(jansel/bdhirsh) - There is some issue with
  800. # tracable_create_paramter. It does not seem to use the right
  801. # grad_enabled. Since this is parameter, we can just override the
  802. # has_grad_fn field to False to workaround the issue.
  803. result.has_grad_fn = False
  804. # In reconstruct() should use the original parameter. The one returned by the graph will be an alias.
  805. result.source = placeholder.source
  806. # TODO(jansel): if the new param falls out of scope, currently it won't get freed until
  807. # the end of the graph. We should fix this.
  808. return result
  809. @staticmethod
  810. def _nn_param_via_prefix_insert(tx, data, requires_grad):
  811. # Alternate version if we have a .source
  812. from .builder import VariableBuilder
  813. varname = tx.output.new_var()
  814. # construct the nn.Parmeter before the graph save it to varname
  815. cg = PyCodegen(tx)
  816. cg.load_import_from("torch.nn", "Parameter")
  817. cg(data.source)
  818. cg(variables.ConstantVariable(requires_grad))
  819. cg.call_function(2, True)
  820. cg.store(varname)
  821. tx.output.pregraph_bytecode.extend(cg.get_instructions())
  822. data_node = data.as_proxy().node
  823. if data_node.op not in ("placeholder", "get_attr"):
  824. unimplemented(
  825. "Unexpected type of data placeholder op for parameter construction"
  826. )
  827. # add the newly constructed nn.Parameter as a graph input
  828. source = SyntheticLocalSource(varname)
  829. example_value = torch.nn.Parameter(
  830. tx.output.example_value_from_input_node(data.as_proxy().node)
  831. )
  832. result = VariableBuilder(tx, source)(example_value)
  833. # No need to guard on this since we already guarded on `data`.
  834. # These guards would fail since varname doesn't exist until after the function starts
  835. TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
  836. source
  837. )
  838. return result