fake_impls.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179
  1. # mypy: ignore-errors
  2. import functools
  3. import itertools
  4. import math
  5. import sys
  6. from typing import Callable, Union
  7. import torch
  8. import torch._custom_op
  9. import torch._logging
  10. from torch._ops import OpOverload
  11. from torch._prims_common import (
  12. elementwise_dtypes,
  13. ELEMENTWISE_TYPE_PROMOTION_KIND,
  14. is_boolean_dtype,
  15. is_float_dtype,
  16. is_integer_dtype,
  17. )
  18. from torch._subclasses.fake_tensor import (
  19. DataDependentOutputException,
  20. DynamicOutputShapeException,
  21. FakeTensor,
  22. in_kernel_invocation_manager,
  23. run_fallback_kernel,
  24. UnsupportedOperatorException,
  25. )
  26. from torch.fx.operator_schemas import normalize_function
  27. from torch.utils._stats import count_label
  28. pytree = torch.utils._pytree
  29. __all__ = [
  30. "op_implementations_checks",
  31. "get_fast_op_impls",
  32. "stride_incorrect_op",
  33. "has_meta",
  34. ]
  35. op_implementations_dict = {}
  36. op_implementations_checks = []
  37. aten = torch._ops.ops.aten
  38. def ordered_set(*items):
  39. return dict.fromkeys(items, True)
  40. # This function indicates if the backend device
  41. # supports non-contiguous tensors
  42. def is_noncontiguous_supported(device):
  43. if device.type == "hpu":
  44. return False
  45. return True
  46. _like_tensor_constructors = ordered_set(
  47. aten.empty_like.default,
  48. aten.empty_like.out,
  49. aten.full_like.default,
  50. aten.full_like.out,
  51. aten.ones_like.default,
  52. aten.ones_like.out,
  53. aten.rand_like.default,
  54. aten.rand_like.out,
  55. aten.randn_like.default,
  56. aten.randn_like.out,
  57. aten.randint_like.default,
  58. aten.randint_like.out,
  59. aten.randint_like.low_dtype,
  60. aten.randint_like.low_dtype_out,
  61. aten.zeros_like.default,
  62. aten.zeros_like.out,
  63. aten.new_empty.default,
  64. aten.new_empty.out,
  65. aten.new_empty_strided.default,
  66. aten.new_empty_strided.out,
  67. aten.new_full.default,
  68. aten.new_full.out,
  69. aten.new_zeros.default,
  70. aten.new_zeros.out,
  71. aten.new_ones.default,
  72. aten.new_ones.out,
  73. )
  74. _device_not_kwarg_ops = ordered_set(
  75. aten._resize_output_.default,
  76. aten._nested_tensor_from_tensor_list.default,
  77. aten._nested_tensor_from_tensor_list.out,
  78. aten.pin_memory.default,
  79. aten.is_pinned.default,
  80. aten.to.device,
  81. aten.to.prim_Device,
  82. aten._pin_memory.default,
  83. aten._pin_memory.out,
  84. aten._resize_output.default,
  85. aten._resize_output.out,
  86. )
  87. # this op is never actually used
  88. _non_kwarg_device_constructors = (aten._list_to_tensor,)
  89. def contains_tensor_types(type):
  90. tensor_type = torch._C.TensorType.get()
  91. return type.isSubtypeOf(tensor_type) or any(
  92. contains_tensor_types(e) for e in type.containedTypes()
  93. )
  94. @functools.lru_cache(None)
  95. def _is_tensor_constructor(func: OpOverload):
  96. assert isinstance(func, OpOverload)
  97. schema = func._schema
  98. if any(contains_tensor_types(arg.type) for arg in schema.arguments):
  99. return False
  100. # TODO: no real reason to restrict multiple outputs
  101. return (
  102. len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()
  103. )
  104. def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
  105. def impl_decorator(op_impl):
  106. if isinstance(run_impl_check, OpOverload):
  107. assert (
  108. run_impl_check not in op_implementations_dict
  109. ), f"duplicate registration: {run_impl_check}"
  110. op_implementations_dict[run_impl_check] = op_impl
  111. elif isinstance(run_impl_check, (list, tuple)):
  112. for op in run_impl_check:
  113. register_op_impl(op)(op_impl)
  114. else:
  115. assert callable(run_impl_check)
  116. op_implementations_checks.append((run_impl_check, op_impl))
  117. return op_impl
  118. return impl_decorator
  119. @register_op_impl(op_implementations_dict.__contains__)
  120. def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs):
  121. return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
  122. @register_op_impl(_is_tensor_constructor)
  123. @register_op_impl([*_like_tensor_constructors])
  124. def constructors(fake_mode, func, *args, **kwargs):
  125. assert func not in _non_kwarg_device_constructors
  126. _, new_kwargs = normalize_function(
  127. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  128. )
  129. if "names" in kwargs:
  130. raise UnsupportedOperatorException(
  131. "torch.compile doesn't support named tensors"
  132. )
  133. if func in _like_tensor_constructors:
  134. default_device = new_kwargs["input"].device
  135. # TODO: file issue
  136. args = (new_kwargs.pop("input"),)
  137. else:
  138. # cpu is default device if none is specified
  139. default_device = torch.device("cpu")
  140. args = ()
  141. out_device = new_kwargs.pop("device", None)
  142. out_device = out_device if out_device is not None else default_device
  143. new_kwargs["device"] = torch.device("meta")
  144. # _like constructors have fake tensor inputs (maybe this causes the non-like
  145. # to fail? hmmm)
  146. with in_kernel_invocation_manager(fake_mode):
  147. r = func(*args, **new_kwargs)
  148. return FakeTensor(fake_mode, r, out_device)
  149. @register_op_impl(aten.to.prim_Device)
  150. @register_op_impl(aten.to.device)
  151. def non_kwarg_to(fake_mode, func, *args, **kwargs):
  152. _, new_kwargs = normalize_function(
  153. func, args, kwargs, normalize_to_only_use_kwargs=True
  154. )
  155. input_device = new_kwargs["device"]
  156. out_device = input_device if input_device else new_kwargs["input"].device
  157. new_kwargs["device"] = torch.device("meta")
  158. inp = new_kwargs.pop("input")
  159. with in_kernel_invocation_manager(fake_mode):
  160. r = func(inp, **new_kwargs)
  161. # TODO: I think this does the wrong thing if r is inp
  162. return fake_mode.fake_tensor_converter.from_meta_and_device(
  163. fake_mode, r, out_device
  164. )
  165. def stride_incorrect_op(op):
  166. if op.namespace not in ("aten", "prims"):
  167. return False
  168. if op is aten._fft_c2c.default:
  169. return False
  170. op_name = op.name()
  171. if "fft" in op_name:
  172. return True
  173. return False
  174. # These operators have meta implementations with incorrect strides
  175. @register_op_impl(stride_incorrect_op)
  176. def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs):
  177. # This is a workaround for meta implmentations with incorrect strides
  178. def is_symbolic(x):
  179. if isinstance(x, FakeTensor):
  180. return x._has_symbolic_sizes_strides
  181. if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):
  182. return True
  183. return False
  184. # For static shapes, we can fall back to eager for the real strides
  185. if fake_mode.allow_fallback_kernels:
  186. require_dynamic = any(
  187. is_symbolic(x) for x in itertools.chain(args, kwargs.values())
  188. )
  189. if not require_dynamic:
  190. flat_args, args_spec = pytree.tree_flatten((args, kwargs))
  191. return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None)
  192. raise UnsupportedOperatorException(func)
  193. # Dont default to default device handling,
  194. # since the device of `the_template` is ignored
  195. @register_op_impl(aten.resize_as_.default)
  196. def resize_as_(fake_mode, func, *args, **kwargs):
  197. with in_kernel_invocation_manager(fake_mode):
  198. return func(*args, **kwargs)
  199. @register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
  200. def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
  201. # TODO: remove me
  202. return constructors(fake_mode, func, *args, **kwargs)
  203. # index.Tensor data-dependent in only some conditions
  204. @register_op_impl(
  205. lambda func: torch.Tag.dynamic_output_shape in func.tags
  206. and func
  207. not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor]
  208. )
  209. def dyn_shape(fake_mode, func, *args, **kwargs):
  210. raise DynamicOutputShapeException(func)
  211. def _unique(
  212. fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
  213. ):
  214. if (
  215. fake_mode.shape_env is None
  216. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  217. ):
  218. # Without symints/symfloats, cannot handle this
  219. raise DynamicOutputShapeException(func)
  220. # Do not use a memo for unique_dim
  221. if dim is not None or (nnz := arg.unique_memo) is None:
  222. # Avoid importing sympy at a module level
  223. from torch.fx.experimental.symbolic_shapes import (
  224. _constrain_range_for_size,
  225. has_free_symbols,
  226. )
  227. if not has_free_symbols(arg.numel()) and arg.numel() == 0:
  228. # If numel is zero, then the output size must be zero.
  229. # In this case, we must not allocate an unbacked SymInt,
  230. # because if we do, it will immediately get refined to
  231. # zero, but this will be inconsistent with size oblivious
  232. # tests (which will continue to claim that the unbacked
  233. # symint cannot equal zero). We could also unconditionally
  234. # allocate an unbacked SymInt and not refine its range,
  235. # but this seems more precise.
  236. nnz = 0
  237. else:
  238. nnz = fake_mode.shape_env.create_unbacked_symint()
  239. maxval = sys.maxsize - 1
  240. numel = arg.numel() if dim is None else arg.size(dim)
  241. if not has_free_symbols(numel):
  242. maxval = int(numel)
  243. _constrain_range_for_size(nnz, max=maxval)
  244. if dim is None:
  245. arg.unique_memo = nnz
  246. if dim is None:
  247. ret = [arg.new_empty((nnz,))]
  248. else:
  249. ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])]
  250. return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu")
  251. if return_inverse or return_if_dim_and_cpu:
  252. inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],))
  253. else:
  254. inverse = arg.new_empty(0)
  255. ret.append(inverse)
  256. if return_counts or return_if_dim_and_cpu:
  257. counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],))
  258. else:
  259. counts = arg.new_empty(0)
  260. ret.append(counts)
  261. return tuple(ret)
  262. @register_op_impl(aten._unique2.default)
  263. def unique2(
  264. fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False
  265. ):
  266. return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)
  267. @register_op_impl(aten.unique_dim.default)
  268. def unique_dim(
  269. fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
  270. ):
  271. return _unique(
  272. fake_mode,
  273. func,
  274. arg,
  275. # normalize dim to be non-negative
  276. dim if dim >= 0 else dim % max(arg.ndim, 1),
  277. sorted,
  278. return_inverse,
  279. return_counts,
  280. )
  281. @register_op_impl(aten.repeat_interleave.Tensor)
  282. def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
  283. if output_size is None:
  284. if (
  285. fake_mode.shape_env is None
  286. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  287. ):
  288. raise DynamicOutputShapeException(func)
  289. output_size = fake_mode.shape_env.create_unbacked_symint()
  290. # Avoid importing sympy at a module level
  291. from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
  292. _constrain_range_for_size(output_size)
  293. # TODO: consider a memo
  294. return repeats.new_empty(output_size)
  295. @register_op_impl(torch.ops.aten._local_scalar_dense.default)
  296. def local_scalar_dense(fake_mode, func, arg):
  297. if (r := arg.item_memo) is not None:
  298. return r
  299. if fake_mode.shape_env is None or (
  300. not fake_mode.shape_env.allow_scalar_outputs
  301. and not fake_mode.allow_scalar_outputs
  302. ):
  303. # Without symints/symfloats, cannot handle this
  304. raise DataDependentOutputException(func)
  305. if is_float_dtype(arg.dtype):
  306. r = fake_mode.shape_env.create_unbacked_symfloat()
  307. elif is_integer_dtype(arg.dtype):
  308. r = fake_mode.shape_env.create_unbacked_symint()
  309. elif is_boolean_dtype(arg.dtype):
  310. r = fake_mode.shape_env.create_unbacked_symbool()
  311. else:
  312. raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
  313. arg.item_memo = r
  314. return r
  315. @register_op_impl(torch.ops.aten.nonzero.default)
  316. def nonzero(fake_mode, func, arg):
  317. if (
  318. fake_mode.shape_env is None
  319. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  320. ):
  321. # Without symints/symfloats, cannot handle this
  322. raise DynamicOutputShapeException(func)
  323. if (nnz := arg.nonzero_memo) is None:
  324. # Avoid importing sympy at a module level
  325. from torch.fx.experimental.symbolic_shapes import (
  326. _constrain_range_for_size,
  327. has_free_symbols,
  328. )
  329. if not has_free_symbols(arg.numel()) and arg.numel() == 0:
  330. # If numel is zero, then the output size must be zero.
  331. # In this case, we must not allocate an unbacked SymInt,
  332. # because if we do, it will immediately get refined to
  333. # zero, but this will be inconsistent with size oblivious
  334. # tests (which will continue to claim that the unbacked
  335. # symint cannot equal zero). We could also unconditionally
  336. # allocate an unbacked SymInt and not refine its range,
  337. # but this seems more precise.
  338. nnz = 0
  339. else:
  340. nnz = fake_mode.shape_env.create_unbacked_symint()
  341. maxval = sys.maxsize - 1
  342. if not has_free_symbols(arg.numel()):
  343. maxval = int(arg.numel())
  344. _constrain_range_for_size(nnz, max=maxval)
  345. arg.nonzero_memo = nnz
  346. return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
  347. @register_op_impl(torch.ops.aten.masked_select.default)
  348. def masked_select(fake_mode, func, self, mask):
  349. if (
  350. fake_mode.shape_env is None
  351. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  352. ):
  353. # Without symints/symfloats, cannot handle this
  354. raise DynamicOutputShapeException(func)
  355. nnz = fake_mode.shape_env.create_unbacked_symint()
  356. # see nonzero for commentary
  357. maxval = sys.maxsize - 1
  358. # Avoid importing sympy at a module level
  359. from torch.fx.experimental.symbolic_shapes import (
  360. _constrain_range_for_size,
  361. has_free_symbols,
  362. )
  363. if not has_free_symbols(self.numel()):
  364. if self.numel() > 2:
  365. maxval = int(self.numel())
  366. _constrain_range_for_size(nnz, max=maxval)
  367. return self.new_empty((nnz,))
  368. # NB: this must be ordered after local_scalar_dense
  369. @register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags)
  370. def data_dep(fake_mode, func, *args, **kwargs):
  371. raise DataDependentOutputException(func)
  372. # Bool Indices get Expanded as Masks
  373. # See: IndexingUtils.h:expandTensors
  374. def check_no_bool_index_tensors(func, self, indices):
  375. for index in indices:
  376. if index is not None and index.dtype in (torch.bool, torch.uint8):
  377. raise DynamicOutputShapeException(func)
  378. def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):
  379. _, new_kwargs = normalize_function(
  380. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  381. )
  382. out_device = new_kwargs["input"].device
  383. with in_kernel_invocation_manager(fake_mode):
  384. out = func(*args, **kwargs)
  385. if not is_noncontiguous_supported(out_device):
  386. out = out.new_empty(out.shape)
  387. if out is new_kwargs["input"]:
  388. return out # copy_
  389. return FakeTensor(fake_mode, out, out_device)
  390. _is_builtin_namespaces = ordered_set("aten", "prims", "prim")
  391. def is_builtin(op):
  392. return op.namespace in _is_builtin_namespaces
  393. def has_meta(func):
  394. return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta")
  395. @register_op_impl(
  396. lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func)
  397. )
  398. def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs):
  399. tensor_lists = []
  400. for arg in itertools.chain(args, kwargs.values()):
  401. if (
  402. isinstance(arg, (list, tuple))
  403. and len(arg)
  404. and isinstance(arg[0], torch.Tensor)
  405. ):
  406. tensor_lists.append(arg)
  407. try:
  408. with in_kernel_invocation_manager(fake_mode):
  409. out_meta = func(*args, **kwargs)
  410. except NotImplementedError as not_implemented_error:
  411. return NotImplemented
  412. if not out_meta:
  413. return out_meta
  414. assert tensor_lists
  415. out_fake = []
  416. for i, meta_t in enumerate(out_meta):
  417. device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists])
  418. out_fake.append(
  419. fake_mode.fake_tensor_converter.from_meta_and_device(
  420. fake_mode, meta_t, device
  421. )
  422. )
  423. return out_fake
  424. # Dont default to default device handling,
  425. # Since op can take in non-zero sized cpu
  426. # index tensors with cuda self
  427. @register_op_impl(aten.index.Tensor)
  428. def index_tensor(fake_mode, func, *args, **kwargs):
  429. from torch._meta_registrations import meta_index_Tensor
  430. _, new_kwargs = normalize_function(
  431. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  432. )
  433. out_device = new_kwargs["input"].device
  434. # ensure nonzero call goes to fake tensor
  435. with fake_mode:
  436. out = meta_index_Tensor(*args, **kwargs)
  437. return out.to(out_device)
  438. # Can take mixed meta/non-meta arguments; the meta registration
  439. # will roughly do the right thing even when given real devices
  440. @register_op_impl(aten._embedding_bag.default)
  441. def embedding_bag(fake_mode, func, *args, **kwargs):
  442. from torch._meta_registrations import meta_embedding_bag
  443. with fake_mode:
  444. return meta_embedding_bag(*args, **kwargs)
  445. # takes in multiple-devices, dont default to default device handling
  446. @register_op_impl(aten._unsafe_index_put.default)
  447. @register_op_impl(aten.copy.default)
  448. @register_op_impl(aten.copy_.default)
  449. @register_op_impl(aten.slice_scatter.default)
  450. def multi_device_op_default(fake_mode, func, *args, **kwargs):
  451. return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
  452. # same with multi_device_op_default, but return the input
  453. @register_op_impl(aten.copy.out)
  454. @register_op_impl(aten.slice_scatter.out)
  455. def multi_device_op_out(fake_mode, func, *args, **kwargs):
  456. with in_kernel_invocation_manager(fake_mode):
  457. out = func(*args, **kwargs)
  458. _, new_kwargs = normalize_function(
  459. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  460. )
  461. return new_kwargs["input"]
  462. @register_op_impl(aten.index_put.default)
  463. @register_op_impl(aten.index_put_.default)
  464. def index_put_impl(fake_mode, func, *args, **kwargs):
  465. _, new_kwargs = normalize_function(
  466. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  467. )
  468. values = new_kwargs["values"]
  469. self_device = new_kwargs["input"].fake_device
  470. torch._check(
  471. self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1),
  472. lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})",
  473. )
  474. out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
  475. if func is aten.index_put_.default:
  476. return new_kwargs["input"]
  477. else:
  478. return out
  479. @register_op_impl(aten._nested_tensor_from_tensor_list.default)
  480. @register_op_impl(aten._nested_tensor_from_tensor_list.out)
  481. @register_op_impl(aten._nested_view_from_buffer.default)
  482. @register_op_impl(aten._nested_view_from_buffer_copy.default)
  483. def nested_tensors_unsupported(fake_mode, func, *args, **kwargs):
  484. raise UnsupportedOperatorException(
  485. "torch.compile does not support strided NestedTensor"
  486. )
  487. @register_op_impl(
  488. [
  489. x
  490. for x in _device_not_kwarg_ops
  491. if x
  492. not in (
  493. # these are already registered elsewhere
  494. aten.to.device,
  495. aten.to.prim_Device,
  496. aten._nested_tensor_from_tensor_list.default,
  497. aten._nested_tensor_from_tensor_list.out,
  498. )
  499. ]
  500. )
  501. def nyi(fake_mode, func, *args, **kwargs):
  502. assert func not in _device_not_kwarg_ops, f"NYI: {func}"
  503. @register_op_impl([aten.convolution.default, aten.convolution_backward.default])
  504. def conv(fake_mode, func, *args, **kwargs):
  505. _, kwargs = normalize_function(
  506. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  507. )
  508. device = kwargs["input"].fake_device
  509. # need to re-enable mode so the tensors report fake device
  510. with fake_mode:
  511. # if the input is unsqueezed is done in Convolution.cpp we get segfault
  512. k = kwargs["weight"].ndim
  513. batch = kwargs["input"].shape[0]
  514. # Avoid importing sympy at a module level
  515. from torch.fx.experimental.symbolic_shapes import has_hint
  516. if not has_hint(batch):
  517. # TODO: We can make this a little more faithful with best effort
  518. # channels last detection (but only if it's statically obvious!)
  519. mem_fmt = None
  520. elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
  521. mem_fmt = None
  522. else:
  523. if func is aten.convolution.default:
  524. conv_backend = torch._C._select_conv_backend(**kwargs)
  525. else:
  526. conv_backend = torch._C._select_conv_backend(
  527. kwargs["input"],
  528. kwargs["weight"],
  529. bias=None,
  530. stride=kwargs["stride"],
  531. padding=kwargs["padding"],
  532. dilation=kwargs["dilation"],
  533. transposed=kwargs["transposed"],
  534. output_padding=kwargs["output_padding"],
  535. groups=kwargs["groups"],
  536. bias_sizes=kwargs["bias_sizes"],
  537. )
  538. mem_fmt = torch._C._conv_determine_backend_memory_format(
  539. kwargs["input"], kwargs["weight"], conv_backend
  540. )
  541. def convert(t, mem_fmt):
  542. if t is None:
  543. return t
  544. if mem_fmt is not None:
  545. t = t.to(memory_format=mem_fmt)
  546. return FakeTensor(fake_mode, t, device)
  547. with in_kernel_invocation_manager(fake_mode):
  548. out = func(**kwargs)
  549. if func is aten.convolution.default:
  550. return convert(out, mem_fmt)
  551. else:
  552. return (
  553. convert(out[0], mem_fmt),
  554. convert(out[1], mem_fmt),
  555. convert(out[2], None),
  556. )
  557. @register_op_impl(aten._scaled_dot_product_flash_attention.default)
  558. def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs):
  559. _, kwargs = normalize_function(
  560. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  561. )
  562. query = kwargs["query"]
  563. key = kwargs["key"]
  564. return_debug_mask = kwargs["return_debug_mask"]
  565. # unused: value, dropout_p, is_causal, scale
  566. def convert_tensor(t, device):
  567. return FakeTensor(fake_mode, t, device)
  568. batch_size = query.size(0)
  569. num_heads = query.size(1)
  570. max_seqlen_batch_q = query.size(2)
  571. head_dim = query.size(3)
  572. max_seqlen_batch_k = key.size(2)
  573. query_t = query.transpose(1, 2)
  574. # empty_like already returns a fake tensor so we don't need to convert it
  575. attention = torch.empty_like(query_t).transpose(1, 2)
  576. logsumexp = convert_tensor(
  577. torch.empty(
  578. (batch_size, num_heads, max_seqlen_batch_q),
  579. dtype=torch.float,
  580. device="meta",
  581. ),
  582. device=query.device,
  583. )
  584. if return_debug_mask:
  585. blocksize_c = 128 if head_dim > 64 else 256
  586. max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
  587. if max_seqlen_batch_k <= 128:
  588. max_seqlen_k = 128
  589. elif max_seqlen_batch_k <= 256:
  590. max_seqlen_k = 256
  591. debug_mask = convert_tensor(
  592. torch.empty(
  593. (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
  594. dtype=query.dtype,
  595. device="meta",
  596. ),
  597. device=query.device,
  598. )
  599. else:
  600. debug_mask = convert_tensor(
  601. torch.empty(0, dtype=query.dtype, device="meta"),
  602. query.device,
  603. )
  604. # Note [Seed and Offset]: device for seed and offset below depends on whether we are
  605. # capturing or not, but at the time of tracing we don't know if we
  606. # are going to use cudagraphs or not, so we return meta tensors here
  607. # it's possible we'll need to have some special handling in inductor for sdpa
  608. return (
  609. attention,
  610. logsumexp,
  611. None,
  612. None,
  613. max_seqlen_batch_q,
  614. max_seqlen_batch_k,
  615. convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
  616. convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
  617. debug_mask,
  618. )
  619. @register_op_impl(aten._scaled_dot_product_efficient_attention.default)
  620. def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs):
  621. _, kwargs = normalize_function(
  622. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  623. )
  624. query = kwargs["query"]
  625. key = kwargs["key"]
  626. value = kwargs["value"]
  627. compute_log_sumexp = kwargs["compute_log_sumexp"]
  628. # unused: attn_bias, dropout_p, is_causal, scale
  629. def convert_tensor(t, device):
  630. return FakeTensor(fake_mode, t, device)
  631. query = query.transpose(1, 2)
  632. key = key.transpose(1, 2)
  633. value = value.transpose(1, 2)
  634. B = query.size(0)
  635. M = query.size(1)
  636. N = key.size(1)
  637. num_heads = query.size(-2)
  638. K = query.size(-1)
  639. Kv = value.size(-1)
  640. res = convert_tensor(
  641. torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),
  642. query.device,
  643. )
  644. logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
  645. logsum_exp = convert_tensor(
  646. torch.empty(
  647. (B, num_heads, logsumexp_dim),
  648. dtype=torch.float,
  649. device="meta",
  650. ),
  651. query.device,
  652. )
  653. res = res.transpose(1, 2)
  654. # See Note [Seed and Offset]:
  655. seed = convert_tensor(
  656. torch.empty((), dtype=torch.long, device="meta"), query.device
  657. )
  658. offset = convert_tensor(
  659. torch.empty((), dtype=torch.long, device="meta"), query.device
  660. )
  661. return res, logsum_exp, seed, offset
  662. @register_op_impl(aten._flash_attention_forward.default)
  663. def meta__flash_attention_forward(fake_mode, func, *args, **kwargs):
  664. _, kwargs = normalize_function(
  665. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  666. )
  667. query = kwargs["query"]
  668. key = kwargs["key"]
  669. cum_seq_q = kwargs["cum_seq_q"]
  670. cum_seq_k = kwargs["cum_seq_k"]
  671. max_q = kwargs["max_q"]
  672. max_k = kwargs["max_k"]
  673. return_debug_mask = kwargs["return_debug_mask"]
  674. # unused: value, dropout_p, is_causal, scale
  675. # unused: seqused_k, alibi_slopes, window_size_left, window_size_right
  676. def convert_tensor(t, device):
  677. return FakeTensor(fake_mode, t, device)
  678. # NB: there are two underlying paths:
  679. # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
  680. # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
  681. # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
  682. batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
  683. max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
  684. max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
  685. num_heads = query.size(-2)
  686. head_dim = query.size(-1)
  687. # Cuda Path
  688. # note: empty_like already returns a fake tensor, we don't need to wrap it
  689. attention = torch.empty_like(query)
  690. logsumexp = convert_tensor(
  691. torch.empty(
  692. (batch_size, num_heads, max_seqlen_batch_q),
  693. dtype=torch.float,
  694. device="meta",
  695. ),
  696. device=query.device,
  697. )
  698. if return_debug_mask:
  699. blocksize_c = 128 if head_dim > 64 else 256
  700. max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
  701. if max_seqlen_batch_k <= 128:
  702. max_seqlen_k = 128
  703. elif max_seqlen_batch_k <= 256:
  704. max_seqlen_k = 256
  705. debug_mask = convert_tensor(
  706. torch.empty(
  707. (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
  708. dtype=query.dtype,
  709. device="meta",
  710. ),
  711. query.device,
  712. )
  713. else:
  714. debug_mask = convert_tensor(
  715. torch.empty(0, dtype=query.dtype, device="meta"),
  716. query.device,
  717. )
  718. # See Note [Seed and Offset]:
  719. return (
  720. attention,
  721. logsumexp,
  722. convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
  723. convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
  724. debug_mask,
  725. )
  726. @register_op_impl(aten._efficient_attention_forward.default)
  727. def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs):
  728. _, kwargs = normalize_function(
  729. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  730. )
  731. query = kwargs["query"]
  732. key = kwargs["key"]
  733. value = kwargs["value"]
  734. cu_seqlens_q = kwargs["cu_seqlens_q"]
  735. max_seqlen_q = kwargs["max_seqlen_q"]
  736. max_seqlen_k = kwargs["max_seqlen_k"]
  737. compute_log_sumexp = kwargs["compute_log_sumexp"]
  738. # unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, seqlen_k
  739. def convert_tensor(t, device):
  740. return FakeTensor(fake_mode, t, device)
  741. B = query.size(0)
  742. M = query.size(1)
  743. N = key.size(1)
  744. num_heads = query.size(-2)
  745. K = query.size(-1)
  746. Kv = value.size(-1)
  747. res = convert_tensor(
  748. torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),
  749. query.device,
  750. )
  751. logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
  752. actual_max_seqlen_q = M
  753. if cu_seqlens_q is not None:
  754. assert max_seqlen_q is not None
  755. actual_max_seqlen_q = max_seqlen_q
  756. actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N
  757. logsumexp_dim = (
  758. math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
  759. )
  760. logsum_exp = convert_tensor(
  761. torch.empty(
  762. (logsumexp_batch_dim, num_heads, logsumexp_dim),
  763. dtype=torch.float,
  764. device="meta",
  765. ),
  766. query.device,
  767. )
  768. # See Note [Seed and Offset]:
  769. seed = convert_tensor(
  770. torch.empty((), dtype=torch.long, device="meta"), query.device
  771. )
  772. offset = convert_tensor(
  773. torch.empty((), dtype=torch.long, device="meta"), query.device
  774. )
  775. return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k
  776. @register_op_impl(torch.ops.aten._pack_padded_sequence.default)
  777. def _pack_padded_sequence(fake_mode, func, inputs, lengths, batch_first):
  778. if (
  779. fake_mode.shape_env is None
  780. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  781. ):
  782. # Without symints/symfloats, cannot handle this
  783. raise DynamicOutputShapeException(func)
  784. new_batch_size = fake_mode.shape_env.create_unbacked_symint()
  785. from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
  786. _constrain_range_for_size(new_batch_size)
  787. if not batch_first:
  788. # Inputs should have shape (batch_size, seq_len, *)
  789. inputs = inputs.transpose(0, 1)
  790. res_size = inputs.shape[1:]
  791. packed_data = inputs.new_empty(res_size)
  792. batch_size = inputs.new_empty((new_batch_size,))
  793. return (packed_data, batch_size)
  794. FAST_OP_IMPLEMENTATIONS = {}
  795. # Unlike register_op_impl, these don't do the slow iteration for
  796. # run_impl_check, and these run BEFORE decompositions
  797. def register_fast_op_impl(func: OpOverload):
  798. def impl_decorator(op_impl):
  799. FAST_OP_IMPLEMENTATIONS[func] = op_impl
  800. return op_impl
  801. return impl_decorator
  802. # infer_size_impl in ExpandUtils
  803. def infer_size(a, b):
  804. from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
  805. dimsA = len(a)
  806. dimsB = len(b)
  807. ndim = max(dimsA, dimsB)
  808. expandedSizes = [0] * ndim
  809. for i in range(ndim - 1, -1, -1):
  810. offset = ndim - 1 - i
  811. dimA = dimsA - 1 - offset
  812. dimB = dimsB - 1 - offset
  813. sizeA = a[dimA] if dimA >= 0 else 1
  814. sizeB = b[dimB] if dimB >= 0 else 1
  815. # NB: It is very important to test for broadcasting, before testing
  816. # sizeA == sizeB. This is because the broadcasting tests are likely
  817. # to be statically known (in particular, if sizeA/sizeB is unbacked
  818. # but size-like, we will unsoundly assume they never equal 1), but
  819. # the sizeA == sizeB test may not be statically known. However, once
  820. # we have established that no broadcasting is happening, the
  821. # sizeA == sizeB is now expect_true and we can defer it as a runtime
  822. # assert (this works because Python will return the terminal
  823. # expression of an or statement as-is, without bool()'ing it; if this
  824. # were not the case, we'd need to write this using torch.sym_or() or
  825. # something like that).
  826. torch._check(
  827. guard_size_oblivious(sizeA == 1)
  828. or guard_size_oblivious(sizeB == 1)
  829. or sizeA == sizeB,
  830. lambda: f"The size of tensor a ({sizeA}) "
  831. f"must match the size of tensor b ({sizeB}) "
  832. f"at non-singleton dimension {i})",
  833. )
  834. expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
  835. return tuple(expandedSizes)
  836. def make_fast_binary_impl(slow_ref):
  837. def fast_binary_impl(mode, *args, **kwargs):
  838. def slow(msg):
  839. count_label(f"slow {msg}")
  840. with mode:
  841. return slow_ref(*args, **kwargs)
  842. count_label("attempt fast")
  843. # Fast path (based off of TensorIterator fast path).
  844. # Unfortunately, there is no way to easily deduplicate
  845. # this with either the TensorIterator C++ implementation
  846. # (which we don't want to SymIntify, and also the algorithm
  847. # here is slightly different from TensorIterator to allow
  848. # for broadcasting), nor the PrimTorch implementation
  849. # (which does not actually implement a fast path.)
  850. operands = args
  851. # compute_shape
  852. has_scalars = False
  853. has_tensors = False
  854. final_shape = None
  855. for op in operands:
  856. shape = op.shape if isinstance(op, torch.Tensor) else ()
  857. if len(shape) == 0:
  858. has_scalars = True
  859. else:
  860. has_tensors = True
  861. if final_shape is None:
  862. final_shape = shape
  863. # TODO: Minor optimization: track if the shapes
  864. # were equal so you can skip the equality check
  865. # below if unnecessary
  866. final_shape = infer_size(final_shape, shape)
  867. assert final_shape is not None
  868. # Do some extra safety checks to see if the output
  869. # stride is obvious
  870. for op in operands:
  871. if (
  872. isinstance(op, torch.Tensor)
  873. and len(op.shape) == len(final_shape)
  874. and op.shape == final_shape
  875. ):
  876. break
  877. else:
  878. return slow("both tensors nontrivially broadcast")
  879. # compute_types
  880. cpu = torch.device("cpu")
  881. common_device = cpu
  882. common_dtype = None
  883. output_dtype = None
  884. has_different_input_dtypes = False
  885. for op in operands:
  886. if not isinstance(op, torch.Tensor):
  887. # Use elementwise_dtypes for the tricky case
  888. has_different_input_dtypes = True
  889. continue
  890. if common_device == cpu and not op.device.type == "cpu":
  891. common_device = op.device
  892. # Slightly simplified here as target_dtype cannot vary
  893. if common_dtype is None:
  894. common_dtype = op.dtype
  895. elif common_dtype != op.dtype:
  896. has_different_input_dtypes = True
  897. if has_different_input_dtypes:
  898. # compute promotion
  899. # TODO: we don't need the compute type
  900. _, common_dtype = elementwise_dtypes(
  901. *operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  902. )
  903. # check all tensors on same device
  904. # cpu scalars are assumed allow
  905. current_cpu_scalars_on_non_cpu = 0
  906. max_cpu_scalars_on_non_cpu = 1 # hard coded atm
  907. for op in operands:
  908. if not isinstance(op, torch.Tensor):
  909. continue
  910. if common_device != cpu and op.dim() == 0 and op.device == cpu:
  911. if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:
  912. return slow("error")
  913. current_cpu_scalars_on_non_cpu += 1
  914. elif op.device != common_device:
  915. return slow("error")
  916. # compute_fast_setup_type
  917. is_contiguous = True
  918. is_channels_last = True
  919. # TODO: is_non-overlapping_and_dense (not bound from Python
  920. # no inplace, no out, everything defined
  921. if is_noncontiguous_supported(common_device):
  922. for op in operands:
  923. if not isinstance(op, torch.Tensor):
  924. continue
  925. is_contiguous = is_contiguous and op.is_contiguous(
  926. memory_format=torch.contiguous_format
  927. )
  928. is_channels_last = is_channels_last and op.is_contiguous(
  929. memory_format=torch.channels_last
  930. )
  931. if is_contiguous:
  932. # do contiguous
  933. count_label("fast is_contiguous")
  934. return FakeTensor(
  935. mode,
  936. torch.empty(
  937. final_shape,
  938. dtype=common_dtype,
  939. device="meta",
  940. memory_format=torch.contiguous_format,
  941. ),
  942. device=common_device,
  943. )
  944. if is_channels_last:
  945. count_label("fast channels_last")
  946. # do channels last
  947. return FakeTensor(
  948. mode,
  949. torch.empty(
  950. final_shape,
  951. dtype=common_dtype,
  952. device="meta",
  953. memory_format=torch.channels_last,
  954. ),
  955. device=common_device,
  956. )
  957. return slow("no contiguity match")
  958. return fast_binary_impl
  959. @functools.lru_cache(None)
  960. def get_fast_op_impls():
  961. import torch._refs
  962. register_fast_op_impl(torch.ops.aten.add.Tensor)(
  963. make_fast_binary_impl(torch._refs.add)
  964. )
  965. register_fast_op_impl(torch.ops.aten.sub.Tensor)(
  966. make_fast_binary_impl(torch._refs.sub)
  967. )
  968. register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type]
  969. register_fast_op_impl(torch.ops.aten.div.Tensor)(
  970. make_fast_binary_impl(torch._refs.div)
  971. )
  972. return FAST_OP_IMPLEMENTATIONS