__init__.py 83 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import itertools
  4. import operator
  5. import weakref
  6. from enum import Enum
  7. from functools import partial, reduce
  8. from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
  9. import torch
  10. import torch._prims_common as utils
  11. import torch.library
  12. from torch import sym_float, Tensor, TypedStorage
  13. from torch._C import _get_default_device
  14. from torch._library.utils import is_functional_schema
  15. from torch._prims.debug_prims import register_debug_prims
  16. from torch._prims.rng_prims import register_rng_prims
  17. from torch._prims_common import (
  18. Dim,
  19. DimsSequenceType,
  20. DimsType,
  21. IntLike,
  22. Number,
  23. NumberType,
  24. RETURN_TYPE,
  25. ShapeType,
  26. StrideType,
  27. TensorLike,
  28. TensorLikeType,
  29. type_to_dtype,
  30. )
  31. from torch._prims_common.wrappers import backwards_not_supported
  32. from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
  33. from torch.overrides import handle_torch_function, has_torch_function
  34. from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
  35. prim = torch.library.Library("prims", "DEF")
  36. prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd")
  37. prim_backend_select_impl = torch.library.Library("prims", "IMPL", "BackendSelect")
  38. prim_autograd_impl = torch.library.Library("prims", "IMPL", "Autograd")
  39. prim_meta_impl = torch.library.Library("prims", "IMPL", "Meta")
  40. # Experimental module containing prototype "primitive" operations.
  41. __all__ = [
  42. #
  43. # Common datastructures and helpers
  44. #
  45. "RETURN_TYPE",
  46. #
  47. # Elementwise unary prims
  48. #
  49. "abs",
  50. "acos",
  51. "acosh",
  52. "asin",
  53. "asinh",
  54. "atan",
  55. "atanh",
  56. "cos",
  57. "cosh",
  58. "bessel_i0",
  59. "bessel_i0e",
  60. "bessel_i1",
  61. "bessel_i1e",
  62. "bessel_j0",
  63. "bessel_j1",
  64. "bitwise_not",
  65. "cbrt",
  66. "ceil",
  67. "conj_physical",
  68. "digamma",
  69. "erf",
  70. "erf_inv",
  71. "erfc",
  72. "erfcx",
  73. "exp",
  74. "expm1",
  75. "exp2",
  76. "fill",
  77. "floor",
  78. "imag",
  79. "isfinite",
  80. "lgamma",
  81. "log",
  82. "log1p",
  83. "log2",
  84. "log10",
  85. "ndtri",
  86. "neg",
  87. "real",
  88. "reciprocal",
  89. "round",
  90. "sign",
  91. "signbit",
  92. "sin",
  93. "sinh",
  94. "spherical_bessel_j0",
  95. "sqrt",
  96. "tan",
  97. "tanh",
  98. "trunc",
  99. #
  100. # Elementwise binary prims
  101. #
  102. "add",
  103. "atan2",
  104. "bitwise_and",
  105. "bitwise_or",
  106. "bitwise_xor",
  107. # 'complex', # needs custom meta
  108. "div",
  109. "eq",
  110. "fmax",
  111. "fmin",
  112. "fmod",
  113. "frexp",
  114. "gcd",
  115. "ge",
  116. "gt",
  117. "hypot",
  118. "igamma",
  119. "igammac",
  120. "le",
  121. "lt",
  122. "maximum",
  123. "minimum",
  124. "mul",
  125. "ne",
  126. "nextafter",
  127. "pow",
  128. "remainder",
  129. "rsqrt",
  130. "shift_left",
  131. "shift_right_arithmetic",
  132. "shift_right_logical", # not implemented
  133. "sub",
  134. "zeta",
  135. #
  136. # View prims
  137. #
  138. "as_strided",
  139. "broadcast_in_dim",
  140. "collapse_view",
  141. "conj",
  142. "expand_dims",
  143. "slice",
  144. "slice_in_dim", # implemented using slice -- make this a ref?
  145. "split_dim",
  146. "squeeze",
  147. "transpose",
  148. "view_of",
  149. "view_element_type",
  150. #
  151. # Functionalized view mutations
  152. #
  153. "as_strided_scatter",
  154. #
  155. # Shape prims
  156. #
  157. "collapse",
  158. "cat",
  159. "reshape",
  160. "rev",
  161. #
  162. # Conditional prims
  163. #
  164. "where",
  165. #
  166. # Data conversion and movement prims
  167. #
  168. "clone",
  169. "convert_element_type",
  170. "device_put",
  171. "item",
  172. "maximum_value",
  173. "minimum_value",
  174. "copy_strided",
  175. #
  176. # Inplace prims
  177. #
  178. "copy_to",
  179. "resize",
  180. # "_set", # Commented out, see note below
  181. #
  182. # Reduction prims
  183. #
  184. "amax",
  185. "amin",
  186. "prod",
  187. "sum",
  188. "xor_sum",
  189. "var",
  190. #
  191. # Tensor Creation Prims
  192. #
  193. "empty_strided",
  194. "empty_permuted",
  195. "scalar_tensor",
  196. "iota",
  197. #
  198. # Linear algebra (linalg) Prims
  199. #
  200. "svd",
  201. #
  202. # Randomness Prims
  203. #
  204. "normal",
  205. "_uniform_helper",
  206. #
  207. # FFT prims
  208. #
  209. "fft_r2c",
  210. "fft_c2c",
  211. "fft_c2r",
  212. #
  213. # prims for making/sinking tokens
  214. #
  215. "_make_token",
  216. "_sink_tokens",
  217. ]
  218. def TensorMeta(
  219. tensorlike: Optional[Union[NumberType, torch.Tensor]] = None,
  220. *,
  221. shape: Optional[ShapeType] = None,
  222. strides: Optional[StrideType] = None,
  223. dtype: Optional[torch.dtype] = None,
  224. device: Optional[Union[torch.device, str]] = None,
  225. ):
  226. if isinstance(tensorlike, Number):
  227. assert not shape and (shape is None or isinstance(shape, Sequence))
  228. assert not strides and (strides is None or isinstance(strides, Sequence))
  229. inferred_shape: Tuple[int, ...] = ()
  230. inferred_strides: Tuple[int, ...] = ()
  231. inferred_dtype = type_to_dtype(type(tensorlike))
  232. inferred_device = torch.device("cpu")
  233. # TODO: This looks wrong, a number that is wrapped into a tensor
  234. # needs to behave differently than a scalar tensor for type
  235. # promotion purposes
  236. elif tensorlike is not None:
  237. assert isinstance(tensorlike, torch.Tensor)
  238. inferred_shape = tuple(tensorlike.shape)
  239. inferred_strides = tuple(tensorlike.stride())
  240. inferred_dtype = tensorlike.dtype
  241. inferred_device = tensorlike.device
  242. else:
  243. # If no tensorlike "example" is given then all metadata
  244. # must be provided explicitly
  245. assert shape is not None
  246. assert strides is not None
  247. assert dtype is not None
  248. assert device is not None
  249. shape = inferred_shape if shape is None else tuple(shape) # type: ignore[possibly-undefined]
  250. strides = inferred_strides if strides is None else tuple(strides) # type: ignore[possibly-undefined]
  251. dtype = inferred_dtype if dtype is None else dtype # type: ignore[possibly-undefined]
  252. device = inferred_device if device is None else device # type: ignore[possibly-undefined]
  253. if isinstance(device, str):
  254. device = torch.device(device)
  255. return torch.empty_strided(shape, strides, dtype=dtype, device=device)
  256. def _make_prim(
  257. *,
  258. schema: str,
  259. return_type: Union[RETURN_TYPE, Tuple[RETURN_TYPE, ...]],
  260. meta: Callable,
  261. impl_aten: Callable,
  262. doc: str,
  263. tags: Optional[Sequence[torch.Tag]] = None,
  264. use_old_custom_ops_api: bool = False,
  265. ):
  266. """
  267. Creates a primitive operation.
  268. """
  269. def _prim_impl(*args, **kwargs):
  270. # always run the meta function because aten implementation will
  271. # typically accept more inputs (e.g., it will do promotion and
  272. # broadcasting) which we want to reject
  273. meta(*args, **kwargs)
  274. return impl_aten(*args, **kwargs)
  275. # Right now prims don't support autograd (we can and should add an
  276. # argument that provides an implementation for backward here.) Because we
  277. # don't have derivative formulas, we must setup a custom autograd function
  278. # that raises an error if backwards is invoked
  279. def _autograd_impl(*args, **kwargs):
  280. return backwards_not_supported(_prim)(*args, **kwargs)
  281. def _backend_select_impl(*args, **kwargs):
  282. if kwargs.get("device") and kwargs["device"].type == "meta":
  283. return meta(*args, **kwargs)
  284. if any(isinstance(x, torch.device) and x.type == "meta" for x in args):
  285. return meta(*args, **kwargs)
  286. else:
  287. return _prim_impl(*args, **kwargs)
  288. name = schema.split("(")[0]
  289. schema = schema[len(name) :]
  290. # register non-functional ops with old custom ops API
  291. cpp_schema = torch._C.parse_schema(name + schema)
  292. if use_old_custom_ops_api or not is_functional_schema(cpp_schema):
  293. prim.define(name + schema, tags=torch.Tag.pt2_compliant_tag)
  294. prim_impl.impl(name, _prim_impl)
  295. prim_autograd_impl.impl(name, _autograd_impl)
  296. prim_meta_impl.impl(name, meta)
  297. else:
  298. mutates_args = []
  299. for arg in cpp_schema.arguments:
  300. if arg.alias_info is not None and arg.alias_info.is_write:
  301. mutates_args.append(arg.name)
  302. prim_def = torch.library.custom_op(
  303. "prims::" + name,
  304. _prim_impl,
  305. mutates_args=tuple(mutates_args),
  306. schema=schema,
  307. )
  308. prim_def.register_fake(meta)
  309. _prim_packet = getattr(torch._ops.ops.prims, name)
  310. _prim = _prim_packet.default
  311. if tags:
  312. _prim._tags = tags
  313. from torch._subclasses.fake_tensor import contains_tensor_types
  314. if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments) or str(
  315. _prim
  316. ) in [
  317. # See https://github.com/pytorch/pytorch/issues/103532
  318. "prims.device_put.default"
  319. ]:
  320. prim_backend_select_impl.impl(name, _backend_select_impl)
  321. for p in (_prim_packet, _prim):
  322. p.__doc__ = doc
  323. p.return_type = return_type # type: ignore[attr-defined]
  324. p.schema = schema
  325. p.prim_impl = _prim_impl
  326. p.prim_meta_impl = meta
  327. p.impl_aten = impl_aten
  328. return _prim
  329. class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum):
  330. DEFAULT = (0,)
  331. INT_TO_FLOAT = (2,)
  332. ALWAYS_BOOL = (3,)
  333. COMPLEX_TO_FLOAT = (4,)
  334. # TODO: implement dtype validation here, too, or on the corresponding refs
  335. def _prim_elementwise_meta(
  336. *args,
  337. type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
  338. args_with_fixed_dtypes: Optional[Tuple[TensorLikeType, ...]] = None,
  339. ) -> FakeTensor:
  340. """
  341. Meta function for elementwise operations that produce outputs in the same dtype
  342. as their inputs.
  343. Stride logic is currently incorrect.
  344. """
  345. assert len(args) > 0
  346. utils.check_same_dtype(*args)
  347. args_ = list(args)
  348. if args_with_fixed_dtypes is not None:
  349. args_ = list(args_with_fixed_dtypes) + args_
  350. utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
  351. utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
  352. l2p_perm = utils.compute_elementwise_output_logical_to_physical_perm(*args_)
  353. shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
  354. # Acquires the dtype
  355. dtype = None
  356. scalar_type = None
  357. for arg in args:
  358. if isinstance(arg, TensorLike):
  359. if not utils.is_cpu_scalar_tensor(arg):
  360. dtype = arg.dtype
  361. break
  362. else:
  363. dtype = arg.dtype
  364. elif isinstance(arg, Number):
  365. scalar_type = type(arg)
  366. if dtype is None and scalar_type is not None:
  367. dtype = utils.type_to_dtype(scalar_type)
  368. # Acquires the device (if it exists) or number
  369. device = None
  370. number = None
  371. for arg in args_:
  372. if isinstance(arg, TensorLike):
  373. if utils.is_cpu_scalar_tensor(arg):
  374. if device is None:
  375. device = arg.device
  376. # keep going, in case there is a cuda tensor later
  377. else:
  378. device = arg.device
  379. break
  380. elif isinstance(arg, Number):
  381. if number is None:
  382. number = arg
  383. # NOTE: type promotion behavior here is mostly hidden from tests because
  384. # references will typically handle the type promotion properly even if this doesn't
  385. # (but getting it wrong will cause too many casts to be inserted in traces!)
  386. if device is not None:
  387. assert dtype is not None
  388. if type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT:
  389. dtype = dtype
  390. elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
  391. dtype = torch.bool
  392. elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
  393. if utils.is_integer_dtype(dtype) or utils.is_boolean_dtype(dtype):
  394. dtype = torch.get_default_dtype()
  395. elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
  396. if utils.is_complex_dtype(dtype):
  397. dtype = utils.corresponding_real_dtype(dtype)
  398. else:
  399. dtype = dtype
  400. assert shape is not None
  401. return torch.empty_permuted(shape, l2p_perm, device=device, dtype=dtype) # type: ignore[return-value]
  402. # Number case
  403. # TODO: fix number type promotion (bool, complex->float)
  404. # For now for symint/float, just implementing the common / simple cases of (int,float,symint,symfloat)
  405. seen_float = False
  406. if isinstance(number, (torch.SymInt, torch.SymFloat)):
  407. for a in args:
  408. assert isinstance(a, (int, float, torch.SymInt, torch.SymFloat)), "NYI"
  409. seen_float = seen_float or isinstance(a, (float, torch.SymFloat))
  410. if seen_float:
  411. number = sym_float(number)
  412. return TensorMeta(number) # type: ignore[arg-type]
  413. def _complex_only_elementwise_meta(*args, **kwargs):
  414. torch._check(
  415. utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported"
  416. )
  417. return _prim_elementwise_meta(*args, **kwargs)
  418. def _make_elementwise_unary_prim(
  419. name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs
  420. ):
  421. """
  422. Creates an elementwise unary prim.
  423. """
  424. return _make_prim(
  425. schema=f"{name}(Tensor self) -> Tensor",
  426. meta=partial(_prim_elementwise_meta, type_promotion=type_promotion),
  427. return_type=RETURN_TYPE.NEW,
  428. **kwargs,
  429. )
  430. def _make_elementwise_binary_prim(
  431. name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs
  432. ):
  433. """
  434. Creates an elementwise binary prim.
  435. """
  436. return _make_prim(
  437. schema=f"{name}(Tensor self, Tensor other) -> Tensor",
  438. meta=partial(_prim_elementwise_meta, type_promotion=type_promotion),
  439. return_type=RETURN_TYPE.NEW,
  440. **kwargs,
  441. )
  442. def _not_impl(*args, **kwargs):
  443. raise NotImplementedError
  444. #
  445. # Elementwise unary operations
  446. #
  447. abs = _make_elementwise_unary_prim(
  448. "abs",
  449. impl_aten=torch.abs,
  450. doc="",
  451. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
  452. )
  453. acos = _make_elementwise_unary_prim(
  454. "acos",
  455. impl_aten=torch.acos,
  456. doc="",
  457. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  458. )
  459. acosh = _make_elementwise_unary_prim(
  460. "acosh",
  461. impl_aten=torch.acosh,
  462. doc="",
  463. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  464. )
  465. asin = _make_elementwise_unary_prim(
  466. "asin",
  467. impl_aten=torch.asin,
  468. doc="",
  469. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  470. )
  471. asinh = _make_elementwise_unary_prim(
  472. "asinh",
  473. impl_aten=torch.asinh,
  474. doc="",
  475. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  476. )
  477. atan = _make_elementwise_unary_prim(
  478. "atan",
  479. impl_aten=torch.atan,
  480. doc="",
  481. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  482. )
  483. atanh = _make_elementwise_unary_prim(
  484. "atanh",
  485. impl_aten=torch.atanh,
  486. doc="",
  487. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  488. )
  489. cos = _make_elementwise_unary_prim(
  490. "cos",
  491. impl_aten=torch.cos,
  492. doc="",
  493. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  494. )
  495. cosh = _make_elementwise_unary_prim(
  496. "cosh",
  497. impl_aten=torch.cosh,
  498. doc="",
  499. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  500. )
  501. bessel_j0 = _make_elementwise_unary_prim(
  502. "bessel_j0",
  503. impl_aten=torch.special.bessel_j0,
  504. doc="",
  505. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  506. )
  507. bessel_j1 = _make_elementwise_unary_prim(
  508. "bessel_j1",
  509. impl_aten=torch.special.bessel_j1,
  510. doc="",
  511. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  512. )
  513. bessel_i0 = _make_elementwise_unary_prim(
  514. "bessel_i0",
  515. impl_aten=torch.i0,
  516. doc="",
  517. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  518. )
  519. bessel_i0e = _make_elementwise_unary_prim(
  520. "bessel_i0e",
  521. impl_aten=torch.special.i0e,
  522. doc="",
  523. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  524. )
  525. bessel_i1 = _make_elementwise_unary_prim(
  526. "bessel_i1",
  527. impl_aten=torch.special.i1,
  528. doc="",
  529. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  530. )
  531. bessel_i1e = _make_elementwise_unary_prim(
  532. "bessel_i1e",
  533. impl_aten=torch.special.i1e,
  534. doc="",
  535. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  536. )
  537. bitwise_not = _make_elementwise_unary_prim(
  538. "bitwise_not",
  539. impl_aten=torch.bitwise_not,
  540. doc="",
  541. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  542. )
  543. def _cbrt_aten(a: torch.Tensor) -> Tensor:
  544. torch._check(
  545. not a.is_complex(),
  546. lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)",
  547. )
  548. # Returns the real cubic root of the number.
  549. # Note that if a < 0, pow(a, (1. / 3.)) returns th complex number
  550. # exp(1/3 * log(a)) = exp(1/3 * (log(abs(a)) + pi*i)) = cbrt(abs(a)) * e^{pi/3*i}
  551. # which is a complex number.
  552. # For more info see the section Note in
  553. # https://en.cppreference.com/w/cpp/numeric/math/cbrt
  554. return torch.copysign(torch.pow(a.abs(), 1 / 3), a)
  555. cbrt = _make_elementwise_unary_prim(
  556. "cbrt",
  557. impl_aten=_cbrt_aten,
  558. doc="",
  559. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  560. )
  561. ceil = _make_elementwise_unary_prim(
  562. "ceil",
  563. impl_aten=torch.ceil,
  564. doc="",
  565. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  566. )
  567. def _conj_physical_meta(input: TensorLikeType) -> TensorLikeType:
  568. if not input.dtype.is_complex:
  569. raise RuntimeError("prims.conj_physical is only defined for complex dtypes")
  570. strides = utils.compute_elementwise_output_strides(input)
  571. return TensorMeta(input, strides=strides)
  572. conj_physical = _make_prim(
  573. schema="conj_physical(Tensor self) -> Tensor",
  574. meta=_conj_physical_meta,
  575. impl_aten=torch._conj_physical,
  576. doc="Returns the physical conjugation of a complex tensor",
  577. return_type=RETURN_TYPE.NEW,
  578. )
  579. def _clone_meta(
  580. input: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
  581. ) -> TensorLikeType:
  582. if memory_format != torch.preserve_format:
  583. return torch.empty(
  584. input.shape,
  585. dtype=input.dtype,
  586. layout=input.layout,
  587. device=input.device,
  588. memory_format=memory_format,
  589. )
  590. # memory_format == torch.preserve_format
  591. strides = utils.compute_elementwise_output_strides(input)
  592. return torch.empty_strided(
  593. input.shape,
  594. strides,
  595. dtype=input.dtype,
  596. layout=input.layout,
  597. device=input.device,
  598. )
  599. clone = _make_prim(
  600. schema="clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
  601. meta=_clone_meta,
  602. impl_aten=torch.clone,
  603. doc="Returns the copy of a tensor",
  604. return_type=RETURN_TYPE.NEW,
  605. )
  606. digamma = _make_elementwise_unary_prim(
  607. "digamma",
  608. impl_aten=torch.digamma,
  609. doc="",
  610. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  611. )
  612. erf = _make_elementwise_unary_prim(
  613. "erf",
  614. impl_aten=torch.erf,
  615. doc="",
  616. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  617. )
  618. erf_inv = _make_elementwise_unary_prim(
  619. "erf_inv",
  620. impl_aten=torch.special.erfinv,
  621. doc="",
  622. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  623. )
  624. erfc = _make_elementwise_unary_prim(
  625. "erfc",
  626. impl_aten=torch.special.erfc,
  627. doc="",
  628. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  629. )
  630. erfcx = _make_elementwise_unary_prim(
  631. "erfcx",
  632. impl_aten=torch.special.erfcx,
  633. doc="",
  634. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  635. )
  636. exp = _make_elementwise_unary_prim(
  637. "exp",
  638. impl_aten=torch.exp,
  639. doc="",
  640. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  641. )
  642. expm1 = _make_elementwise_unary_prim(
  643. "expm1",
  644. impl_aten=torch.special.expm1,
  645. doc="",
  646. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  647. )
  648. exp2 = _make_elementwise_unary_prim(
  649. "exp2",
  650. impl_aten=torch.special.exp2,
  651. doc="",
  652. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  653. )
  654. def _fill_meta(a: TensorLikeType, value: NumberType) -> TensorLikeType:
  655. return _prim_elementwise_meta(
  656. a, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
  657. )
  658. # NOTE: fill uses _make_prim directly because it has a value parameter
  659. fill = _make_prim(
  660. schema="fill(Tensor self, Scalar value) -> Tensor",
  661. return_type=RETURN_TYPE.NEW,
  662. meta=_fill_meta,
  663. impl_aten=torch.fill,
  664. doc="",
  665. )
  666. floor = _make_elementwise_unary_prim(
  667. "floor",
  668. impl_aten=torch.floor,
  669. doc="",
  670. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  671. )
  672. imag = _make_prim(
  673. schema="imag(Tensor(a) self) -> Tensor(a)",
  674. meta=partial(
  675. _complex_only_elementwise_meta,
  676. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
  677. ),
  678. return_type=RETURN_TYPE.VIEW,
  679. impl_aten=torch.imag,
  680. doc="",
  681. )
  682. isfinite = _make_elementwise_unary_prim(
  683. "isfinite",
  684. impl_aten=torch.isfinite,
  685. doc="",
  686. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  687. )
  688. lgamma = _make_elementwise_unary_prim(
  689. "lgamma",
  690. impl_aten=torch.lgamma,
  691. doc="",
  692. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  693. )
  694. log = _make_elementwise_unary_prim(
  695. "log",
  696. impl_aten=torch.log,
  697. doc="",
  698. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  699. )
  700. log1p = _make_elementwise_unary_prim(
  701. "log1p",
  702. impl_aten=torch.log1p,
  703. doc="",
  704. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  705. )
  706. log2 = _make_elementwise_unary_prim(
  707. "log2",
  708. impl_aten=torch.log2,
  709. doc="",
  710. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  711. )
  712. log10 = _make_elementwise_unary_prim(
  713. "log10",
  714. impl_aten=torch.log10,
  715. doc="",
  716. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  717. )
  718. real = _make_prim(
  719. schema="real(Tensor(a) self) -> Tensor(a)",
  720. meta=partial(
  721. _complex_only_elementwise_meta,
  722. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
  723. ),
  724. return_type=RETURN_TYPE.VIEW,
  725. impl_aten=torch.real,
  726. doc="",
  727. )
  728. reciprocal = _make_elementwise_unary_prim(
  729. "reciprocal",
  730. impl_aten=torch.reciprocal,
  731. doc="",
  732. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  733. )
  734. ndtri = _make_elementwise_unary_prim(
  735. "ndtri",
  736. impl_aten=torch.special.ndtri,
  737. doc="",
  738. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  739. )
  740. neg = _make_elementwise_unary_prim(
  741. "neg",
  742. impl_aten=torch.neg,
  743. doc="",
  744. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  745. )
  746. round = _make_elementwise_unary_prim(
  747. "round",
  748. impl_aten=torch.round,
  749. doc="",
  750. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  751. )
  752. rsqrt = _make_elementwise_unary_prim(
  753. "rsqrt",
  754. impl_aten=torch.rsqrt,
  755. doc="",
  756. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  757. )
  758. sign = _make_elementwise_unary_prim(
  759. "sign",
  760. impl_aten=torch.sign,
  761. doc="",
  762. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  763. )
  764. signbit = _make_elementwise_unary_prim(
  765. "signbit",
  766. impl_aten=torch.signbit,
  767. doc="",
  768. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  769. )
  770. sin = _make_elementwise_unary_prim(
  771. "sin",
  772. impl_aten=torch.sin,
  773. doc="",
  774. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  775. )
  776. sinh = _make_elementwise_unary_prim(
  777. "sinh",
  778. impl_aten=torch.sinh,
  779. doc="",
  780. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  781. )
  782. spherical_bessel_j0 = _make_elementwise_unary_prim(
  783. "spherical_bessel_j0",
  784. impl_aten=torch.special.spherical_bessel_j0,
  785. doc="",
  786. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  787. )
  788. sqrt = _make_elementwise_unary_prim(
  789. "sqrt",
  790. impl_aten=torch.sqrt,
  791. doc="",
  792. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  793. )
  794. tan = _make_elementwise_unary_prim(
  795. "tan",
  796. impl_aten=torch.tan,
  797. doc="",
  798. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  799. )
  800. tanh = _make_elementwise_unary_prim(
  801. "tanh",
  802. impl_aten=torch.tanh,
  803. doc="",
  804. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  805. )
  806. trunc = _make_elementwise_unary_prim(
  807. "trunc",
  808. impl_aten=torch.trunc,
  809. doc="",
  810. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  811. )
  812. #
  813. # Elementwise binary operations
  814. #
  815. add = _make_elementwise_binary_prim(
  816. name="add",
  817. impl_aten=torch.add,
  818. doc="",
  819. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  820. )
  821. atan2 = _make_elementwise_binary_prim(
  822. name="atan2",
  823. impl_aten=torch.atan2,
  824. doc="",
  825. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  826. )
  827. bitwise_and = _make_elementwise_binary_prim(
  828. "bitwise_and",
  829. impl_aten=torch.bitwise_and,
  830. doc="",
  831. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  832. )
  833. bitwise_or = _make_elementwise_binary_prim(
  834. "bitwise_or",
  835. impl_aten=torch.bitwise_or,
  836. doc="",
  837. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  838. )
  839. bitwise_xor = _make_elementwise_binary_prim(
  840. "bitwise_xor",
  841. impl_aten=torch.bitwise_xor,
  842. doc="",
  843. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  844. )
  845. # TODO: complex needs a special meta to account for its float -> complex behavior
  846. # complex = _make_elementwise_binary_prim(
  847. # impl_aten=torch.complex,
  848. # doc="",
  849. # )
  850. # div prim performs truncation division on integer inputs
  851. # and true division for floating and complex inputs
  852. def _div_aten(a, b):
  853. is_integral = isinstance(a, (bool, int, torch.SymInt)) or (
  854. isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype)
  855. )
  856. if is_integral:
  857. return torch.div(a, b, rounding_mode="trunc")
  858. else:
  859. return torch.true_divide(a, b)
  860. div = _make_elementwise_binary_prim(
  861. "div",
  862. impl_aten=_div_aten,
  863. doc="",
  864. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  865. )
  866. eq = _make_elementwise_binary_prim(
  867. "eq",
  868. impl_aten=torch.eq,
  869. doc="",
  870. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  871. )
  872. fmax = _make_elementwise_binary_prim(
  873. "fmax",
  874. impl_aten=torch.fmax,
  875. doc="",
  876. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  877. )
  878. fmin = _make_elementwise_binary_prim(
  879. "fmin",
  880. impl_aten=torch.fmin,
  881. doc="",
  882. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  883. )
  884. fmod = _make_elementwise_binary_prim(
  885. "fmod",
  886. impl_aten=torch.fmod,
  887. doc="",
  888. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  889. )
  890. gcd = _make_elementwise_binary_prim(
  891. "gcd",
  892. impl_aten=torch.gcd,
  893. doc="",
  894. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  895. )
  896. ge = _make_elementwise_binary_prim(
  897. "ge",
  898. impl_aten=torch.ge,
  899. doc="",
  900. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  901. )
  902. gt = _make_elementwise_binary_prim(
  903. "gt",
  904. impl_aten=torch.gt,
  905. doc="",
  906. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  907. )
  908. hypot = _make_elementwise_binary_prim(
  909. "hypot",
  910. impl_aten=torch.hypot,
  911. doc="",
  912. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  913. )
  914. igamma = _make_elementwise_binary_prim(
  915. "igamma",
  916. impl_aten=torch.special.gammainc,
  917. doc="",
  918. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  919. )
  920. igammac = _make_elementwise_binary_prim(
  921. "igammac",
  922. impl_aten=torch.special.gammaincc,
  923. doc="",
  924. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  925. )
  926. le = _make_elementwise_binary_prim(
  927. "le",
  928. impl_aten=torch.le,
  929. doc="",
  930. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  931. )
  932. lt = _make_elementwise_binary_prim(
  933. "lt",
  934. impl_aten=torch.lt,
  935. doc="",
  936. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  937. )
  938. # Note: the following impls are because torch.maximum and torch.minimum do not support scalar inputs
  939. def _maximum_aten(
  940. a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
  941. ) -> TensorLikeType:
  942. if isinstance(a, TensorLike) and isinstance(b, Number):
  943. b = scalar_tensor(b, dtype=a.dtype, device=a.device)
  944. elif isinstance(b, TensorLike) and isinstance(a, Number):
  945. a = scalar_tensor(a, dtype=b.dtype, device=b.device)
  946. return torch.maximum(a, b) # type: ignore[arg-type]
  947. maximum = _make_elementwise_binary_prim(
  948. "maximum",
  949. impl_aten=_maximum_aten,
  950. doc="",
  951. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  952. )
  953. def _minimum_aten(
  954. a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
  955. ) -> TensorLikeType:
  956. if isinstance(a, TensorLike) and isinstance(b, Number):
  957. b = scalar_tensor(b, dtype=a.dtype, device=a.device)
  958. elif isinstance(b, TensorLike) and isinstance(a, Number):
  959. a = scalar_tensor(a, dtype=b.dtype, device=b.device)
  960. return torch.minimum(a, b) # type: ignore[arg-type]
  961. minimum = _make_elementwise_binary_prim(
  962. "minimum",
  963. impl_aten=_minimum_aten,
  964. doc="",
  965. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  966. )
  967. mul = _make_elementwise_binary_prim(
  968. "mul",
  969. impl_aten=torch.mul,
  970. doc="",
  971. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  972. )
  973. ne = _make_elementwise_binary_prim(
  974. "ne",
  975. impl_aten=torch.ne,
  976. doc="",
  977. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  978. )
  979. nextafter = _make_elementwise_binary_prim(
  980. "nextafter",
  981. impl_aten=torch.nextafter,
  982. doc="",
  983. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  984. )
  985. pow = _make_elementwise_binary_prim(
  986. "pow",
  987. impl_aten=torch.pow,
  988. doc="",
  989. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  990. )
  991. remainder = _make_elementwise_binary_prim(
  992. "remainder",
  993. impl_aten=torch.remainder,
  994. doc="",
  995. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  996. )
  997. shift_left = _make_elementwise_binary_prim(
  998. "shift_left",
  999. impl_aten=torch.bitwise_left_shift,
  1000. doc="",
  1001. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  1002. )
  1003. shift_right_arithmetic = _make_elementwise_binary_prim(
  1004. "shift_right_arithmetic",
  1005. impl_aten=torch.bitwise_right_shift,
  1006. doc="",
  1007. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  1008. )
  1009. shift_right_logical = _not_impl
  1010. sub = _make_elementwise_binary_prim(
  1011. "sub",
  1012. impl_aten=torch.sub,
  1013. doc="",
  1014. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  1015. )
  1016. zeta = _make_elementwise_binary_prim(
  1017. "zeta",
  1018. impl_aten=torch.special.zeta,
  1019. doc="",
  1020. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  1021. )
  1022. #
  1023. # View operations
  1024. def _as_strided_meta(
  1025. a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int
  1026. ) -> TensorLikeType:
  1027. assert len(size) == len(stride)
  1028. assert storage_offset >= 0
  1029. utils.validate_strides(stride)
  1030. utils.validate_shape(size)
  1031. if reduce(operator.mul, size) == 0:
  1032. # NOTE: This special case is to avoid having to acquire the storage below
  1033. # as_strided to shapes with no elements are trivially valid, so it's OK
  1034. pass
  1035. elif isinstance(a, torch.Tensor):
  1036. utils.check_in_bounds_for_storage(
  1037. a._typed_storage(), size, stride, storage_offset
  1038. )
  1039. return torch.as_strided(a, size, stride, storage_offset)
  1040. def _as_strided_aten(
  1041. a: Tensor, size: ShapeType, stride: StrideType, storage_offset: int
  1042. ) -> Tensor:
  1043. return torch.as_strided(a, size, stride, storage_offset)
  1044. _as_strided_doc = """
  1045. Creates a view of the tensor with the given shape (size), strides (stride) and
  1046. storage offset (storage_offset).
  1047. """
  1048. as_strided = _make_prim(
  1049. schema="as_strided(Tensor(a!) a, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor(a!)",
  1050. meta=_as_strided_meta,
  1051. impl_aten=_as_strided_aten,
  1052. return_type=RETURN_TYPE.VIEW,
  1053. doc=_as_strided_doc,
  1054. )
  1055. def _broadcast_in_dim_meta(
  1056. a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int]
  1057. ):
  1058. from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
  1059. # Type checks
  1060. assert isinstance(a, TensorLike)
  1061. assert isinstance(shape, Sequence)
  1062. assert isinstance(broadcast_dimensions, Sequence)
  1063. # every dimension must be accounted for
  1064. assert a.ndim == len(broadcast_dimensions)
  1065. # broadcast shape must have weakly more dimensions
  1066. assert len(shape) >= a.ndim
  1067. # broadcast_dimensions must be an ascending sequence
  1068. # (no relative reordering of dims) of integers and
  1069. # each dimension must be within the new shape
  1070. def _greater_than_reduce(acc, x):
  1071. assert isinstance(x, Dim)
  1072. assert x > acc
  1073. assert x < len(shape)
  1074. return x
  1075. reduce(_greater_than_reduce, broadcast_dimensions, -1)
  1076. # shape must be broadcastable to
  1077. for idx, new_idx in enumerate(broadcast_dimensions):
  1078. if not guard_size_oblivious(a.shape[idx] == 1):
  1079. torch._check(
  1080. a.shape[idx] == shape[new_idx],
  1081. lambda: f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}",
  1082. )
  1083. new_strides = []
  1084. original_idx = 0
  1085. for idx in range(len(shape)):
  1086. if idx in broadcast_dimensions:
  1087. # Assigns a stride of zero to dimensions
  1088. # which were actually broadcast
  1089. if guard_size_oblivious(a.shape[original_idx] != shape[idx]):
  1090. new_strides.append(0)
  1091. else:
  1092. new_strides.append(a.stride()[original_idx])
  1093. original_idx = original_idx + 1
  1094. else:
  1095. if guard_size_oblivious(shape[idx] != 1):
  1096. new_strides.append(0)
  1097. elif original_idx == a.ndim:
  1098. new_strides.append(1)
  1099. else:
  1100. new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
  1101. return a.as_strided(shape, new_strides, a.storage_offset())
  1102. def _broadcast_in_dim_aten(a, shape, broadcast_dimensions):
  1103. s = list(shape)
  1104. for broadcast_dimension in broadcast_dimensions:
  1105. s[broadcast_dimension] = -1
  1106. v = a
  1107. for idx, x in enumerate(s):
  1108. if x != -1:
  1109. v = v.unsqueeze(idx)
  1110. return v.expand(shape)
  1111. _broadcast_in_dim_doc = """
  1112. Creates a view of a with the specified shape.
  1113. Allows adding dimensions of any length and broadcasting
  1114. dimensions of length one in a to any length.
  1115. The location of the broadcast dimensions must be specified
  1116. using the broadcast_dimensions argument. Changing the
  1117. relative order of dimensions is not supported.
  1118. """
  1119. broadcast_in_dim = _make_prim(
  1120. schema="broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)",
  1121. meta=_broadcast_in_dim_meta,
  1122. impl_aten=_broadcast_in_dim_aten,
  1123. return_type=RETURN_TYPE.VIEW,
  1124. doc=_broadcast_in_dim_doc,
  1125. )
  1126. def _validate_collapse_args(a: Tensor, start: int, end: int) -> None:
  1127. # Special-case for zero dimensional tensors
  1128. ndim = max(1, a.dim())
  1129. utils.validate_idx(ndim, start)
  1130. utils.validate_idx(ndim, end)
  1131. # Verifies end is strictly greater than start
  1132. # (Collapse requires a non-empty interval)
  1133. torch._check_value(
  1134. end >= start,
  1135. lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!",
  1136. )
  1137. def _collapsed_shape(shape: ShapeType, start: int, end: int) -> Tuple[int, ...]:
  1138. """
  1139. Returns the shape of a with dims in [start, end) merged into a single dimension.
  1140. """
  1141. # Special-case for zero dimensional tensors
  1142. shape = (1,) if len(shape) == 0 else tuple(shape)
  1143. dim_length = 1
  1144. for s in shape[start : end + 1]:
  1145. dim_length = dim_length * s
  1146. return shape[0:start] + (dim_length,) + shape[end + 1 :]
  1147. def _collapse_view_helper(
  1148. a: TensorLikeType, start: int, end: int
  1149. ) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
  1150. assert isinstance(a, TensorLike)
  1151. from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
  1152. _validate_collapse_args(a, start, end)
  1153. # Special-case for zero dimensional tensors
  1154. if a.ndim == 0:
  1155. shape = (1,)
  1156. strides = (1,)
  1157. else:
  1158. shape = a.shape # type: ignore[assignment]
  1159. strides = a.stride() # type: ignore[assignment]
  1160. if a.ndim == 0 or (end == start):
  1161. return shape, strides
  1162. length = shape[end]
  1163. stride = strides[end]
  1164. for idx in range(end - 1, start - 1, -1):
  1165. if guard_size_oblivious(shape[idx] == 0) or guard_size_oblivious(
  1166. shape[idx + 1] == 0
  1167. ):
  1168. length = 0
  1169. stride = 0
  1170. break
  1171. if guard_size_oblivious(shape[idx] == 1):
  1172. continue
  1173. length = length * shape[idx]
  1174. if guard_size_oblivious(stride < strides[idx]):
  1175. stride = stride
  1176. else:
  1177. stride = strides[idx]
  1178. if (
  1179. guard_size_oblivious(a.numel() > 0)
  1180. and guard_size_oblivious(shape[idx + 1] != 1)
  1181. and not guard_size_oblivious(
  1182. strides[idx] == strides[idx + 1] * shape[idx + 1]
  1183. )
  1184. ):
  1185. return None, None
  1186. new_shape = shape[:start] + (length,) + shape[end + 1 :]
  1187. new_strides = strides[:start] + (stride,) + strides[end + 1 :]
  1188. # NOTE: when the input has no elements it's restrided as if it were contiguous
  1189. if guard_size_oblivious(a.numel() == 0):
  1190. new_strides = utils.make_contiguous_strides_for(new_shape)
  1191. return new_shape, new_strides
  1192. def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType:
  1193. new_shape, new_strides = _collapse_view_helper(a, start, end)
  1194. if new_shape is None:
  1195. msg = "Attempting to view a collapsed tensor, but no such view exists!"
  1196. raise ValueError(msg)
  1197. assert new_strides is not None
  1198. return a.as_strided(new_shape, new_strides, a.storage_offset())
  1199. def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor:
  1200. new_shape = _collapsed_shape(a.shape, start, end)
  1201. return a.view(new_shape)
  1202. _collapse_view_doc = """
  1203. Creates a view of a with the dimensions between
  1204. start (inclusive) and end (exclusive) merged into a
  1205. single dimension.
  1206. If it's not possible to take such a view then an error
  1207. is thrown. See collapse instead.
  1208. The dimensions can be merged if and only if
  1209. they are all "nested" with each other. That is, they all
  1210. have the property that
  1211. stride[i] = stride[i+1] * shape[i+1]
  1212. for all i in [start, end - 1).
  1213. """
  1214. collapse_view = _make_prim(
  1215. schema="collapse_view(Tensor(a) a, int start, int end) -> Tensor(a)",
  1216. meta=_collapse_view_meta,
  1217. impl_aten=_collapse_view_aten,
  1218. return_type=RETURN_TYPE.VIEW,
  1219. doc=_collapse_view_doc,
  1220. )
  1221. def _conj_meta(a: TensorLikeType) -> TensorLikeType:
  1222. if not a.dtype.is_complex:
  1223. raise RuntimeError("Expected complex dtype in prims.conj")
  1224. out = a.as_strided(a.shape, a.stride(), a.storage_offset())
  1225. torch._C._set_conj(out, not a.is_conj())
  1226. return out
  1227. _conj_doc = """
  1228. Returns a conjugated view of the original tensor
  1229. """
  1230. conj = _make_prim(
  1231. schema="conj(Tensor(a) a) -> Tensor(a)",
  1232. meta=_conj_meta,
  1233. impl_aten=torch.conj,
  1234. return_type=RETURN_TYPE.VIEW,
  1235. doc=_conj_doc,
  1236. )
  1237. def expand_dims(
  1238. a: TensorLikeType, dimensions: DimsSequenceType, ndim=None
  1239. ) -> TensorLikeType:
  1240. """
  1241. Creates a view of a with a.ndim + len(dimensions) dimensions, with new
  1242. dimensions of length one at the dimensions specified by dimensions.
  1243. """
  1244. if ndim is not None:
  1245. # TODO: this is only here to support the unsqueeze ref
  1246. dims = sorted(utils.canonicalize_dims(ndim, dimensions)) # type: ignore[arg-type]
  1247. else:
  1248. dims = sorted(utils.canonicalize_dims(a.ndim, dimensions)) # type: ignore[arg-type]
  1249. if len(set(dims)) != len(dims):
  1250. msg = f"Received duplicate dimensions to expand in {str(dimensions)}"
  1251. raise ValueError(msg)
  1252. new_shape = list(a.shape)
  1253. for idx in dims:
  1254. new_shape.insert(idx, 1)
  1255. broadcast_dimensions = [
  1256. idx for idx in range(len(new_shape)) if idx not in dimensions
  1257. ]
  1258. return broadcast_in_dim(a, new_shape, broadcast_dimensions)
  1259. # Note: saves the Python slice object because we're about to clobber its name with the slice prim
  1260. pyslice: Type[slice] = slice # type: ignore[has-type]
  1261. def _slice_meta(
  1262. a: TensorLikeType,
  1263. start_indices: DimsSequenceType,
  1264. limit_indices: DimsSequenceType,
  1265. strides: Optional[StrideType] = None,
  1266. ) -> TensorLikeType:
  1267. _strides = strides if strides is not None else [1] * len(start_indices)
  1268. if a.ndim != len(start_indices):
  1269. msg = f"Attempting to slice tensor of rank {a.ndim} with start_indices of length {len(start_indices)}!"
  1270. raise ValueError(msg)
  1271. if a.ndim != len(limit_indices):
  1272. msg = f"Attempting to slice tensor of rank {a.ndim} with limit_indices of length {len(limit_indices)}!"
  1273. raise ValueError(msg)
  1274. if a.ndim != len(_strides):
  1275. msg = f"Attempting to slice tensor of rank {a.ndim} with strides of length {len(limit_indices)}!"
  1276. raise ValueError(msg)
  1277. for x, y in zip(start_indices, a.shape):
  1278. if x < 0:
  1279. msg = f"Attempting to slice a tensor with a negative start index of {x}!"
  1280. raise ValueError(msg)
  1281. if x > y:
  1282. msg = (
  1283. f"Attempting to slice a tensor but a start index in {start_indices} is greater than"
  1284. f" the length of its corresponding dimension in shape {a.shape}"
  1285. )
  1286. raise ValueError(msg)
  1287. for x, y, z in zip(limit_indices, a.shape, start_indices):
  1288. if x < 0:
  1289. msg = f"Attempting to slice a tensor with a negative stop index of {x}!"
  1290. raise ValueError(msg)
  1291. if x > y:
  1292. msg = (
  1293. f"Attempting to slice a tensor but a stop index in {limit_indices} is greater than the length of "
  1294. f" its corresponding dimension in shape {a.shape}"
  1295. )
  1296. raise ValueError(msg)
  1297. if x < z:
  1298. msg = (
  1299. f"Attempting to slice a tensor but a start index in {x} is greater than "
  1300. f" its corresponding stop index {z}"
  1301. )
  1302. for x in _strides:
  1303. if x <= 0:
  1304. msg = f"Attempting to slice a tensor with a non-positive step of {x}!"
  1305. raise ValueError(msg)
  1306. new_shape = []
  1307. for x, y, z in zip(start_indices, limit_indices, _strides):
  1308. new_shape.append(1 + (y - x - 1) // z)
  1309. new_strides = []
  1310. for x, y in zip(a.stride(), _strides):
  1311. new_strides.append(x * y)
  1312. return a.as_strided(new_shape, new_strides, a.storage_offset())
  1313. def _slice_aten(
  1314. a: Tensor,
  1315. start_indices: DimsSequenceType,
  1316. limit_indices: DimsSequenceType,
  1317. strides: Optional[StrideType] = None,
  1318. ) -> Tensor:
  1319. _strides = strides if strides is not None else [1] * len(start_indices)
  1320. slices = []
  1321. for start, stop, step in zip(start_indices, limit_indices, _strides):
  1322. slices.append(pyslice(start, stop, step))
  1323. return operator.getitem(a, slices) # type: ignore[call-overload]
  1324. _slice_doc = """
  1325. Creates a view of a "bounding box" within the tensor.
  1326. The bounding box is specified independently in each of the tensor's dimensions.
  1327. start_indices and limit_indices describe the box's boundaries for their corresponding
  1328. dimensions. If strides is specified then they specify the step size between elements
  1329. in their corresponding dimension.
  1330. This operation is analogous to slicing in NumPy, but does not permit slices where
  1331. the stop indices are less than the start indices.
  1332. """
  1333. slice = _make_prim(
  1334. schema="slice(Tensor(a) a, SymInt[] start_indices, SymInt[] limit_indices, SymInt[]? strides=None) -> Tensor(a)",
  1335. meta=_slice_meta,
  1336. impl_aten=_slice_aten,
  1337. return_type=RETURN_TYPE.VIEW,
  1338. doc=_slice_doc,
  1339. )
  1340. def _slice_in_dim_meta(
  1341. a: TensorLikeType,
  1342. start_index: int,
  1343. limit_index: int,
  1344. stride: int = 1,
  1345. axis: int = 0,
  1346. ) -> TensorLikeType:
  1347. if axis < 0:
  1348. msg = f"slice_in_dim: received a negative axis {axis}"
  1349. raise ValueError(msg)
  1350. if axis >= a.ndim:
  1351. msg = f"slice_in_dim: axis {axis} is greater or equal to the rank {a.ndim} of the tensor"
  1352. raise ValueError(msg)
  1353. if start_index < 0:
  1354. msg = f"slice_in_dim: received a negative start_index {start_index}"
  1355. raise ValueError(msg)
  1356. if start_index > a.shape[axis]:
  1357. msg = f"slice_in_dim: start_index is greater than the length {start_index} of dimension {axis}"
  1358. raise ValueError(msg)
  1359. if limit_index > a.shape[axis]:
  1360. msg = f"slice_in_dim: limit_index is greater than the length {limit_index} of dimension {axis}"
  1361. raise ValueError(msg)
  1362. if limit_index < start_index:
  1363. msg = f"slice_in_dim: received a limit_index {limit_index} less than the start_index {start_index}"
  1364. raise ValueError(msg)
  1365. if stride < 0:
  1366. msg = f"slice_in_dim: received a non-positive stride of {stride}!"
  1367. raise ValueError(msg)
  1368. start_indices = [0] * a.ndim
  1369. limit_indices = list(a.shape)
  1370. strides = [1] * a.ndim
  1371. start_indices[axis] = start_index
  1372. limit_indices[axis] = limit_index
  1373. strides[axis] = stride
  1374. return _slice_meta(a, start_indices, limit_indices, strides)
  1375. def _slice_in_dim_aten(
  1376. a: Tensor,
  1377. start_index: int,
  1378. limit_index: int,
  1379. stride: int = 1,
  1380. axis: int = 0,
  1381. ) -> Tensor:
  1382. start_indices = [0] * a.ndim
  1383. limit_indices = list(a.shape)
  1384. strides = [1] * a.ndim
  1385. start_indices[axis] = start_index
  1386. limit_indices[axis] = limit_index
  1387. strides[axis] = stride
  1388. return slice(a, start_indices, limit_indices, strides)
  1389. _slice_in_dim_doc = """
  1390. Convenience wrapper for slicing just one dimension using slice.
  1391. """
  1392. # TODO: make stride SymInt
  1393. slice_in_dim = _make_prim(
  1394. schema="slice_in_dim(Tensor(a) a, SymInt start_index, SymInt limit_index, int stride=1, int axis=0) -> Tensor(a)",
  1395. meta=_slice_in_dim_meta,
  1396. impl_aten=_slice_in_dim_aten,
  1397. return_type=RETURN_TYPE.VIEW,
  1398. doc=_slice_in_dim_doc,
  1399. )
  1400. def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType:
  1401. assert isinstance(a, TensorLike)
  1402. utils.validate_idx(a.ndim, dim)
  1403. utils.validate_dim_length(outer_length)
  1404. # Verifies the dim can be split with the specified lhs_length
  1405. inner_length = a.shape[dim] // outer_length
  1406. if (a.shape[dim] % outer_length) != 0:
  1407. msg = (
  1408. f"Attempting to split dimension of length {a.shape[dim]}, "
  1409. f"but outer length of {outer_length} divides it with a remainder!"
  1410. )
  1411. raise ValueError(msg)
  1412. new_shape: List[int] = []
  1413. new_strides: List[int] = []
  1414. for idx in range(a.ndim):
  1415. if idx == dim:
  1416. new_shape.extend((outer_length, inner_length))
  1417. new_strides.extend((a.stride()[idx] * inner_length, a.stride()[idx]))
  1418. else:
  1419. new_shape.append(a.shape[idx])
  1420. new_strides.append(a.stride()[idx])
  1421. return a.as_strided(new_shape, new_strides, a.storage_offset())
  1422. def _split_dim_aten(a: Tensor, dim: int, outer_length: int) -> Tensor:
  1423. inner_length = a.shape[dim] // outer_length
  1424. new_shape = a.shape[0:dim] + (outer_length, inner_length) + a.shape[dim + 1 :]
  1425. return a.view(new_shape)
  1426. _split_dim_doc = """
  1427. Creates a view of a with the given dimension (of length l) split
  1428. into two dimensions, with the outer of the two having
  1429. length outer_length and the inner of the two having computed
  1430. length inner_length such outer_length * inner_length = l.
  1431. """
  1432. # TODO: consider renaming split_dim_view
  1433. split_dim = _make_prim(
  1434. schema="split_dim(Tensor(a) a, int dim, SymInt outer_length) -> Tensor(a)",
  1435. meta=_split_dim_meta,
  1436. impl_aten=_split_dim_aten,
  1437. return_type=RETURN_TYPE.VIEW,
  1438. doc=_split_dim_doc,
  1439. )
  1440. # Note: allows dimensions to be specified redundantly
  1441. def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType:
  1442. assert isinstance(a, TensorLike)
  1443. for idx in dimensions:
  1444. utils.validate_idx(a.ndim, idx)
  1445. assert a.shape[idx] == 1
  1446. new_shape = []
  1447. new_strides = []
  1448. for idx in range(len(a.shape)):
  1449. if idx in dimensions:
  1450. continue
  1451. new_shape.append(a.shape[idx])
  1452. new_strides.append(a.stride()[idx])
  1453. return a.as_strided(new_shape, new_strides, a.storage_offset())
  1454. _squeeze_doc = """
  1455. Creates a view of the tensor with the specified dimensions removed.
  1456. The removed dimensions must each have length one.
  1457. """
  1458. squeeze = _make_prim(
  1459. schema="squeeze(Tensor(a) a, int[] dimensions) -> Tensor(a)",
  1460. meta=_squeeze_meta,
  1461. impl_aten=torch.squeeze,
  1462. return_type=RETURN_TYPE.VIEW,
  1463. doc=_squeeze_doc,
  1464. )
  1465. def _transpose_meta(a: TensorLikeType, permutation: DimsSequenceType) -> TensorLikeType:
  1466. if a.ndim != len(permutation):
  1467. msg = f"Attempting to permute a tensor of rank {a.ndim}, but received a permutation of length {len(permutation)}!"
  1468. raise ValueError(msg)
  1469. if not utils.is_valid_permutation(a.ndim, permutation):
  1470. msg = f"Received an invalid permutation, {permutation}!"
  1471. raise ValueError(msg)
  1472. new_shape = [0] * a.ndim
  1473. new_strides = [0] * a.ndim
  1474. for idx, dim in enumerate(permutation):
  1475. new_shape[idx] = a.shape[dim]
  1476. new_strides[idx] = a.stride()[dim]
  1477. return a.as_strided(tuple(new_shape), tuple(new_strides), a.storage_offset())
  1478. def _transpose_aten(a: Tensor, permutation: DimsSequenceType) -> Tensor:
  1479. return torch.permute(a, permutation)
  1480. _transpose_doc = """
  1481. Creates a view of the tensor with its dimensions permuted.
  1482. The length of the permutation must be the rank of the tensor,
  1483. and each element of the permutation specifies the new order
  1484. for the corresponding dimension.
  1485. """
  1486. transpose = _make_prim(
  1487. schema="transpose(Tensor(a) a, int[] permutation) -> Tensor(a)",
  1488. meta=_transpose_meta,
  1489. impl_aten=_transpose_aten,
  1490. return_type=RETURN_TYPE.VIEW,
  1491. doc=_transpose_doc,
  1492. )
  1493. def _view_of_meta(a: TensorLikeType) -> TensorLikeType:
  1494. return a.as_strided(a.shape, a.stride(), a.storage_offset())
  1495. def _view_of_aten(a: Tensor) -> Tensor:
  1496. return a.view(a.shape)
  1497. _view_of_doc = """
  1498. Creates a view of the tensor.
  1499. """
  1500. view_of = _make_prim(
  1501. schema="view_of(Tensor(a) a) -> Tensor(a)",
  1502. meta=_view_of_meta,
  1503. impl_aten=_view_of_aten,
  1504. return_type=RETURN_TYPE.VIEW,
  1505. doc=_view_of_doc,
  1506. )
  1507. def _view_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
  1508. return a.view(dtype)
  1509. def _view_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
  1510. return a.view(dtype)
  1511. _view_element_type_doc = """
  1512. Creates a view of the tensor with a different dtype.
  1513. """
  1514. view_element_type = _make_prim(
  1515. schema="view_of_dtype(Tensor(a) a, ScalarType dtype) -> Tensor(a)",
  1516. meta=_view_element_type_meta,
  1517. impl_aten=_view_element_type_aten,
  1518. return_type=RETURN_TYPE.VIEW,
  1519. doc=_view_element_type_doc,
  1520. )
  1521. #
  1522. # Functionalized view mutations
  1523. #
  1524. def _as_strided_scatter_meta(
  1525. input: TensorLikeType,
  1526. src: TensorLikeType,
  1527. size: ShapeType,
  1528. stride: StrideType,
  1529. storage_offset: int,
  1530. ) -> TensorLikeType:
  1531. utils.validate_shape(size)
  1532. utils.validate_strides(stride)
  1533. required_size = utils.compute_required_storage_length(size, stride, storage_offset)
  1534. torch._check(
  1535. input.numel() >= required_size,
  1536. lambda: (
  1537. f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} "
  1538. f" and itemsize {input.element_size()} requiring a storage size of "
  1539. f"{required_size * input.element_size()} are out of bounds "
  1540. f"for storage of size {input.numel() * input.element_size()}"
  1541. ),
  1542. )
  1543. torch._check(
  1544. utils.is_same_shape(src.shape, size),
  1545. lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
  1546. )
  1547. return utils.clone_preserve_strides(input)
  1548. _as_strided_scatter_doc = """
  1549. Creates a new tensor equivalent to ``out = input.clone()`` after mutation by
  1550. ``out.as_strided(size, stride, storage_offset).copy_(src)``.
  1551. """
  1552. as_strided_scatter = _make_prim(
  1553. schema="as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor",
  1554. meta=_as_strided_scatter_meta,
  1555. impl_aten=torch.as_strided_scatter,
  1556. return_type=RETURN_TYPE.NEW,
  1557. doc=_as_strided_scatter_doc,
  1558. )
  1559. #
  1560. # Shape operations
  1561. #
  1562. def _collapse_meta(a: Tensor, start: int, end: int) -> Tensor:
  1563. # Special-case for zero dimensional tensors
  1564. _validate_collapse_args(a, start, end)
  1565. new_shape = _collapsed_shape(a.shape, start, end)
  1566. return a.new_empty(new_shape)
  1567. def _collapse_aten(a: Tensor, start: int, end: int) -> Tensor:
  1568. new_shape = _collapsed_shape(a.shape, start, end)
  1569. out = a.new_empty(new_shape)
  1570. with torch.no_grad():
  1571. out.view_as(a).copy_(a)
  1572. return out
  1573. _collapse_doc = """
  1574. Collapse a span of neighboring dimensions into one.
  1575. See collapse_view for the corresponding view operation.
  1576. """
  1577. collapse = _make_prim(
  1578. schema="collapse(Tensor a, int start, int end) -> Tensor",
  1579. meta=_collapse_meta,
  1580. impl_aten=_collapse_aten,
  1581. return_type=RETURN_TYPE.NEW,
  1582. doc=_collapse_doc,
  1583. )
  1584. # TODO: review stride logic
  1585. # NB: unlike torch.cat, this is more strict about empty tensors and dim is
  1586. # never negative
  1587. def _cat_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType:
  1588. # Verifies same shape (except in the concat dimension)
  1589. assert dim >= 0
  1590. shape = tensors[0].shape
  1591. concat_length = 0
  1592. for tensor_idx, tensor in enumerate(tensors):
  1593. assert len(shape) == len(tensor.shape)
  1594. for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
  1595. if idx == dim:
  1596. concat_length = concat_length + length
  1597. else:
  1598. torch._check(
  1599. length == common_length,
  1600. lambda: f"Sizes of tensors must match except in dimension {dim}. "
  1601. f"Expected {common_length} but got {length} for tensor number "
  1602. f"{tensor_idx} in the list",
  1603. )
  1604. new_shape = list(tensors[0].shape).copy()
  1605. new_shape[dim] = concat_length
  1606. return TensorMeta(
  1607. tensors[0],
  1608. shape=new_shape,
  1609. strides=utils.make_contiguous_strides_for(new_shape),
  1610. )
  1611. def _cat_aten(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int) -> Tensor:
  1612. return torch.cat(tensors, dim)
  1613. _cat_doc = """
  1614. Concatenates tensors along the specified dimension.
  1615. The tensors' shapes must have the same rank and same length for other dimensions.
  1616. """
  1617. cat = _make_prim(
  1618. schema="cat(Tensor[] tensors, int dim) -> Tensor",
  1619. meta=_cat_meta,
  1620. impl_aten=_cat_aten,
  1621. return_type=RETURN_TYPE.NEW,
  1622. doc=_cat_doc,
  1623. )
  1624. def _reshape_meta(a: TensorLikeType, shape: ShapeType):
  1625. assert isinstance(a, TensorLike)
  1626. utils.validate_shape(shape)
  1627. # Validates the tensor and the requested shape have the
  1628. # same number of elements
  1629. numel = reduce(operator.mul, shape)
  1630. if numel != a.numel():
  1631. msg = f"Attempting to reshape a tensor with {a.numel()} elements to a shape with {numel} elements!"
  1632. raise ValueError(msg)
  1633. return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape))
  1634. def _reshape_aten(a: Tensor, shape: ShapeType) -> Tensor:
  1635. return a.reshape(shape).contiguous().clone()
  1636. _reshape_doc = """
  1637. Creates a contiguous tensor with the specified shape
  1638. containing a copy of the data in a.
  1639. """
  1640. reshape = _make_prim(
  1641. schema="reshape(Tensor a, SymInt[] shape) -> Tensor",
  1642. meta=_reshape_meta,
  1643. impl_aten=_reshape_aten,
  1644. return_type=RETURN_TYPE.NEW,
  1645. doc=_reshape_doc,
  1646. )
  1647. def _rev_meta(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
  1648. utils.validate_dimension_indices(a.ndim, dims)
  1649. return torch.empty_like(a, memory_format=torch.preserve_format)
  1650. _rev_doc = """
  1651. Reverses the order of elements along the given dimensions.
  1652. """
  1653. rev = _make_prim(
  1654. schema="rev(Tensor a, int[] dims) -> Tensor",
  1655. meta=_rev_meta,
  1656. impl_aten=torch.flip,
  1657. return_type=RETURN_TYPE.NEW,
  1658. doc=_rev_doc,
  1659. )
  1660. #
  1661. # Conditional prims
  1662. #
  1663. def _where_meta(
  1664. pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType
  1665. ) -> TensorLikeType:
  1666. return _prim_elementwise_meta(
  1667. a,
  1668. b,
  1669. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  1670. args_with_fixed_dtypes=(pred,),
  1671. )
  1672. _where_doc = """
  1673. Selects elements from a and b according to pred.
  1674. Where pred is true the result contains the element from a, and
  1675. where pred is false the result contains the element from b.
  1676. """
  1677. where = _make_prim(
  1678. schema="where(Tensor pred, Tensor a, Tensor b) -> Tensor",
  1679. meta=_where_meta,
  1680. impl_aten=torch.where,
  1681. return_type=RETURN_TYPE.NEW,
  1682. doc=_where_doc,
  1683. )
  1684. #
  1685. # Type conversions
  1686. #
  1687. def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
  1688. # Type checks
  1689. assert isinstance(a, TensorLike)
  1690. assert isinstance(dtype, torch.dtype)
  1691. # dtype conversion preserves dense strides
  1692. if torch._prims_common.is_non_overlapping_and_dense(a):
  1693. strides = a.stride()
  1694. else:
  1695. strides = utils.compute_elementwise_output_strides(a)
  1696. return TensorMeta(a, strides=strides, dtype=dtype)
  1697. def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
  1698. # Propagates requires grad when possible
  1699. if not utils.is_grad_dtype(dtype):
  1700. requires_grad = False
  1701. else:
  1702. # TODO: update meta objects so this can be acquired directly
  1703. try:
  1704. requires_grad = a.requires_grad
  1705. except Exception as e:
  1706. requires_grad = False
  1707. result = torch.empty_like(
  1708. a, device=a.device, dtype=dtype, requires_grad=requires_grad
  1709. )
  1710. with torch.no_grad():
  1711. return copy_to(result, a)
  1712. _convert_element_type_doc = """
  1713. Creates a copy of a tensor with the given dtype.
  1714. """
  1715. convert_element_type = _make_prim(
  1716. schema="convert_element_type(Tensor a, ScalarType dtype) -> Tensor",
  1717. meta=_convert_element_type_meta,
  1718. impl_aten=_convert_element_type_aten,
  1719. return_type=RETURN_TYPE.NEW,
  1720. doc=_convert_element_type_doc,
  1721. tags=(torch.Tag.pointwise,),
  1722. )
  1723. def _device_put_meta(
  1724. a: TensorLikeType, device: Union[str, torch.device]
  1725. ) -> TensorLikeType:
  1726. assert isinstance(a, TensorLike)
  1727. assert isinstance(device, (str, torch.device))
  1728. return TensorMeta(a, device=utils.canonicalize_device(device))
  1729. def _device_put_aten(a: Tensor, device: Union[str, torch.device]) -> Tensor:
  1730. return a.to(device)
  1731. _device_put_doc = """
  1732. Creates a copy of a tensor on the given device.
  1733. """
  1734. device_put = _make_prim(
  1735. schema="device_put(Tensor a, Device device) -> Tensor",
  1736. meta=_device_put_meta,
  1737. impl_aten=_device_put_aten,
  1738. return_type=RETURN_TYPE.NEW,
  1739. doc=_device_put_doc,
  1740. )
  1741. # NOTE: need to model meta scalars
  1742. # See https://github.com/pytorch/pytorch/issues/78070
  1743. def _item_meta(a: TensorLikeType) -> FakeTensor:
  1744. number_type = utils.dtype_to_type(a.dtype)
  1745. return TensorMeta(number_type(-1))
  1746. _item_doc = """
  1747. Converts a tensor with one element to a Python number.
  1748. """
  1749. # TODO: create a new return type for scalars?
  1750. # FIXME: currently returns integers for boolean tensors
  1751. # https://github.com/pytorch/pytorch/issues/78071
  1752. item = _make_prim(
  1753. schema="item(Tensor a) -> Scalar",
  1754. meta=_item_meta,
  1755. impl_aten=torch.Tensor.item,
  1756. return_type=RETURN_TYPE.NEW,
  1757. doc=_item_doc,
  1758. )
  1759. # NOTE: need to model meta scalars
  1760. # See https://github.com/pytorch/pytorch/issues/78070
  1761. def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor:
  1762. number_type = utils.dtype_to_type(dtype)
  1763. return TensorMeta(number_type(-1))
  1764. def _maximum_value_aten(dtype: torch.dtype):
  1765. if dtype == torch.bool:
  1766. return True
  1767. elif dtype.is_complex or dtype.is_floating_point:
  1768. return torch.finfo(dtype).max
  1769. else:
  1770. return torch.iinfo(dtype).max
  1771. _maximum_value_doc = """
  1772. Return the maximum finite value for a dtype.
  1773. """
  1774. # TODO: create a new return type for scalars?
  1775. # FIXME: currently returns integers for boolean tensors
  1776. # https://github.com/pytorch/pytorch/issues/78071
  1777. maximum_value = _make_prim(
  1778. schema="maximum_value(ScalarType dtype) -> Scalar",
  1779. meta=_maximum_value_meta,
  1780. impl_aten=_maximum_value_aten,
  1781. return_type=RETURN_TYPE.NEW,
  1782. doc=_maximum_value_doc,
  1783. )
  1784. # NOTE: need to model meta scalars
  1785. # See https://github.com/pytorch/pytorch/issues/78070
  1786. def _minimum_value_meta(dtype: torch.dtype) -> FakeTensor:
  1787. number_type = utils.dtype_to_type(dtype)
  1788. return TensorMeta(number_type(-1))
  1789. def _minimum_value_aten(dtype: torch.dtype):
  1790. if dtype == torch.bool:
  1791. return False
  1792. elif dtype.is_complex or dtype.is_floating_point:
  1793. return torch.finfo(dtype).min
  1794. else:
  1795. return torch.iinfo(dtype).min
  1796. _minimum_value_doc = """
  1797. Return the minimum finite value for a dtype.
  1798. """
  1799. # TODO: create a new return type for scalars?
  1800. # FIXME: currently returns integers for boolean tensors
  1801. # https://github.com/pytorch/pytorch/issues/78071
  1802. minimum_value = _make_prim(
  1803. schema="minimum_value(ScalarType dtype) -> Scalar",
  1804. meta=_minimum_value_meta,
  1805. impl_aten=_minimum_value_aten,
  1806. return_type=RETURN_TYPE.NEW,
  1807. doc=_minimum_value_doc,
  1808. )
  1809. #
  1810. # Inplace operators
  1811. #
  1812. def _copy_to_meta(a: TensorLikeType, b: TensorLikeType):
  1813. assert isinstance(a, TensorLike)
  1814. assert isinstance(b, TensorLike)
  1815. # Validates the cast is safe
  1816. # TODO: move this as an option on the reference
  1817. # a_typ = utils.dtype_to_type(a.dtype)
  1818. # b_typ = utils.dtype_to_type(b.dtype)
  1819. # if a_typ is not utils.get_higher_type(a_typ, b_typ):
  1820. # raise RuntimeError(str(b.dtype), " can't be cast safely to ", str(a.dtype), "!")
  1821. # Validates the tensors have the same number of elements
  1822. if a.numel() != b.numel():
  1823. msg = f"Attempting to copy {b.numel()} elements to a tensor with {a.numel()} elements!"
  1824. raise RuntimeError(msg)
  1825. return a
  1826. def _copy_to_aten(a: Tensor, b: Tensor) -> Tensor:
  1827. return a.copy_(b)
  1828. _copy_to_doc = """
  1829. Copies the data in b to a and returns the modified a.
  1830. """
  1831. # TODO: Remove safe casting and implement on reference instead
  1832. copy_to = _make_prim(
  1833. schema="copy_to(Tensor(a!) a, Tensor b) -> Tensor(a!)",
  1834. meta=_copy_to_meta,
  1835. impl_aten=_copy_to_aten,
  1836. return_type=RETURN_TYPE.INPLACE,
  1837. doc=_copy_to_doc,
  1838. )
  1839. def _copy_strided_meta(a: TensorLikeType, stride: ShapeType):
  1840. assert isinstance(a, TensorLike)
  1841. return torch.empty_strided(
  1842. a.shape,
  1843. stride,
  1844. dtype=a.dtype,
  1845. layout=a.layout,
  1846. device=a.device,
  1847. requires_grad=a.requires_grad,
  1848. )
  1849. def _copy_strided_aten(a: Tensor, stride: ShapeType) -> Tensor:
  1850. out = torch.empty_strided(
  1851. a.size(),
  1852. stride=stride,
  1853. dtype=a.dtype,
  1854. layout=a.layout,
  1855. device=a.device,
  1856. requires_grad=a.requires_grad,
  1857. )
  1858. out.copy_(a)
  1859. return out
  1860. _copy_strided_doc = """
  1861. Copies the data in a to a new tensor, the new tensor has same shape with a size, but has different stride.
  1862. """
  1863. copy_strided = _make_prim(
  1864. schema="copy_strided(Tensor a, SymInt[] stride) -> Tensor",
  1865. meta=_copy_strided_meta,
  1866. impl_aten=_copy_strided_aten,
  1867. return_type=RETURN_TYPE.NEW,
  1868. doc=_copy_strided_doc,
  1869. )
  1870. def _resize_meta(a: TensorLikeType, shape: ShapeType):
  1871. return a.resize_(shape)
  1872. def _resize_aten(a: Tensor, shape: ShapeType) -> Tensor:
  1873. return a.resize_(shape)
  1874. _resize_doc = """
  1875. Gives a tensor with no elements a new shape, returning the modified tensor.
  1876. The tensor's strides are contiguous and its values are unitialized.
  1877. """
  1878. # TODO: review support arbitrary resizes
  1879. resize = _make_prim(
  1880. schema="resize(Tensor(a!) a, SymInt[] shape) -> Tensor(a!)",
  1881. meta=_resize_meta,
  1882. impl_aten=_resize_aten,
  1883. return_type=RETURN_TYPE.INPLACE,
  1884. doc=_resize_doc,
  1885. )
  1886. def _reduction_meta(inp, dims, *, output_dtype=None):
  1887. """
  1888. Meta function for single output reduction operations
  1889. Stride logic is incorrect
  1890. """
  1891. assert isinstance(inp, TensorLike)
  1892. if output_dtype is None:
  1893. output_dtype = inp.dtype
  1894. output_shape = utils.compute_reduction_output_shape(inp.shape, dims)
  1895. return TensorMeta(
  1896. shape=output_shape,
  1897. strides=utils.make_contiguous_strides_for(output_shape),
  1898. dtype=output_dtype,
  1899. device=inp.device,
  1900. )
  1901. def _var_reduction_meta(inp, dims, correction):
  1902. if utils.is_complex_dtype(inp.dtype):
  1903. output_dtype = utils.corresponding_real_dtype(inp.dtype)
  1904. else:
  1905. output_dtype = inp.dtype
  1906. return _reduction_meta(inp, dims, output_dtype=output_dtype)
  1907. _sum_doc = """
  1908. Computes the sum of elements in the input tensor over the list of dimensions
  1909. specified in the dim argument
  1910. """
  1911. _xor_sum_doc = """
  1912. Computes the xor sum of elements in the input tensor over the list of dimensions
  1913. specified in the dim argument
  1914. """
  1915. _prod_doc = """
  1916. Computes the product of elements in the input tensor over the list of dimensions
  1917. specified in the dim argument
  1918. """
  1919. _amax_doc = """
  1920. Computes the maximum value of elements in the input tensor over the list of dimensions
  1921. specified in the dim argument
  1922. """
  1923. _amin_doc = """
  1924. Computes the minimum value of elements in the input tensor over the list of dimensions
  1925. specified in the dim argument
  1926. """
  1927. _var_doc = """
  1928. Computes the biased variance of x over the list of dimensions specified in the dim argument
  1929. """
  1930. def _make_reduction_prim(name: str, impl_aten, doc):
  1931. """Creates a reduction prim."""
  1932. return _make_prim(
  1933. schema=f"{name}(Tensor inp, int[]? dims, *, ScalarType? output_dtype=None) -> Tensor",
  1934. meta=_reduction_meta,
  1935. impl_aten=impl_aten,
  1936. return_type=RETURN_TYPE.NEW,
  1937. doc=doc,
  1938. )
  1939. def _make_var_reduction_prim(name: str, impl_aten, doc):
  1940. """Creates a reduction prim."""
  1941. return _make_prim(
  1942. schema=f"{name}(Tensor inp, int[]? dims, float? correction=1, *, ScalarType? output_dtype=None) -> Tensor",
  1943. meta=_var_reduction_meta,
  1944. impl_aten=impl_aten,
  1945. return_type=RETURN_TYPE.NEW,
  1946. doc=doc,
  1947. )
  1948. sum = _make_reduction_prim(
  1949. name="sum",
  1950. impl_aten=torch.sum,
  1951. doc=_sum_doc,
  1952. )
  1953. def _xor_sum_aten(
  1954. inp: TensorLikeType,
  1955. dims: Optional[DimsSequenceType],
  1956. *,
  1957. dtype: Optional[torch.dtype] = None,
  1958. ) -> Tensor:
  1959. raise NotImplementedError("xor_sum only implemented with inductor")
  1960. xor_sum = _make_reduction_prim(
  1961. name="xor_sum",
  1962. impl_aten=_xor_sum_aten,
  1963. doc=_xor_sum_doc,
  1964. )
  1965. def _prod_aten(
  1966. inp: TensorLikeType,
  1967. dims: Optional[DimsSequenceType],
  1968. *,
  1969. dtype: Optional[torch.dtype] = None,
  1970. ) -> Tensor:
  1971. if dims is not None:
  1972. if len(dims) == 0:
  1973. return inp.clone()
  1974. for d in sorted(dims, reverse=True):
  1975. assert d >= 0
  1976. inp = torch.prod(inp, d, dtype=dtype)
  1977. return inp
  1978. else:
  1979. return torch.prod(inp, dims, dtype=dtype)
  1980. prod = _make_reduction_prim(
  1981. name="prod",
  1982. impl_aten=_prod_aten,
  1983. doc=_prod_doc,
  1984. )
  1985. # torch.var, but correction is not kwarg-only
  1986. def torch_var(input, dim=None, correction=1, **kwargs):
  1987. return torch.var(input, dim=dim, correction=correction, **kwargs)
  1988. var = _make_var_reduction_prim(
  1989. name="var",
  1990. impl_aten=torch_var,
  1991. doc=_var_doc,
  1992. )
  1993. amax = _make_reduction_prim(
  1994. name="amax",
  1995. impl_aten=torch.amax,
  1996. doc=_amax_doc,
  1997. )
  1998. amin = _make_reduction_prim(
  1999. name="amin",
  2000. impl_aten=torch.amin,
  2001. doc=_amin_doc,
  2002. )
  2003. _iota_doc = """
  2004. Constructs a 1-D tensor t where ``t[i] == start + i * step``.
  2005. """
  2006. # TODO: layout, pin_memory, memory_format
  2007. # TODO: model requires_grad on TensorMeta
  2008. def _iota_meta(
  2009. length: int,
  2010. *,
  2011. start: int,
  2012. step: int,
  2013. dtype: torch.dtype,
  2014. device: torch.device,
  2015. requires_grad: bool,
  2016. ) -> TensorLikeType:
  2017. torch._check(
  2018. utils.is_integer_dtype(dtype),
  2019. lambda: "prims.iota only supports integer dtypes",
  2020. )
  2021. torch._check(step != 0, lambda: "step must be nonzero")
  2022. return torch.empty(
  2023. length,
  2024. dtype=dtype,
  2025. device=device,
  2026. requires_grad=requires_grad,
  2027. )
  2028. def _iota_aten(
  2029. length: int,
  2030. *,
  2031. start: int,
  2032. step: int,
  2033. dtype: torch.dtype,
  2034. device: torch.device,
  2035. requires_grad: bool,
  2036. ) -> TensorLikeType:
  2037. end = start + length * step
  2038. return torch.arange(
  2039. start, end, step, dtype=dtype, device=device, requires_grad=requires_grad
  2040. )
  2041. iota = _make_prim(
  2042. schema="iota(SymInt length, *, SymInt start, SymInt step, ScalarType dtype, Device device, bool requires_grad) -> Tensor", # noqa: B950
  2043. return_type=RETURN_TYPE.NEW,
  2044. meta=_iota_meta,
  2045. impl_aten=_iota_aten,
  2046. doc=_iota_doc,
  2047. )
  2048. # TODO: layout, pin_memory, memory_format
  2049. # TODO: model requires_grad on TensorMeta
  2050. def _empty_meta(
  2051. shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
  2052. ) -> TensorLikeType:
  2053. strides = utils.make_contiguous_strides_for(shape)
  2054. return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
  2055. def _empty_aten(
  2056. shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
  2057. ) -> Tensor:
  2058. return torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
  2059. _empty_doc = """
  2060. Creates a tensor with uninitialized values and the specified shape, dtype, and device.
  2061. """
  2062. empty = _make_prim(
  2063. schema="empty(SymInt[] shape, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
  2064. meta=_empty_meta,
  2065. impl_aten=_empty_aten,
  2066. return_type=RETURN_TYPE.NEW,
  2067. doc=_empty_doc,
  2068. )
  2069. def _empty_strided_meta(
  2070. shape: ShapeType,
  2071. strides: StrideType,
  2072. *,
  2073. dtype: torch.dtype,
  2074. device: torch.device,
  2075. requires_grad: bool,
  2076. ) -> TensorLikeType:
  2077. return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
  2078. _empty_strided_doc = """
  2079. Creates a tensor with uninitialized values.
  2080. """
  2081. # TODO: add layout, pin_memory
  2082. empty_strided = _make_prim(
  2083. schema="empty_strided(SymInt[] shape, SymInt[] strides, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
  2084. return_type=RETURN_TYPE.NEW,
  2085. meta=_empty_strided_meta,
  2086. impl_aten=torch.empty_strided,
  2087. doc=_empty_strided_doc,
  2088. )
  2089. def _empty_permuted_meta(
  2090. shape: ShapeType,
  2091. physical_layout: DimsSequenceType,
  2092. *,
  2093. dtype: torch.dtype,
  2094. device: torch.device,
  2095. requires_grad: bool,
  2096. ) -> TensorLikeType:
  2097. p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout])
  2098. dim = len(shape)
  2099. torch._check(
  2100. len(physical_layout) == dim,
  2101. lambda: (
  2102. "Number of dimensions in the tensor input does not match the "
  2103. f"length of the physical layout; i.e. len(size) = {dim} "
  2104. f"is not equal to len(physical_layout) = {len(physical_layout)}"
  2105. ),
  2106. )
  2107. strides = [0] * len(shape)
  2108. seen_dims = set()
  2109. for p, l in enumerate(physical_layout):
  2110. torch._check(
  2111. 0 <= l < dim,
  2112. lambda: (
  2113. f"Dimension out of range (expected to be between 0 and {dim - 1}, but got "
  2114. f"{l} at index {p}). NB: negative dims "
  2115. "not currently supported; file an issue if you want it."
  2116. ),
  2117. )
  2118. torch._check(l not in seen_dims, lambda: "Duplicate dim not allowed")
  2119. strides[l] = p_strides[p]
  2120. seen_dims.add(l)
  2121. return TensorMeta(
  2122. shape=shape,
  2123. strides=strides,
  2124. dtype=dtype,
  2125. device=device,
  2126. )
  2127. _empty_permuted_doc = """
  2128. Creates a tensor with uninitialized values according to some physical layout,
  2129. that is guaranteed to be non-overlapping and dense.
  2130. """
  2131. # TODO: add layout, pin_memory
  2132. empty_permuted = _make_prim(
  2133. schema="empty_permuted(SymInt[] shape, int[] physical_layout, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", # noqa: B950
  2134. return_type=RETURN_TYPE.NEW,
  2135. meta=_empty_permuted_meta,
  2136. impl_aten=torch.empty_permuted,
  2137. doc=_empty_permuted_doc,
  2138. )
  2139. def _full_meta(
  2140. shape: ShapeType,
  2141. fill_value: NumberType,
  2142. *,
  2143. dtype: torch.dtype,
  2144. device: torch.device,
  2145. requires_grad: bool,
  2146. ) -> TensorLikeType:
  2147. strides = utils.make_contiguous_strides_for(shape)
  2148. return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
  2149. def _full_aten(
  2150. shape: ShapeType,
  2151. fill_value: NumberType,
  2152. *,
  2153. dtype: torch.dtype,
  2154. device: torch.device,
  2155. requires_grad: bool,
  2156. ) -> Tensor:
  2157. # Note that Mypy thinks torch.full can't accept a complex fill_value
  2158. return torch.full(
  2159. shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type]
  2160. )
  2161. _full_doc = """
  2162. Creates a tensor filled with the given fill value, and with the specified shape, dtype, and device.
  2163. """
  2164. # TODO: add layout
  2165. full = _make_prim(
  2166. schema="full(SymInt[] shape, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
  2167. meta=_full_meta,
  2168. impl_aten=_full_aten,
  2169. return_type=RETURN_TYPE.NEW,
  2170. doc=_full_doc,
  2171. )
  2172. def _full_like_meta(
  2173. a: TensorLikeType,
  2174. fill_value: NumberType,
  2175. *,
  2176. dtype: torch.dtype,
  2177. device: torch.device,
  2178. requires_grad: bool,
  2179. ) -> TensorLikeType:
  2180. strides = utils.compute_elementwise_output_strides(a)
  2181. if a.numel() == 0:
  2182. strides = a.stride()
  2183. return TensorMeta(a, strides=strides, dtype=dtype, device=device)
  2184. def _full_like_aten(
  2185. a: Tensor,
  2186. fill_value: NumberType,
  2187. *,
  2188. dtype: torch.dtype,
  2189. device: torch.device,
  2190. requires_grad: bool,
  2191. ) -> Tensor:
  2192. # Note that Mypy thinks torch.full can't accept a complex fill_value
  2193. return torch.full_like(
  2194. a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type]
  2195. )
  2196. _full_like_doc = """
  2197. Creates a tensor filled with the given fill value, and the same shape, dtype, and device as the
  2198. given tensor by default. The dtype and device settings can be overridden
  2199. by specifying them explicitly.
  2200. """
  2201. full_like = _make_prim(
  2202. schema="full_like(Tensor a, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
  2203. meta=_full_like_meta,
  2204. impl_aten=_full_like_aten,
  2205. return_type=RETURN_TYPE.NEW,
  2206. doc=_full_like_doc,
  2207. )
  2208. def _scalar_tensor_meta(
  2209. scalar: NumberType,
  2210. *,
  2211. dtype: torch.dtype,
  2212. device: torch.device,
  2213. ) -> TensorLikeType:
  2214. shape: ShapeType = []
  2215. strides = utils.make_contiguous_strides_for(shape)
  2216. return TensorMeta(scalar, shape=shape, strides=strides, dtype=dtype, device=device)
  2217. def _scalar_tensor_aten(
  2218. scalar: NumberType,
  2219. *,
  2220. dtype: torch.dtype,
  2221. device: torch.device,
  2222. ) -> Tensor:
  2223. if isinstance(scalar, complex) and (
  2224. dtype is None or not utils.is_complex_dtype(dtype)
  2225. ):
  2226. raise TypeError("Complex scalar requires complex tensor dtype.")
  2227. # Note that Mypy thinks torch.scalar can't accept a complex scalar
  2228. return torch.scalar_tensor(scalar, dtype=dtype, device=device) # type: ignore[arg-type]
  2229. _scalar_tensor_doc = """
  2230. Wraps a Number into a Tensor with the specified dtype and device.
  2231. """
  2232. # TODO: add layout and pin_memory support
  2233. scalar_tensor = _make_prim(
  2234. schema="scalar_tensor(Scalar s, *, ScalarType? dtype=None, Device? device=None) -> Tensor",
  2235. meta=_scalar_tensor_meta,
  2236. impl_aten=_scalar_tensor_aten,
  2237. return_type=RETURN_TYPE.NEW,
  2238. doc=_scalar_tensor_doc,
  2239. )
  2240. #
  2241. # Linear algebra (linalg) prims
  2242. #
  2243. def _svd_meta(
  2244. A: TensorLikeType, *, full_matrices: bool
  2245. ) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]:
  2246. utils.check_is_matrix(A, "linalg.svd")
  2247. utils.check_fp_or_complex(A.dtype, "linalg.svd", allow_low_precision_dtypes=False)
  2248. A_shape = A.shape
  2249. batch = A_shape[:-2]
  2250. m, n = A_shape[-2:]
  2251. k = min(m, n)
  2252. shape_U = batch + (m, m if full_matrices else k)
  2253. strides_U = utils.make_contiguous_strides_for(shape_U, row_major=False)
  2254. U = TensorMeta(shape=shape_U, strides=strides_U, dtype=A.dtype, device=A.device)
  2255. shape_S = batch + (k,)
  2256. strides_S = utils.make_contiguous_strides_for(shape_S)
  2257. S = TensorMeta(
  2258. shape=shape_S,
  2259. strides=strides_S,
  2260. dtype=utils.corresponding_real_dtype(A.dtype) if A.is_complex() else A.dtype,
  2261. device=A.device,
  2262. )
  2263. shape_Vh = batch + (n if full_matrices else k, n)
  2264. # The CPU backend returns V, but the cuSolver backend returns V^H
  2265. # TODO The MAGMA backend returns V, so this is wrong if used with the MAGMA backend
  2266. is_cuda = A.device.type == "cuda"
  2267. strides_Vh = utils.make_contiguous_strides_for(shape_Vh, row_major=is_cuda)
  2268. Vh = TensorMeta(shape=shape_Vh, strides=strides_Vh, dtype=A.dtype, device=A.device)
  2269. # Also makes sure this is CUDA or HIP:
  2270. # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
  2271. if A.numel() != 0 and Vh.is_complex() and torch.cuda.is_available():
  2272. Vh = Vh.conj()
  2273. return U, S, Vh
  2274. def _svd_aten(
  2275. A: TensorLikeType, *, full_matrices: bool
  2276. ) -> Tuple[Tensor, Tensor, Tensor]:
  2277. return torch.linalg.svd(A, full_matrices=full_matrices)
  2278. _svd_doc = """
  2279. Returns the SVD of a matrix or batch of matrices.
  2280. The `full_matrices` flag controls whether the full or reduced SVD decomposition is returned.
  2281. """
  2282. svd = _make_prim(
  2283. schema="svd(Tensor A, *, bool full_matrices) -> (Tensor U, Tensor S, Tensor Vh)",
  2284. meta=_svd_meta,
  2285. impl_aten=_svd_aten,
  2286. return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW, RETURN_TYPE.NEW),
  2287. doc=_svd_doc,
  2288. )
  2289. #
  2290. # Randomness Prims
  2291. #
  2292. def _normal_meta(
  2293. shape: ShapeType,
  2294. *,
  2295. mean: Union[float, complex],
  2296. std: float,
  2297. dtype: torch.dtype,
  2298. device: torch.device,
  2299. requires_grad: bool,
  2300. generator: Optional[torch.Generator] = None,
  2301. ) -> TensorLikeType:
  2302. torch._check(
  2303. std >= 0.0,
  2304. lambda: f"expected non-negative standard deviation, but got std={std}",
  2305. )
  2306. torch._check(
  2307. utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
  2308. lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}",
  2309. )
  2310. strides = utils.make_contiguous_strides_for(shape)
  2311. return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
  2312. def _normal_aten(
  2313. shape: ShapeType,
  2314. *,
  2315. mean: Union[float, complex],
  2316. std: float,
  2317. dtype: torch.dtype,
  2318. device: torch.device,
  2319. requires_grad: bool,
  2320. generator: Optional[torch.Generator] = None,
  2321. ) -> Tensor:
  2322. a = torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
  2323. with torch.no_grad():
  2324. # NOTE: normal_ is incorrectly annotated to expect mean to be a float
  2325. a.normal_(mean, std, generator=generator) # type: ignore[arg-type]
  2326. return a
  2327. _normal_doc = """
  2328. Constructs a tensor filled with values drawn from a normal distribution with the specified mean
  2329. and standard deviation.
  2330. Only supports floating-point types.
  2331. """
  2332. normal = _make_prim(
  2333. schema=(
  2334. "normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad, Generator? generator=None) -> Tensor" # noqa: B950
  2335. ),
  2336. return_type=RETURN_TYPE.NEW,
  2337. meta=_normal_meta,
  2338. impl_aten=_normal_aten,
  2339. doc=_normal_doc,
  2340. )
  2341. def _uniform_meta(
  2342. shape: ShapeType,
  2343. *,
  2344. low: float,
  2345. high: float,
  2346. dtype: torch.dtype,
  2347. device: torch.device,
  2348. generator: Optional[torch.Generator] = None,
  2349. ) -> TensorLikeType:
  2350. strides = utils.make_contiguous_strides_for(shape)
  2351. return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
  2352. def _uniform_aten(
  2353. shape: ShapeType,
  2354. *,
  2355. low: float,
  2356. high: float,
  2357. dtype: torch.dtype,
  2358. device: torch.device,
  2359. generator: Optional[torch.Generator] = None,
  2360. ) -> Tensor:
  2361. a = torch.empty(shape, dtype=dtype, device=device)
  2362. a.uniform_(low, high, generator=generator)
  2363. return a
  2364. _uniform_doc = """
  2365. Constructs a tensor filled with values drawn uniformly from low to high.
  2366. """
  2367. # TODO: we should more seriously review randomness modeling and prims
  2368. _uniform_helper = _make_prim(
  2369. schema=(
  2370. "uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device, Generator? generator=None) -> Tensor"
  2371. ),
  2372. return_type=RETURN_TYPE.NEW,
  2373. meta=_uniform_meta,
  2374. impl_aten=_uniform_aten,
  2375. doc=_uniform_doc,
  2376. )
  2377. #
  2378. # FFT prims
  2379. #
  2380. def _fft_r2c_meta(
  2381. input: TensorLike,
  2382. *,
  2383. dim: DimsSequenceType,
  2384. onesided: bool,
  2385. ) -> TensorLikeType:
  2386. dim = utils.canonicalize_dims(input.ndim, dim)
  2387. utils.validate_no_repeating_dims(dim)
  2388. shape = list(input.shape)
  2389. if onesided:
  2390. last_dim = dim[-1]
  2391. shape[last_dim] = shape[last_dim] // 2 + 1
  2392. dtype = utils.corresponding_complex_dtype(input.dtype)
  2393. strides = utils.make_contiguous_strides_for(shape)
  2394. return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device)
  2395. def _fft_r2c_aten(
  2396. input: TensorLike,
  2397. *,
  2398. dim: DimsSequenceType,
  2399. onesided: bool,
  2400. ) -> TensorLikeType:
  2401. normalization = 0 # No normalization
  2402. return torch._fft_r2c(input, dim, normalization, onesided)
  2403. _fft_r2c_doc = """
  2404. Performs a real to complex Fast Fourier Transform
  2405. """
  2406. fft_r2c = _make_prim(
  2407. schema="fft_r2c(Tensor self, *, int[] dim, bool onesided) -> Tensor",
  2408. meta=_fft_r2c_meta,
  2409. impl_aten=_fft_r2c_aten,
  2410. return_type=RETURN_TYPE.NEW,
  2411. doc=_fft_r2c_doc,
  2412. )
  2413. def _fft_c2c_meta(
  2414. input: TensorLike,
  2415. *,
  2416. dim: DimsSequenceType,
  2417. forward: bool,
  2418. ) -> TensorLikeType:
  2419. dim = utils.canonicalize_dims(input.ndim, dim)
  2420. utils.validate_no_repeating_dims(dim)
  2421. shape = input.shape
  2422. strides = utils.make_contiguous_strides_for(shape)
  2423. return TensorMeta(
  2424. shape=shape, strides=strides, dtype=input.dtype, device=input.device
  2425. )
  2426. def _fft_c2c_aten(
  2427. input: TensorLike,
  2428. *,
  2429. dim: DimsSequenceType,
  2430. forward: bool,
  2431. ) -> TensorLikeType:
  2432. normalization = 0 # No normalization
  2433. return torch._fft_c2c(input, dim, normalization, forward)
  2434. _fft_c2c_doc = """
  2435. Performs either a Fast Fourier Transform, or its inverse
  2436. """
  2437. fft_c2c = _make_prim(
  2438. schema="fft_c2c(Tensor self, *, int[] dim, bool forward) -> Tensor",
  2439. meta=_fft_c2c_meta,
  2440. impl_aten=_fft_c2c_aten,
  2441. return_type=RETURN_TYPE.NEW,
  2442. doc=_fft_c2c_doc,
  2443. )
  2444. def _fft_c2r_meta(
  2445. input: TensorLike,
  2446. *,
  2447. dim: DimsSequenceType,
  2448. last_dim_size: int,
  2449. ) -> TensorLikeType:
  2450. dim = utils.canonicalize_dims(input.ndim, dim)
  2451. utils.validate_no_repeating_dims(dim)
  2452. shape = list(input.shape)
  2453. shape[dim[-1]] = last_dim_size
  2454. dtype = utils.corresponding_real_dtype(input.dtype)
  2455. strides = utils.make_contiguous_strides_for(shape)
  2456. return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device)
  2457. def _fft_c2r_aten(
  2458. input: TensorLike,
  2459. *,
  2460. dim: DimsSequenceType,
  2461. last_dim_size: int,
  2462. ) -> TensorLikeType:
  2463. normalization = 0 # No normalization
  2464. return torch._fft_c2r(input, dim, normalization, last_dim_size)
  2465. _fft_c2r_doc = """
  2466. Performs a complex to real Inverse Fast Fourier Transform
  2467. """
  2468. fft_c2r = _make_prim(
  2469. schema="fft_c2r(Tensor self, *, int[] dim, SymInt last_dim_size) -> Tensor",
  2470. meta=_fft_c2r_meta,
  2471. impl_aten=_fft_c2r_aten,
  2472. return_type=RETURN_TYPE.NEW,
  2473. doc=_fft_c2r_doc,
  2474. )
  2475. def _frexp_meta(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]:
  2476. torch._check(
  2477. self.dtype.is_floating_point,
  2478. lambda: "torch.frexp() only supports floating-point dtypes",
  2479. )
  2480. return torch.empty_like(self), torch.empty_like(self, dtype=torch.int32)
  2481. frexp = _make_prim(
  2482. schema="frexp(Tensor self) -> (Tensor mantissa, Tensor exponent)",
  2483. meta=_frexp_meta,
  2484. return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW),
  2485. impl_aten=torch.frexp,
  2486. doc="",
  2487. )
  2488. def _make_token_aten() -> TensorLikeType:
  2489. return torch.empty(0)
  2490. _make_token = _make_prim(
  2491. schema="_make_token() -> Tensor",
  2492. meta=_make_token_aten,
  2493. return_type=RETURN_TYPE.NEW,
  2494. impl_aten=_make_token_aten,
  2495. doc="Creates a token used for keeping track of side effects.",
  2496. )
  2497. def _sink_tokens_aten(tokens) -> None:
  2498. pass
  2499. _sink_tokens = _make_prim(
  2500. schema="_sink_tokens(Tensor[] tokens) -> ()",
  2501. meta=_sink_tokens_aten,
  2502. return_type=RETURN_TYPE.NONE,
  2503. impl_aten=_sink_tokens_aten,
  2504. doc="Sink all of the tokens which were previously used for keeping track of side effects.",
  2505. )
  2506. register_rng_prims()
  2507. register_debug_prims()