ops.py 38 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import math
  4. import operator
  5. import torch
  6. from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention
  7. from .nested_tensor import NestedTensor
  8. from typing import * # noqa: F403
  9. import torch.nn.functional as F
  10. from torch.fx.operator_schemas import normalize_function
  11. __all__: List[Any] = []
  12. JAGGED_OPS_TABLE: Dict[Any, Any] = {}
  13. # Simplifying assumption: we assume that the batch dim is always the left-most
  14. # dim, and the ragged dim is always the second dim.
  15. def _outer_to_inner_dim(ndim, dim):
  16. assert dim >= 0 and dim < ndim
  17. return 0 if dim < 2 else dim - 1
  18. def _wrap_jagged_dim(
  19. ndim, dim, op_name, convert_to_inner_dim=True, allow_batch_dim=False
  20. ):
  21. from torch._prims_common import canonicalize_dims
  22. wrapped = canonicalize_dims(ndim, dim)
  23. if wrapped == 1:
  24. raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=1")
  25. elif wrapped == 0 and not allow_batch_dim:
  26. raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0")
  27. return _outer_to_inner_dim(ndim, wrapped) if convert_to_inner_dim else wrapped
  28. def _wrap_jagged_dims(ndim, dims, op_name):
  29. # ex: (2, 3, 4) -> (1, 2, 3)
  30. # ex: (0, 1, 4) -> (0, 3)
  31. from torch._prims_common import canonicalize_dims
  32. wrapped_dims = [canonicalize_dims(ndim, d) for d in dims]
  33. # This logic needs to be done after we canonicalize dims but before we
  34. # map to inner dims so we can print a nicer error message.
  35. zero_in_dims = 0 in wrapped_dims
  36. one_in_dims = 1 in wrapped_dims
  37. if zero_in_dims ^ one_in_dims:
  38. apply, not_apply = ("batch", "ragged") if zero_in_dims else ("ragged", "batch")
  39. raise RuntimeError(
  40. f"{op_name}(): applying over the {apply} dimension, but not the {not_apply}"
  41. " dimension is not supported for NestedTensor"
  42. )
  43. return (
  44. tuple(_outer_to_inner_dim(ndim, d) for d in dims if d != 0),
  45. zero_in_dims,
  46. )
  47. def check_schema(schema_str: str, func, *args, **kwargs) -> None:
  48. named_arg_types = schema_str.split(", ")
  49. num_optional_args = [x.endswith("?") for x in named_arg_types].count(True)
  50. min_args = len(named_arg_types) - num_optional_args
  51. # special case: ellipses allows for any number of unchecked args at the end
  52. if named_arg_types[-1] == "...":
  53. named_arg_types = named_arg_types[:-1]
  54. else:
  55. if not (len(args) >= min_args and len(args) <= len(named_arg_types)):
  56. raise ValueError(
  57. f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} "
  58. f"arguments and at most {len(named_arg_types)} arguments, but got: "
  59. f"{len(args)} arguments"
  60. )
  61. arg_type_check_fns = {
  62. "t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
  63. "jt": lambda x: isinstance(x, NestedTensor)
  64. and x._lengths is None
  65. and x._ragged_idx == 1, # ops with "jt" require contiguous JT only
  66. "jt_all": lambda x: isinstance(
  67. x, NestedTensor
  68. ), # ops with "jt_all" can accept all kinds of JT
  69. "any": lambda x: True,
  70. }
  71. for i, named_arg_type in enumerate(named_arg_types):
  72. name, arg_type = named_arg_type.split(": ")
  73. is_optional = arg_type.endswith("?")
  74. normalized_arg_type = arg_type[:-1] if is_optional else arg_type
  75. if normalized_arg_type not in arg_type_check_fns.keys():
  76. raise AssertionError(f"Unknown arg type: {normalized_arg_type}")
  77. if i >= len(args):
  78. if not is_optional:
  79. raise ValueError(
  80. f"NestedTensor {func.__name__}({schema_str}) "
  81. f"missing required argument: {name}"
  82. )
  83. continue
  84. _check_fn = arg_type_check_fns[normalized_arg_type]
  85. def check_fn(x, is_optional=is_optional):
  86. if is_optional:
  87. return x is None or _check_fn(x)
  88. else:
  89. return _check_fn(x)
  90. if not check_fn(args[i]):
  91. type_to_desc = {
  92. "t": "tensor",
  93. "t?": "optional tensor",
  94. "jt": "contiguous jagged layout NestedTensor",
  95. "jt_all": "jagged layout NestedTensor",
  96. "any": "<any type>",
  97. }
  98. raise ValueError(
  99. f"NestedTensor {func.__name__}({schema_str}): expected {name} to be a "
  100. f"{type_to_desc[arg_type]}"
  101. )
  102. def check_ragged_dim_same(
  103. func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str
  104. ) -> None:
  105. # Calling into .shape here
  106. if a._size[a._ragged_idx] != b._size[b._ragged_idx]:
  107. raise RuntimeError(
  108. f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the "
  109. "same exact offsets tensor."
  110. )
  111. # returns True if the raggedness-relevant portions of the NT shape
  112. # match those of the specified size
  113. def raggedness_matches(nt, size):
  114. end = nt._ragged_idx + 1
  115. nt_ragged = nt._size[:end]
  116. size_ragged = size[:end]
  117. return len(nt_ragged) == len(size_ragged) and (
  118. all(ns == s or s == -1 for ns, s in zip(nt_ragged, size_ragged))
  119. )
  120. def squeeze_leading_ones(t):
  121. # Note: [ Squeezing leading ones ]
  122. #
  123. # Squeeze leading ones from t.
  124. #
  125. # We want:
  126. # (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
  127. # (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) (not yet supported)
  128. #
  129. # 1) Squeeze extra ones and grab values from NT
  130. # (1, 1, ?, ?) -> (?, ?) and (sum(*), ?, ?) -> (B, j0, ?, ?)
  131. # 2) Do dense broadcasting:
  132. # (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?)
  133. # 3) Construct nested tensor
  134. # (sum(*), ?, ?) -> (B, j0, ?, ?)
  135. #
  136. # If unsqueezing on the 0th dim becomes supported, we would unsqueeze
  137. # at step (4) and we would need to update this function to record how
  138. # many ones we unsqueezed.
  139. while t.shape[0] == 1:
  140. t = t.squeeze(0)
  141. return t
  142. def register_func(tables, aten_ops, schema_str):
  143. if not isinstance(aten_ops, list):
  144. aten_ops = [aten_ops]
  145. if not isinstance(tables, list):
  146. tables = [tables]
  147. def wrapper(func):
  148. for aten_op in aten_ops:
  149. def get_inner(aten_op):
  150. def inner(*args, **kwargs):
  151. check_schema(schema_str, func, *args, **kwargs)
  152. return func(aten_op, *args, **kwargs)
  153. return inner
  154. for table in tables:
  155. table[aten_op] = get_inner(aten_op)
  156. return func
  157. return wrapper
  158. register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE)
  159. def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
  160. dispatch_func = JAGGED_OPS_TABLE.get(func, None)
  161. if dispatch_func is not None:
  162. return dispatch_func
  163. # Handle pointwise fallbacks
  164. if torch.Tag.pointwise in func.tags:
  165. # Assume there aren't additional tensors that aren't the "unary/binary" args
  166. num_tensor_args = sum(isinstance(x, torch.Tensor) for x in args)
  167. if num_tensor_args == 1:
  168. check_schema("self: jt_all, ...", func, *args, **kwargs)
  169. return functools.partial(jagged_unary_pointwise, func)
  170. elif num_tensor_args == 2:
  171. check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs)
  172. return functools.partial(jagged_binary_pointwise, func)
  173. return None
  174. def extract_kwargs(arg):
  175. kwargs = {
  176. "offsets": arg.offsets(),
  177. "_metadata_cache": arg._metadata_cache,
  178. "_ragged_idx": arg._ragged_idx,
  179. }
  180. return kwargs
  181. def jagged_unary_pointwise(func, *args, **kwargs):
  182. return NestedTensor(
  183. func(args[0]._values, *args[1:], **kwargs), **extract_kwargs(args[0])
  184. )
  185. def jagged_binary_pointwise(func, *args, **kwargs):
  186. a, b = args[0], args[1]
  187. assert isinstance(a, NestedTensor) or isinstance(b, NestedTensor)
  188. mismatch_error_msg = (
  189. "cannot call binary pointwise function {} with inputs of shapes {} and {}"
  190. )
  191. # a is NT, b is NT
  192. if isinstance(a, NestedTensor) and isinstance(b, NestedTensor):
  193. # ex: (B, j0, D) + (B, j0, D)
  194. # ex: (B, j0, D) + (B, j0, 1)
  195. if raggedness_matches(a, b._size):
  196. return NestedTensor(
  197. func(a._values, b._values, *args[2:], **kwargs), **extract_kwargs(a)
  198. )
  199. raise RuntimeError(mismatch_error_msg.format(func.__name__, a._size, b._size))
  200. # either a is NT or b is NT at this point
  201. a_is_nt = isinstance(a, NestedTensor)
  202. extracted_kwargs = extract_kwargs(a) if a_is_nt else extract_kwargs(b)
  203. # === Handle broadcasting across the batch / ragged dims ===
  204. # Easy case: take advantage of pre-existing broadcasting logic
  205. # ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
  206. # ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
  207. # ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
  208. nt, t = (a, b) if a_is_nt else (b, a)
  209. # See Note: [ Squeezing leading ones ]
  210. if t.dim() > nt.dim():
  211. raise NotImplementedError("NYI: broadcasting NT with T with larger dim")
  212. t_squeezed = squeeze_leading_ones(t)
  213. if nt.dim() >= t_squeezed.dim() + 2:
  214. lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values)
  215. return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs)
  216. # Harder case: do manual broadcasting over unbound components
  217. # when NT dim == non-NT dim
  218. # ex: (B, j0, D_0, D_1) + (B, 1, D_0, D_1) -> (B, j0, D_0, D_1)
  219. if a.dim() == b.dim():
  220. # ex: (B, j0, D_0, D_1) + (1, 1, D_0, D_1) -> should
  221. # be (B, j0, D_0, D_1) but not yet supported
  222. if a.shape[0] != b.shape[0]:
  223. raise RuntimeError(
  224. mismatch_error_msg.format(func.__name__, a.shape, b.shape)
  225. )
  226. # need to use offsets to broadcast across ragged dim properly
  227. # NB: inefficient fallback here; Triton codegen can help this
  228. # TODO: Make this work with autograd
  229. outputs = []
  230. for a_comp, b_comp in zip(a.unbind(), b.unbind()):
  231. outputs.append(func(a_comp, b_comp, *args[2:], **kwargs))
  232. new_values = torch.cat(outputs, dim=0)
  233. return NestedTensor(new_values, **extracted_kwargs)
  234. # ex: (B, j0, D_0, D_1) + (A, B, 1, D_0, D_1) -> error because this breaks the invariant
  235. # that ragged dim is wrt left-most batch dim
  236. raise RuntimeError(mismatch_error_msg.format(func.__name__, a.shape, b.shape))
  237. def jagged_torch_function(func, *args, **kwargs):
  238. # SDPA has special kernels that handle nested tensors.
  239. # Dispatch to the correct implementation here
  240. if func is torch._C._nn.scaled_dot_product_attention:
  241. return jagged_scaled_dot_product_attention(*args, **kwargs)
  242. # Handle flatten() here because it's CompositeImplicit.
  243. if func.__name__ == "flatten":
  244. def _flatten_sig(input, start_dim=0, end_dim=-1):
  245. pass
  246. _, new_kwargs = normalize_function(
  247. _flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  248. )
  249. inp = new_kwargs.pop("input")
  250. # NB: stay in outer dim space because we're going to redispatch on a NT input
  251. start_dim = _wrap_jagged_dim(
  252. inp.dim(), new_kwargs["start_dim"], "flatten", convert_to_inner_dim=False
  253. )
  254. end_dim = _wrap_jagged_dim(
  255. inp.dim(), new_kwargs["end_dim"], "flatten", convert_to_inner_dim=False
  256. )
  257. if start_dim == end_dim:
  258. return inp
  259. product = functools.reduce(operator.mul, inp.shape[start_dim : end_dim + 1])
  260. new_shape = (*inp.shape[:start_dim], product, *inp.shape[end_dim + 1 :])
  261. return inp.reshape(*new_shape)
  262. raise NotImplementedError(func)
  263. @register_jagged_func(
  264. [
  265. torch.ops.aten.is_non_overlapping_and_dense.default,
  266. torch.ops.aten.sym_size.default,
  267. torch.ops.aten.dim.default,
  268. torch.ops.aten.numel.default,
  269. torch.ops.aten.sym_numel.default,
  270. torch.ops.aten.sym_stride.default,
  271. torch.ops.aten.sym_storage_offset.default,
  272. ],
  273. "self: jt_all",
  274. )
  275. def tensor_attr_supported_getter(func, *args, **kwargs):
  276. if func == torch.ops.aten.is_non_overlapping_and_dense.default:
  277. return False
  278. if func == torch.ops.aten.sym_size.default:
  279. return args[0]._size
  280. if func == torch.ops.aten.dim.default:
  281. return len(args[0]._size)
  282. if func in (torch.ops.aten.sym_numel.default, torch.ops.aten.numel.default):
  283. if args[0]._lengths is not None:
  284. return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:]))
  285. return args[0]._values.numel()
  286. if func == torch.ops.aten.sym_stride.default:
  287. return args[0]._strides
  288. if func == torch.ops.aten.sym_storage_offset.default:
  289. return args[0]._values.storage_offset()
  290. @register_jagged_func(torch.ops.prim.layout.default, "self: jt_all")
  291. def prim_layout_default(func, *args, **kwargs):
  292. return torch.jagged
  293. @register_jagged_func(
  294. [torch.ops.aten.size.default],
  295. "self: jt_all",
  296. )
  297. def tensor_attr_unsupported_getter(func, *args, **kwargs):
  298. if func == torch.ops.aten.size.default:
  299. raise RuntimeError(
  300. "NestedTensors does not support directly calling torch.ops.aten.size "
  301. "please use `nested_tensor.size()` instead."
  302. )
  303. @register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt_all")
  304. def is_contiguous_general(func, *args, **kwargs):
  305. from torch._prims_common import is_contiguous_for_memory_format
  306. _, new_kwargs = normalize_function(
  307. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  308. )
  309. inp = new_kwargs.pop("input")
  310. # If created from narrow() check for lengths
  311. if inp.lengths() is not None:
  312. return False
  313. new_kwargs["memory_format"] = new_kwargs.get(
  314. "memory_format", torch.contiguous_format
  315. )
  316. if new_kwargs["memory_format"] == torch.preserve_format:
  317. return True
  318. return is_contiguous_for_memory_format(inp._values, **new_kwargs)
  319. register_jagged_func(
  320. torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?"
  321. )(is_contiguous_general)
  322. @register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
  323. def linear_default(func, *args, **kwargs):
  324. _, new_kwargs = normalize_function(
  325. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  326. )
  327. inp = new_kwargs.pop("input")
  328. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  329. @register_jagged_func(
  330. torch.ops.aten.linear_backward.default,
  331. "self: jt, grad_output: jt, weight: t, output_mask: any",
  332. )
  333. def linear_backward_default(func, *args, **kwargs):
  334. _, new_kwargs = normalize_function(
  335. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  336. )
  337. inp = new_kwargs.pop("input")
  338. grad_output = new_kwargs.pop("grad_output")
  339. weight = new_kwargs.pop("weight")
  340. check_ragged_dim_same(func, inp, "self", grad_output, "grad_output")
  341. ds = NestedTensor(
  342. torch.mm(grad_output._values, weight), **extract_kwargs(grad_output)
  343. )
  344. dw = torch.mm(grad_output._values.T, inp._values)
  345. db = None # NYI: gradient for bias, need to reduce over ragged dim
  346. return (ds, dw, db)
  347. @register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all")
  348. def to_copy_default(func, *args, **kwargs):
  349. from .nested_tensor import _tensor_symint_registry
  350. _, new_kwargs = normalize_function(
  351. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  352. )
  353. inp = new_kwargs.pop("input")
  354. # don't change layout
  355. new_kwargs.pop("layout")
  356. new_values = func(inp._values, **new_kwargs)
  357. new_offsets = inp._offsets.to(device=new_values.device)
  358. _tensor_symint_registry[new_offsets] = _tensor_symint_registry[inp._offsets]
  359. inp_kwargs = extract_kwargs(inp)
  360. inp_kwargs["offsets"] = new_offsets
  361. return NestedTensor(new_values, **inp_kwargs)
  362. register_jagged_func(
  363. [
  364. torch.ops.aten.empty_like.default,
  365. torch.ops.aten.ones_like.default,
  366. torch.ops.aten.zeros_like.default,
  367. torch.ops.aten.randn_like.default,
  368. torch.ops.aten.detach.default,
  369. ],
  370. "self: jt_all",
  371. )(jagged_unary_pointwise)
  372. register_jagged_func(
  373. torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any"
  374. )(jagged_unary_pointwise)
  375. @register_jagged_func(
  376. torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?"
  377. )
  378. def native_dropout_default(func, *args, **kwargs):
  379. _, new_kwargs = normalize_function(
  380. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  381. )
  382. inp = new_kwargs.pop("input")
  383. out1, out2 = func(inp._values, **new_kwargs)
  384. return (
  385. NestedTensor(out1, **extract_kwargs(inp)),
  386. NestedTensor(out2, **extract_kwargs(inp)),
  387. )
  388. @register_jagged_func(
  389. torch.ops.aten.native_dropout_backward.default,
  390. "grad_output: jt, mask: jt, scale: any",
  391. )
  392. def native_dropout_backward_default(func, *args, **kwargs):
  393. _, new_kwargs = normalize_function(
  394. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  395. )
  396. grad_output = new_kwargs.pop("grad_output")
  397. mask = new_kwargs.pop("mask")
  398. return NestedTensor(
  399. func(grad_output._values, mask._values, **new_kwargs),
  400. **extract_kwargs(grad_output),
  401. )
  402. @register_jagged_func(torch.ops.aten.prod.dim_int, "self: jt, dim: any, keepdim: any?")
  403. def prod_dim_int(func, *args, **kwargs):
  404. _, new_kwargs = normalize_function(
  405. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  406. )
  407. inp = new_kwargs.pop("input")
  408. # TODO: Figure out how to handle this better
  409. # keep_dim is required to keep it in jagged format
  410. if not new_kwargs["keepdim"]:
  411. raise RuntimeError("prod(): keepdim=True must be set for NestedTensor")
  412. dim = new_kwargs["dim"]
  413. new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), dim, "prod")
  414. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(args[0]))
  415. @register_jagged_func(
  416. torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any"
  417. )
  418. def split_tensor(func, *args, **kwargs):
  419. _, new_kwargs = normalize_function(
  420. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  421. )
  422. inp = new_kwargs.pop("input")
  423. new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "split")
  424. return tuple(
  425. NestedTensor(values=x, **extract_kwargs(inp))
  426. for x in func(inp._values, **new_kwargs)
  427. )
  428. @register_jagged_func(
  429. torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any"
  430. )
  431. def split_with_sizes_default(func, *args, **kwargs):
  432. _, new_kwargs = normalize_function(
  433. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  434. )
  435. inp = new_kwargs.pop("input")
  436. new_kwargs["dim"] = _wrap_jagged_dim(
  437. inp.dim(), new_kwargs["dim"], "split_with_sizes"
  438. )
  439. return [
  440. NestedTensor(values=x, **extract_kwargs(inp))
  441. for x in func(inp._values, **new_kwargs)
  442. ]
  443. @register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?")
  444. def chunk_default(func, *args, **kwargs):
  445. _, new_kwargs = normalize_function(
  446. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  447. )
  448. inp = new_kwargs.pop("input")
  449. new_kwargs["dim"] = _wrap_jagged_dim(
  450. inp.dim(), new_kwargs["dim"], "chunk", allow_batch_dim=True
  451. )
  452. if new_kwargs["dim"] == 0:
  453. chunks = new_kwargs["chunks"]
  454. dim0_size = inp._size[0]
  455. chunk_size = math.ceil(dim0_size / chunks)
  456. # get _offsets of the chunks
  457. lengths = inp._offsets.diff()
  458. chunked_lengths = lengths.chunk(chunks)
  459. chunked_offsets = [torch.cumsum(x, dim=0) for x in chunked_lengths]
  460. chunked_offsets = [F.pad(x, (1, 0), value=0) for x in chunked_offsets]
  461. nested_kwargs = [
  462. {"offsets": per_offsets, "_ragged_idx": inp._ragged_idx}
  463. for per_offsets in chunked_offsets
  464. ]
  465. # get _values of the chunks
  466. split_sizes = [x.sum().item() for x in chunked_lengths]
  467. chunk_values = inp._values.split(split_sizes)
  468. return [
  469. NestedTensor(values=chunk_values[i], **(nested_kwargs[i]))
  470. for i in range(0, chunk_size)
  471. ]
  472. else:
  473. return [
  474. NestedTensor(values=x, **extract_kwargs(inp))
  475. for x in func(inp._values, **new_kwargs)
  476. ]
  477. @register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
  478. def unbind_int(func, *args, **kwargs):
  479. # Note that this specializes on the length of the offsets
  480. _, new_kwargs = normalize_function(
  481. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  482. )
  483. dim = new_kwargs["dim"]
  484. if dim != 0:
  485. raise RuntimeError("unbind(): only supported for NestedTensor on dim=0")
  486. inp = new_kwargs.pop("input")
  487. values = inp.values()
  488. offsets = inp.offsets()
  489. lengths = inp.lengths()
  490. ragged_idx = inp._ragged_idx
  491. if lengths is None:
  492. return torch.split(values, offsets.diff().tolist(), dim=(ragged_idx - 1))
  493. if ragged_idx <= 0:
  494. raise RuntimeError(
  495. "unbind(): nested tensor ragged_idx out of bounds (should be >= 1)"
  496. )
  497. for i in range(lengths.shape[0]):
  498. if offsets[i] + lengths[i] > values.shape[ragged_idx - 1]:
  499. raise RuntimeError(
  500. "unbind(): nested tensor offsets and lengths do not match ragged_idx dimension"
  501. )
  502. return [
  503. torch.narrow(values, dim=(ragged_idx - 1), start=offsets[i], length=lengths[i])
  504. for i in range(lengths.shape[0])
  505. ]
  506. @register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any")
  507. def squeeze_dim(func, *args, **kwargs):
  508. _, new_kwargs = normalize_function(
  509. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  510. )
  511. inp = new_kwargs.pop("input")
  512. values = inp._values
  513. new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), new_kwargs["dim"], "squeeze")
  514. return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp))
  515. @register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any")
  516. def unsqueeze_default(func, *args, **kwargs):
  517. _, new_kwargs = normalize_function(
  518. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  519. )
  520. inp = new_kwargs.pop("input")
  521. values = inp._values
  522. # Account for collapsed jagged dim
  523. dim = new_kwargs["dim"]
  524. new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size) + 1, dim, "unsqueeze")
  525. return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp))
  526. @register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
  527. def cat_default(func, *args, **kwargs):
  528. _, new_kwargs = normalize_function(
  529. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  530. )
  531. tensors = new_kwargs.pop("tensors")
  532. # Convert any non-nested to nested
  533. nested = [t for t in tensors if t.is_nested]
  534. assert len(nested) > 0
  535. first = nested[0]
  536. tensors = [t if t.is_nested else t.expand_as(first) for t in tensors]
  537. # Account for collapsed jagged dim
  538. dim = new_kwargs["dim"]
  539. new_kwargs["dim"] = _wrap_jagged_dim(len(first.shape), dim, "cat")
  540. return NestedTensor(
  541. func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
  542. )
  543. @register_jagged_func(torch.ops.aten.matmul.default, "self: jt, other: any")
  544. def matmul_default(func, *args, **kwargs):
  545. _, new_kwargs = normalize_function(
  546. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  547. )
  548. inp = new_kwargs.pop("input")
  549. other = new_kwargs.pop("other")
  550. if inp.is_nested and not other.is_nested:
  551. return NestedTensor(
  552. func(inp._values, other, **new_kwargs), **extract_kwargs(inp)
  553. )
  554. elif inp.is_nested and other.is_nested:
  555. # BMM with equivalent ragged dims between the two inputs
  556. if inp.dim() > 3 and other.dim() > 3 and raggedness_matches(inp, other._size):
  557. return NestedTensor(func(inp._values, other._values), **extract_kwargs(inp))
  558. raise RuntimeError(
  559. f"matmul(): not supported between inputs of shapes {inp._size} and {other.shape}"
  560. )
  561. @register_jagged_func(
  562. torch.ops.aten.expand.default, "self: jt, size: any, implicit: any?"
  563. )
  564. def expand_default(func, *args, **kwargs):
  565. _, new_kwargs = normalize_function(
  566. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  567. )
  568. inp = new_kwargs.pop("input")
  569. size = new_kwargs["size"]
  570. assert ("implicit" not in new_kwargs) or (not new_kwargs.pop("implicit"))
  571. if not raggedness_matches(inp, size):
  572. raise RuntimeError(f"expand(): cannot expand shape {inp._size} -> {size}")
  573. expand_arg = [-1, *size[2:]]
  574. return NestedTensor(func(inp._values, expand_arg), **extract_kwargs(inp))
  575. @register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt")
  576. def expand_as_default(func, *args, **kwargs):
  577. _, new_kwargs = normalize_function(
  578. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  579. )
  580. inp = new_kwargs.pop("input")
  581. other = new_kwargs.pop("other")
  582. return NestedTensor(func(inp, other._values), **extract_kwargs(other))
  583. @register_jagged_func(torch.ops.aten.where.self, "condition: jt, self: jt, other: jt")
  584. def where_self(func, *args, **kwargs):
  585. _, new_kwargs = normalize_function(
  586. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  587. )
  588. condition = new_kwargs.pop("condition")
  589. inp = new_kwargs.pop("input")
  590. other = new_kwargs.pop("other")
  591. assert condition._size == other._size == inp._size
  592. return NestedTensor(
  593. func(condition._values, inp._values, other._values, **new_kwargs),
  594. **extract_kwargs(condition),
  595. )
  596. @register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?")
  597. def _pin_memory_default(func, *args, **kwargs):
  598. _, new_kwargs = normalize_function(
  599. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  600. )
  601. inp = new_kwargs.pop("input")
  602. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  603. @register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?")
  604. def is_pinned_default(func, *args, **kwargs):
  605. _, new_kwargs = normalize_function(
  606. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  607. )
  608. inp = new_kwargs.pop("input")
  609. return func(inp._values, **new_kwargs)
  610. @register_jagged_func(
  611. torch.ops.aten.is_same_size.default, "self: jt_all, other: jt_all"
  612. )
  613. def is_same_size_default(func, *args, **kwargs):
  614. return args[0]._size == args[1]._size
  615. @register_jagged_func(
  616. torch.ops.aten.sum.dim_IntList, "self: jt, dim: any?, keepdim: any?, dtype: any?"
  617. )
  618. def sum_dim_IntList(func, *args, **kwargs):
  619. # sum_dim_IntList can produce a NT or a T depending on whether the ragged dims
  620. # are reduced away.
  621. _, new_kwargs = normalize_function(
  622. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  623. )
  624. inp = new_kwargs.pop("input")
  625. assert inp._ragged_idx == 1
  626. new_kwargs["dim"], ragged_reduced_away = _wrap_jagged_dims(
  627. inp.dim(), new_kwargs["dim"], "sum"
  628. )
  629. if not ragged_reduced_away:
  630. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  631. else:
  632. # Don't wrap because we reduced away the raggedness
  633. out = func(inp._values, **new_kwargs)
  634. if new_kwargs["keepdim"]:
  635. out = out.unsqueeze(0)
  636. return out
  637. @register_jagged_func(
  638. torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any"
  639. )
  640. def transpose_int(func, *args, **kwargs):
  641. _, new_kwargs = normalize_function(
  642. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  643. )
  644. from torch._prims_common import canonicalize_dims
  645. inp = new_kwargs.pop("input")
  646. dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"]))
  647. if inp._lengths is not None:
  648. raise ValueError(
  649. "transpose(): not supported on jagged layout nested tensor with holes"
  650. )
  651. # To support the SDPA API, inputs need to have the ragged idx transposed to dim 2
  652. # instead of 1, although the internal Flash and mem-effn implementations will
  653. # use the inputs with raggedness in dim 1.
  654. if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx:
  655. if dim0 == 0 or dim1 == 0:
  656. raise ValueError(
  657. "Transpose is not supported on the batch dimension for jagged NT"
  658. )
  659. if dim0 == inp._ragged_idx:
  660. to_dim = dim1
  661. else:
  662. to_dim = dim0
  663. inp_kwargs = extract_kwargs(inp)
  664. inp_kwargs["_ragged_idx"] = to_dim
  665. return NestedTensor(
  666. inp.values().transpose(
  667. _outer_to_inner_dim(len(inp._size), dim0),
  668. _outer_to_inner_dim(len(inp._size), dim1),
  669. ),
  670. **inp_kwargs,
  671. )
  672. new_kwargs["dim0"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim0"], "transpose")
  673. new_kwargs["dim1"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim1"], "transpose")
  674. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  675. @register_jagged_func(
  676. [torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default],
  677. "self: jt_all, size: any",
  678. )
  679. def view_default(func, *args, **kwargs):
  680. _, new_kwargs = normalize_function(
  681. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  682. )
  683. inp = new_kwargs.pop("input")
  684. size = new_kwargs.pop("size")
  685. if inp._ragged_idx != 1 and tuple(inp._size) != tuple(size):
  686. raise RuntimeError(
  687. f"view(): does not support ragged_idx != 1 except when inp._size == size. "
  688. f"inp._size is ({inp._size}) and size is ({size})."
  689. )
  690. # Ensure specified size still includes batch and ragged dims
  691. if len(size) < 3 or not raggedness_matches(inp, size):
  692. raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}")
  693. # outer size: the size of the NT, e.g. [3, j0, 10]
  694. # inner size: the size of the values, e.g. [8, 10] (e.g. for offsets = [0, 3, 5, 8])
  695. # this function gets inner_size[inner_idx] for a given inner_idx.
  696. #
  697. # example: for outer size [a, b, c, j0, d, e, f]
  698. # assume that j0 is ragged, other are concrete integers
  699. # and ragged_idx=3
  700. # inner size will be [b, c, inp._values.size(ragged_idx), d, e, f]
  701. # therefore:
  702. # inner_size[0] = outer_size[1]
  703. # inner_size[1] = outer_size[2]
  704. # inner_size[0] = inp._values.size(ragged_idx - 1)
  705. # inner_size[3] = outer_size[4]
  706. # inner_size[4] = outer_size[5]
  707. def get_inner_size(inner_idx):
  708. nonlocal inp, size
  709. if inner_idx == inp._ragged_idx - 1:
  710. return inp._values.size(inner_idx)
  711. else:
  712. return size[inner_idx + 1]
  713. inner_size = [get_inner_size(i) for i in range(len(size) - 1)]
  714. return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp))
  715. @register_jagged_func(
  716. torch.ops.aten.native_layer_norm.default,
  717. "input: jt, normalized_shape: any, weight: any?, bias: any?, eps: any",
  718. )
  719. def native_layer_norm_default(func, *args, **kwargs):
  720. _, new_kwargs = normalize_function(
  721. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  722. )
  723. inp = new_kwargs.pop("input")
  724. normalized_shape = new_kwargs["normalized_shape"]
  725. # Ensure we're not trying to normalize over the ragged dim
  726. if inp.dim() < 3 or (inp.dim() - len(normalized_shape)) < 2:
  727. raise RuntimeError(
  728. "layer_norm(): normalizing over ragged dim not supported for nested tensors"
  729. )
  730. output, mean, std = func(inp._values, **new_kwargs)
  731. return (NestedTensor(output, **extract_kwargs(inp)), mean, std)
  732. @register_jagged_func(
  733. torch.ops.aten.native_layer_norm_backward.default,
  734. "grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any",
  735. )
  736. def native_layer_norm_backward_default(func, *args, **kwargs):
  737. _, new_kwargs = normalize_function(
  738. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  739. )
  740. grad_out = new_kwargs.pop("grad_out")
  741. inp = new_kwargs.pop("input")
  742. d_input, d_gamma, d_beta = func(grad_out._values, inp._values, **new_kwargs)
  743. if d_input is None:
  744. return (None, d_gamma, d_beta)
  745. return (NestedTensor(d_input, **extract_kwargs(inp)), d_gamma, d_beta)
  746. @register_jagged_func(torch.ops.aten.select.int, "self: jt, dim: any, index: any")
  747. def select_int(func, *args, **kwargs):
  748. _, new_kwargs = normalize_function(
  749. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  750. )
  751. inp = new_kwargs.pop("input")
  752. new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "select")
  753. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  754. @register_jagged_func(
  755. torch.ops.aten.slice.Tensor,
  756. "self: jt, dim: any?, start: any?, end: any?, step: any?",
  757. )
  758. def slice_tensor(func, *args, **kwargs):
  759. _, new_kwargs = normalize_function(
  760. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  761. )
  762. inp = new_kwargs.pop("input")
  763. new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "slice")
  764. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  765. @register_jagged_func(
  766. torch.ops.aten.convolution.default,
  767. "input: jt, weight: t, bias: t?, stride: any, padding: any, "
  768. "dilation: any, transposed: any, output_padding: any, groups: any",
  769. )
  770. def convolution_default(func, *args, **kwargs):
  771. _, new_kwargs = normalize_function(
  772. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  773. )
  774. inp = new_kwargs.pop("input")
  775. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  776. @register_jagged_func(
  777. torch.ops.aten.mean.dim, "self: jt, dim: any?, keepdim: any, dtype: any?"
  778. )
  779. def mean_dim(func, *args, **kwargs):
  780. _, new_kwargs = normalize_function(
  781. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  782. )
  783. inp = new_kwargs.pop("input")
  784. # NB: mean expects dim as a single item list of ints for some reason
  785. new_kwargs["dim"] = [_wrap_jagged_dim(inp.dim(), new_kwargs["dim"][0], "mean")]
  786. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  787. @register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
  788. def stack_default(func, *args, **kwargs):
  789. _, new_kwargs = normalize_function(
  790. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  791. )
  792. # guaranteed this is non-empty if we got here
  793. tensors = new_kwargs.pop("tensors")
  794. for t in tensors:
  795. if not isinstance(t, NestedTensor):
  796. raise RuntimeError("stack(): expected all nested tensors inputs")
  797. if t.dim() != tensors[0].dim():
  798. raise RuntimeError(
  799. "stack(): expected all nested tensors to have the same dim"
  800. )
  801. if not raggedness_matches(t, tensors[0].shape):
  802. raise RuntimeError(
  803. "stack(): expected all nested tensors to have the same nested structure"
  804. )
  805. new_kwargs["dim"] = _wrap_jagged_dim(
  806. tensors[0].dim() + 1, new_kwargs["dim"], "stack"
  807. )
  808. return NestedTensor(
  809. func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
  810. )
  811. @register_jagged_func(
  812. torch.ops.aten.embedding.default,
  813. "weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?",
  814. )
  815. def embedding_default(func, *args, **kwargs):
  816. _, new_kwargs = normalize_function(
  817. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  818. )
  819. # guaranteed this is non-empty if we got here
  820. indices = new_kwargs.pop("indices")
  821. weight = new_kwargs.pop("weight")
  822. return NestedTensor(
  823. func(weight, indices._values, **new_kwargs), **extract_kwargs(indices)
  824. )
  825. @register_jagged_func(
  826. [
  827. torch.ops.aten.values.default,
  828. torch.ops.aten._nested_get_values.default,
  829. ],
  830. "self: jt_all",
  831. )
  832. def values_default(func, *args, **kwargs):
  833. _, new_kwargs = normalize_function(
  834. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  835. )
  836. inp = new_kwargs.pop("input")
  837. # TODO: Handle inference mode properly.
  838. # See https://github.com/pytorch/pytorch/issues/112024#issuecomment-1779554292
  839. return inp._values.detach()
  840. @register_jagged_func(
  841. torch.ops.aten._nested_view_from_jagged.default,
  842. "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?",
  843. )
  844. def _nested_view_from_jagged_default(func, *args, **kwargs):
  845. _, new_kwargs = normalize_function(
  846. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  847. )
  848. values, offsets, lengths = (
  849. new_kwargs["input"],
  850. new_kwargs["offsets"],
  851. new_kwargs["lengths"],
  852. )
  853. ragged_idx = new_kwargs["ragged_idx"]
  854. return NestedTensor(values, offsets, lengths=lengths, _ragged_idx=ragged_idx)
  855. @register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all")
  856. def _nested_get_offsets(func, *args, **kwargs):
  857. _, new_kwargs = normalize_function(
  858. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  859. )
  860. inp = new_kwargs.pop("input")
  861. return inp._offsets
  862. @register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all")
  863. def _nested_get_lengths(func, *args, **kwargs):
  864. _, new_kwargs = normalize_function(
  865. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  866. )
  867. inp = new_kwargs.pop("input")
  868. return inp._lengths
  869. @register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all")
  870. def _nested_get_ragged_idx(func, *args, **kwargs):
  871. _, new_kwargs = normalize_function(
  872. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  873. )
  874. inp = new_kwargs.pop("input")
  875. return inp._ragged_idx
  876. # Make the dummy available on the C++ side.
  877. @register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any")
  878. def _nested_get_jagged_dummy(func, *args, **kwargs):
  879. from torch.nested._internal.nested_tensor import _nt_view_dummy
  880. return _nt_view_dummy()
  881. with torch.library._scoped_library("aten", "IMPL") as aten:
  882. aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CPU")
  883. aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CUDA")
  884. aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "Meta")