decompositions.py 161 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import numbers
  4. import operator
  5. import sys
  6. from enum import Enum
  7. from functools import partial, reduce
  8. from itertools import chain, product
  9. from typing import Any, Callable, cast, Iterable, List, Optional, Tuple, Union
  10. import torch
  11. import torch._prims as prims
  12. import torch._prims_common as utils
  13. import torch.nn.functional as F
  14. from torch import sym_float, sym_int, Tensor
  15. from torch._decomp import register_decomposition
  16. from torch._higher_order_ops.out_dtype import out_dtype
  17. from torch._prims_common import (
  18. IntLike,
  19. NumberType,
  20. suggest_memory_format,
  21. TensorLike,
  22. TensorSequenceType,
  23. )
  24. from torch._prims_common.wrappers import (
  25. _maybe_convert_to_dtype,
  26. _maybe_resize_out,
  27. _safe_copy_out,
  28. out_wrapper,
  29. )
  30. from torch.utils import _pytree as pytree
  31. from torch.utils._pytree import tree_map
  32. DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
  33. # None of these functions are publicly accessible; get at them
  34. # from torch._decomps
  35. __all__: List[str] = []
  36. aten = torch._ops.ops.aten
  37. class Reduction(Enum):
  38. NONE = 0
  39. MEAN = 1
  40. SUM = 2
  41. # This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided
  42. # We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops
  43. # Will need to validate the non-elementwise uses
  44. def type_casts(
  45. f: Callable,
  46. type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND,
  47. compute_dtype_only: bool = False,
  48. ):
  49. @functools.wraps(f)
  50. def inner(*args, **kwargs):
  51. flat_args = [
  52. x for x in pytree.arg_tree_leaves(*args, **kwargs) if isinstance(x, Tensor)
  53. ]
  54. computation_dtype, result_dtype = utils.elementwise_dtypes(
  55. *flat_args, type_promotion_kind=type_promotion
  56. )
  57. # TODO: pretty sure this is not quite right
  58. def increase_prec(x):
  59. if isinstance(x, Tensor):
  60. return x.to(computation_dtype)
  61. else:
  62. return x
  63. def decrease_prec(x):
  64. if isinstance(x, Tensor):
  65. return x.to(result_dtype)
  66. else:
  67. return x
  68. r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
  69. if compute_dtype_only:
  70. return r
  71. else:
  72. return tree_map(decrease_prec, r)
  73. return inner
  74. compute_only_pw_cast_for_opmath = partial(
  75. type_casts,
  76. type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  77. compute_dtype_only=True,
  78. )
  79. pw_cast_for_opmath = partial(
  80. type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  81. )
  82. pw_cast_for_int_to_real = partial(
  83. type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  84. )
  85. # This expands x until x.dim() == dim. Might be useful as an operator
  86. def _unsqueeze_to_dim(x: Tensor, dim: int) -> Tensor:
  87. for _ in range(dim - x.dim()):
  88. x = x.unsqueeze(-1)
  89. return x
  90. @register_decomposition(aten.tanh_backward)
  91. @out_wrapper("grad_input")
  92. @pw_cast_for_opmath
  93. def tanh_backward(out_grad: Tensor, y: Tensor):
  94. return out_grad * (1 - y * y).conj_physical()
  95. @register_decomposition(aten.sigmoid_backward)
  96. @out_wrapper("grad_input")
  97. @pw_cast_for_opmath
  98. def sigmoid_backward(out_grad: Tensor, y: Tensor):
  99. return out_grad * (y * (1 - y)).conj_physical()
  100. @register_decomposition(aten.softplus_backward)
  101. @out_wrapper("grad_input")
  102. @pw_cast_for_opmath
  103. def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float):
  104. z = (x * beta).exp()
  105. return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
  106. @register_decomposition(aten.elu_backward)
  107. @out_wrapper("grad_input")
  108. @pw_cast_for_opmath
  109. def elu_backward(
  110. grad_output: Tensor,
  111. alpha: float,
  112. scale: float,
  113. input_scale: float,
  114. is_result: bool,
  115. self_or_result: Tensor,
  116. ):
  117. negcoef = alpha * scale
  118. poscoef = scale
  119. negiptcoef = input_scale
  120. if is_result:
  121. return torch.where(
  122. self_or_result <= 0,
  123. grad_output * negiptcoef * (self_or_result + negcoef),
  124. grad_output * poscoef,
  125. )
  126. else:
  127. return torch.where(
  128. self_or_result <= 0,
  129. grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef),
  130. grad_output * poscoef,
  131. )
  132. @register_decomposition([aten.fill.Scalar])
  133. def fill_scalar(self, value):
  134. return torch.full_like(self, value)
  135. @register_decomposition([aten.fill.Tensor])
  136. def fill_tensor(self, value: Tensor):
  137. torch._check(
  138. value.dim() == 0,
  139. lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions",
  140. )
  141. return aten.copy(self, value)
  142. @register_decomposition(aten.hardsigmoid)
  143. @out_wrapper()
  144. @pw_cast_for_opmath
  145. def hardsigmoid(self: Tensor) -> Tensor:
  146. return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
  147. @register_decomposition(aten.hardsigmoid_backward)
  148. @out_wrapper("grad_input")
  149. @pw_cast_for_opmath
  150. def hardsigmoid_backward(grad_output: Tensor, self: Tensor):
  151. return torch.where(
  152. (self > -3.0) & (self < 3.0),
  153. grad_output * (1.0 / 6.0),
  154. 0.0,
  155. )
  156. @register_decomposition(aten.hardtanh_backward)
  157. @out_wrapper("grad_input")
  158. def hardtanh_backward(
  159. grad_output: Tensor, self: Tensor, min_val: float, max_val: float
  160. ):
  161. return torch.where((self <= min_val) | (self >= max_val), 0.0, grad_output)
  162. @register_decomposition(aten.hardswish)
  163. @out_wrapper()
  164. @pw_cast_for_opmath
  165. def hardswish(self: Tensor) -> Tensor:
  166. return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
  167. @register_decomposition(aten.hardswish_backward)
  168. @out_wrapper()
  169. @pw_cast_for_opmath
  170. def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor:
  171. return torch.where(
  172. self < -3,
  173. 0.0,
  174. torch.where(self <= 3, grad_output * ((self / 3) + 0.5), grad_output),
  175. )
  176. @register_decomposition(aten.threshold_backward)
  177. @out_wrapper("grad_input")
  178. def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float):
  179. return torch.where(self <= threshold, 0, grad_output)
  180. @register_decomposition(aten.leaky_relu_backward)
  181. @out_wrapper("grad_input")
  182. @pw_cast_for_opmath
  183. def leaky_relu_backward(
  184. grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool
  185. ):
  186. return torch.where(self > 0, grad_output, grad_output * negative_slope)
  187. @register_decomposition(aten.gelu_backward)
  188. @out_wrapper("grad_input")
  189. @pw_cast_for_opmath
  190. def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
  191. M_SQRT2 = 1.41421356237309504880
  192. M_SQRT1_2 = 0.70710678118654752440
  193. M_2_SQRTPI = 1.12837916709551257390
  194. if approximate == "tanh":
  195. kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
  196. kKappa = 0.044715
  197. x_sq = self * self
  198. x_cube = x_sq * self
  199. inner = kBeta * (self + kKappa * x_cube)
  200. tanh_inner = torch.tanh(inner)
  201. left = 0.5 * self
  202. right = 1 + tanh_inner
  203. left_derivative = 0.5 * right
  204. tanh_derivative = 1 - tanh_inner * tanh_inner
  205. inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
  206. right_derivative = left * tanh_derivative * inner_derivative
  207. return grad * (left_derivative + right_derivative)
  208. else:
  209. kAlpha = M_SQRT1_2
  210. kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
  211. cdf = 0.5 * (1 + torch.erf(self * kAlpha))
  212. pdf = kBeta * torch.exp(self * self * -0.5)
  213. return grad * (cdf + self * pdf)
  214. @register_decomposition(aten.mish_backward)
  215. @pw_cast_for_opmath
  216. def mish_backward(grad_output: Tensor, input: Tensor):
  217. input_tanh_softplus = torch.tanh(F.softplus(input))
  218. input_sigmoid = torch.sigmoid(input)
  219. out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)
  220. return grad_output * (input_tanh_softplus + out)
  221. @register_decomposition(aten.silu)
  222. @out_wrapper()
  223. @pw_cast_for_opmath
  224. def silu(self: Tensor) -> Tensor:
  225. return self * torch.sigmoid(self)
  226. @register_decomposition(aten.silu_backward)
  227. @out_wrapper("grad_input")
  228. @pw_cast_for_opmath
  229. def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
  230. sigmoid = 1 / (1 + torch.exp(-self))
  231. return grad_output * sigmoid * (1 + self * (1 - sigmoid))
  232. @register_decomposition(aten._prelu_kernel)
  233. def _prelu_kernel(self: Tensor, weight: Tensor) -> Tensor:
  234. return torch.where(self > 0, self, weight * self)
  235. @register_decomposition(aten._prelu_kernel_backward)
  236. def _prelu_kernel_backward(
  237. grad_output: Tensor,
  238. self: Tensor,
  239. weight: Tensor,
  240. ) -> Tuple[Tensor, Tensor]:
  241. input_grad = torch.where(self > 0, grad_output, weight * grad_output)
  242. weight_grad = torch.where(self > 0, 0.0, self * grad_output)
  243. return (input_grad, weight_grad)
  244. @register_decomposition(aten.rrelu_with_noise)
  245. @aten.rrelu_with_noise.default.py_impl(DispatchKey.AutogradCUDA)
  246. @out_wrapper()
  247. @pw_cast_for_opmath
  248. def rrelu_with_noise(
  249. self: Tensor,
  250. noise: Tensor,
  251. lower: float = 0.125,
  252. upper: float = 0.3333333333333333,
  253. training: bool = False,
  254. generator: Optional[torch.Generator] = None,
  255. ) -> Tensor:
  256. assert generator is None
  257. if training:
  258. not_positive = self <= 0
  259. r = aten.uniform(self, lower, upper)
  260. output = torch.where(not_positive, self * r, self)
  261. noise.copy_(torch.where(not_positive, r, 1))
  262. return output
  263. else:
  264. negative_slope = (lower + upper) / 2
  265. return aten.leaky_relu(self, negative_slope)
  266. @register_decomposition(aten.rrelu_with_noise_)
  267. @aten.rrelu_with_noise_.default.py_impl(DispatchKey.AutogradCUDA)
  268. @pw_cast_for_opmath
  269. def rrelu_with_noise_(
  270. self: Tensor,
  271. noise: Tensor,
  272. lower: float = 0.125,
  273. upper: float = 0.3333333333333333,
  274. training: bool = False,
  275. generator: Optional[torch.Generator] = None,
  276. ) -> Tensor:
  277. return self.copy_(rrelu_with_noise(self, noise, lower, upper, training, generator))
  278. @register_decomposition(aten.rrelu_with_noise_backward)
  279. @out_wrapper()
  280. @pw_cast_for_opmath
  281. def rrelu_with_noise_backward(
  282. grad_output: Tensor,
  283. self: Tensor,
  284. noise: Tensor,
  285. lower: float,
  286. upper: float,
  287. training: bool,
  288. self_is_result: bool,
  289. ) -> Tensor:
  290. if training and upper - lower > 1e-6:
  291. return grad_output.mul(noise)
  292. else:
  293. negative_slope = (lower + upper) / 2
  294. return aten.leaky_relu_backward(
  295. grad_output, self, negative_slope, self_is_result
  296. )
  297. @register_decomposition(aten.log_sigmoid_backward)
  298. @out_wrapper("grad_input")
  299. @pw_cast_for_opmath
  300. def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor:
  301. in_negative = self < 0
  302. max_deriv = torch.where(in_negative, 1, 0)
  303. sign = torch.where(in_negative, 1, -1)
  304. z = torch.exp(-torch.abs(self))
  305. return grad_output * (max_deriv - sign * (z / (1 + z)))
  306. # CPU has a special formula that uses buffer, but disabled for convenience sake
  307. # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
  308. def apply_loss_reduction(loss: Tensor, reduction: int):
  309. if reduction == Reduction.MEAN.value:
  310. return torch.mean(loss)
  311. elif reduction == Reduction.SUM.value:
  312. return torch.sum(loss)
  313. else:
  314. return loss
  315. def to_real_dtype(dtype: torch.dtype):
  316. if dtype == torch.complex32:
  317. return torch.float16
  318. elif dtype == torch.complex64:
  319. return torch.float32
  320. elif dtype == torch.complex128:
  321. return torch.float64
  322. # TODO: None of these loss castings are quite correct, see
  323. # https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels
  324. # perform the pointwise portion in opmath, but don't maintain it between the
  325. # pointwise portion and the reduction
  326. @register_decomposition(aten.mse_loss)
  327. @out_wrapper()
  328. @pw_cast_for_opmath
  329. def mse_loss(
  330. self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
  331. ) -> Tensor:
  332. loss = (self - target) ** 2
  333. return apply_loss_reduction(loss, reduction)
  334. @register_decomposition(aten.mse_loss_backward)
  335. @out_wrapper("grad_input")
  336. @pw_cast_for_opmath
  337. def mse_loss_backward(
  338. grad_output: Tensor, input: Tensor, target: Tensor, reduction: int
  339. ):
  340. norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0
  341. return norm * (input - target) * grad_output
  342. @register_decomposition(aten.smooth_l1_loss)
  343. @out_wrapper()
  344. @pw_cast_for_opmath
  345. def smooth_l1_loss(
  346. self: Tensor,
  347. target: Tensor,
  348. reduction: int = Reduction.MEAN.value,
  349. beta: float = 1.0,
  350. ):
  351. loss = (self - target).abs()
  352. loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
  353. return apply_loss_reduction(loss, reduction)
  354. @register_decomposition(aten.smooth_l1_loss_backward.default)
  355. @pw_cast_for_opmath
  356. def smooth_l1_loss_backward(
  357. grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, beta: float
  358. ):
  359. norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
  360. x = self - target
  361. abs_x = torch.abs(x)
  362. norm_grad = norm * grad_output
  363. return torch.where(
  364. abs_x < beta,
  365. norm_grad * x / beta,
  366. norm_grad * torch.sign(x),
  367. )
  368. @register_decomposition(aten.smooth_l1_loss_backward.grad_input)
  369. @pw_cast_for_opmath
  370. def smooth_l1_loss_backward_out(
  371. grad_output: Tensor,
  372. self: Tensor,
  373. target: Tensor,
  374. reduction: int,
  375. beta: float,
  376. grad_input: Tensor,
  377. ):
  378. result = smooth_l1_loss_backward(grad_output, self, target, reduction, beta)
  379. _maybe_resize_out(grad_input, result.shape)
  380. return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True)
  381. @register_decomposition(aten.huber_loss_backward.default)
  382. @pw_cast_for_opmath
  383. def huber_loss_backward(
  384. grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float
  385. ):
  386. norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
  387. x = self - target
  388. return torch.where(
  389. x < -delta,
  390. -norm * grad_output * delta,
  391. torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output),
  392. )
  393. # We cannot use @out_wrapper() here, because the output tensor is not named 'out', it's 'grad_input'
  394. @register_decomposition(aten.huber_loss_backward.out)
  395. @pw_cast_for_opmath
  396. def huber_loss_backward_out(
  397. grad_output: Tensor,
  398. self: Tensor,
  399. target: Tensor,
  400. reduction: int,
  401. delta: float,
  402. grad_input: Tensor,
  403. ):
  404. result = huber_loss_backward(grad_output, self, target, reduction, delta)
  405. _maybe_resize_out(grad_input, result.shape)
  406. return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True)
  407. def _nll_loss_backward(
  408. grad_output: Tensor,
  409. self: Tensor,
  410. target: Tensor,
  411. weight: Optional[Tensor],
  412. reduction: int,
  413. ignore_index: int,
  414. total_weight: Tensor,
  415. ) -> Tensor:
  416. channel_dim = 0 if self.dim() < 2 else 1
  417. if reduction == Reduction.MEAN.value:
  418. grad_output = grad_output / total_weight
  419. target = target.unsqueeze(channel_dim)
  420. safe_target = torch.where(target != ignore_index, target, 0)
  421. grad_input = torch.zeros_like(self)
  422. grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
  423. if grad_input.dim() > grad_output.dim() > 0:
  424. grad_output = grad_output.unsqueeze(channel_dim)
  425. if weight is not None:
  426. new_shape = [1 for _ in range(self.dim())]
  427. new_shape[channel_dim] = weight.shape[0]
  428. weight = weight.reshape(new_shape)
  429. grad_output = grad_output * weight
  430. grad_output = torch.where(target != ignore_index, grad_output, 0)
  431. return grad_input * grad_output
  432. @register_decomposition(aten.glu_backward)
  433. @out_wrapper("grad_input")
  434. @pw_cast_for_opmath
  435. def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor:
  436. assert self.dim() > 0, "glu does not support 0-dimensional tensors"
  437. wrap_dim = utils.canonicalize_dim(self.dim(), dim)
  438. nIn = self.size(wrap_dim)
  439. assert (
  440. nIn % 2 == 0
  441. ), f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}"
  442. inputSize = nIn // 2
  443. firstHalf = self.narrow(wrap_dim, 0, inputSize)
  444. secondHalf = self.narrow(wrap_dim, inputSize, inputSize)
  445. gradInputFirstHalf = torch.sigmoid(secondHalf)
  446. gradInputSecondHalf = (
  447. (1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output
  448. )
  449. gradInputFirstHalf = gradInputFirstHalf * grad_output
  450. return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim)
  451. @register_decomposition(aten.nll_loss_backward)
  452. @out_wrapper("grad_input")
  453. def nll_loss_backward(
  454. grad_output: Tensor,
  455. self: Tensor,
  456. target: Tensor,
  457. weight: Optional[Tensor],
  458. reduction: int,
  459. ignore_index: int,
  460. total_weight: Tensor,
  461. ) -> Tensor:
  462. assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D"
  463. assert (
  464. target.dim() <= 1
  465. ), "0D or 1D target tensor expected, multi-target not supported"
  466. no_batch_dim = self.dim() == 1 and target.dim() == 0
  467. assert no_batch_dim or (
  468. self.shape[0] == target.shape[0]
  469. ), f"size mismatch (got input: {self.shape}, target: {target.shape})"
  470. assert total_weight.numel() == 1, (
  471. "expected total_weight to be a single element tensor, got: ",
  472. f"{total_weight.shape} ({total_weight.numel()} elements)",
  473. )
  474. assert (
  475. weight is None or weight.numel() == self.shape[-1]
  476. ), "weight tensor should be defined either for all or no classes"
  477. if reduction == Reduction.NONE.value and self.dim() == 2:
  478. assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], (
  479. f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but "
  480. f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}"
  481. )
  482. else:
  483. assert (
  484. grad_output.dim() <= 1 and grad_output.numel() == 1
  485. ), f"Expected a single element grad_output tensor, but got: {grad_output.shape}"
  486. return _nll_loss_backward(
  487. grad_output, self, target, weight, reduction, ignore_index, total_weight
  488. )
  489. @register_decomposition(aten.nll_loss2d_backward)
  490. @out_wrapper("grad_input")
  491. def nll_loss2d_backward(
  492. grad_output: Tensor,
  493. self: Tensor,
  494. target: Tensor,
  495. weight: Optional[Tensor],
  496. reduction: int,
  497. ignore_index: int,
  498. total_weight: Tensor,
  499. ) -> Tensor:
  500. assert (
  501. self.dim() == 4
  502. ), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}"
  503. assert (
  504. target.dim() == 3
  505. ), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}"
  506. assert (
  507. self.shape[0] == target.shape[0]
  508. and self.shape[2] == target.shape[1]
  509. and self.shape[3] == target.shape[2]
  510. ), f"size mismatch (got input: {self.shape}, target: {target.shape}"
  511. assert total_weight.numel() == 1, (
  512. "expected total_weight to be a single element tensor, "
  513. f"got: {total_weight.shape} ( {total_weight.numel()}, elements)"
  514. )
  515. return _nll_loss_backward(
  516. grad_output, self, target, weight, reduction, ignore_index, total_weight
  517. )
  518. @register_decomposition(aten.binary_cross_entropy)
  519. @out_wrapper()
  520. @pw_cast_for_opmath
  521. def binary_cross_entropy(
  522. self: Tensor,
  523. target: Tensor,
  524. weight: Optional[Tensor] = None,
  525. reduction: int = Reduction.MEAN.value,
  526. ) -> Tensor:
  527. # We cannot currently model this without introducing data-dependent control flow
  528. # TORCH_CHECK(
  529. # (input_val >= 0) && (input_val <= 1),
  530. # "all elements of input should be between 0 and 1"
  531. # )
  532. loss = (target - 1) * torch.maximum(
  533. torch.log1p(-self), self.new_full((), -100)
  534. ) - target * torch.maximum(torch.log(self), self.new_full((), -100))
  535. if weight is not None:
  536. loss = loss * weight
  537. return apply_loss_reduction(loss, reduction)
  538. @register_decomposition(aten.binary_cross_entropy_backward)
  539. @out_wrapper("grad_input")
  540. @pw_cast_for_opmath
  541. def binary_cross_entropy_backward(
  542. grad_output: Tensor,
  543. self: Tensor,
  544. target: Tensor,
  545. weight: Optional[Tensor] = None,
  546. reduction: int = Reduction.MEAN.value,
  547. ) -> Tensor:
  548. EPSILON = 1e-12
  549. result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON)
  550. if weight is not None:
  551. result = result * weight
  552. if reduction == Reduction.MEAN.value:
  553. result = result / self.numel()
  554. return result
  555. @register_decomposition(aten.soft_margin_loss)
  556. @out_wrapper()
  557. @pw_cast_for_opmath
  558. def soft_margin_loss(
  559. input: Tensor,
  560. target: Tensor,
  561. reduction: int = Reduction.MEAN.value,
  562. ) -> Tensor:
  563. loss = torch.log1p(torch.exp(-input * target))
  564. return apply_loss_reduction(loss, reduction)
  565. @register_decomposition(aten.soft_margin_loss_backward)
  566. @out_wrapper("grad_input")
  567. @pw_cast_for_opmath
  568. def soft_margin_loss_backward(
  569. grad_output: Tensor,
  570. self: Tensor,
  571. target: Tensor,
  572. reduction: int = Reduction.MEAN.value,
  573. ) -> Tensor:
  574. grad_input = target * grad_output * (torch.sigmoid(target * self) - 1)
  575. if reduction == Reduction.MEAN.value:
  576. grad_input = grad_input / self.numel()
  577. return grad_input
  578. @register_decomposition(aten.dist)
  579. @out_wrapper()
  580. def dist(input: Tensor, other: Tensor, p: float = 2):
  581. return aten.norm(input - other, p=p)
  582. @register_decomposition(aten._euclidean_dist)
  583. @out_wrapper()
  584. def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor:
  585. x1_norm = x1.pow(2).sum(-1, True)
  586. x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format)
  587. x2_norm = x2.pow(2).sum(-1, True)
  588. x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format)
  589. x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1)
  590. x2_ = torch.cat([x2, x2_pad, x2_norm], -1)
  591. result = x1_.matmul(x2_.mT)
  592. return result.clamp_min(0).sqrt()
  593. @register_decomposition(aten.slice_backward)
  594. @out_wrapper()
  595. def slice_backward(
  596. grad_output: Tensor,
  597. input_sizes: List[int],
  598. dim: int,
  599. start: int,
  600. end: int,
  601. step: int,
  602. ):
  603. grad_input = grad_output.new_zeros(input_sizes)
  604. return torch.slice_scatter(grad_input, grad_output, dim, start, end, step)
  605. @register_decomposition(aten.slice.Tensor)
  606. def slice_forward(
  607. # Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1
  608. self: Tensor,
  609. dim: int = 0,
  610. start: Optional[int] = None,
  611. end: Optional[int] = None,
  612. step: int = 1,
  613. ):
  614. ndim = self.dim()
  615. if ndim == 0:
  616. raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
  617. dim = utils.canonicalize_dim(self.dim(), dim)
  618. sizes = list(self.size())
  619. strides = list(self.stride())
  620. if step <= 0:
  621. raise RuntimeError("slice step must be positive")
  622. start_val = start if start is not None else 0
  623. end_val = end if end is not None else sys.maxsize # 2^63 - 1
  624. if start_val < 0:
  625. start_val += sizes[dim]
  626. if end_val < 0:
  627. end_val += sizes[dim]
  628. if start_val < 0:
  629. start_val = 0
  630. elif start_val > sizes[dim]:
  631. start_val = sizes[dim]
  632. if end_val < start_val:
  633. end_val = start_val
  634. elif end_val > sizes[dim]:
  635. end_val = sizes[dim]
  636. storage_offset = self.storage_offset() + start_val * strides[dim]
  637. len = end_val - start_val
  638. sizes[dim] = (len + step - 1) // step
  639. strides[dim] *= step
  640. if self.is_quantized:
  641. raise NotImplementedError(
  642. "Slice decomposition for quantized tensors aren't implemented"
  643. )
  644. else:
  645. return self.as_strided(sizes, strides, storage_offset)
  646. @register_decomposition(aten.select_backward)
  647. @out_wrapper()
  648. def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int):
  649. grad_input = grad_output.new_zeros(input_sizes)
  650. return torch.select_scatter(grad_input, grad_output, dim, index)
  651. @register_decomposition(aten.diagonal_backward)
  652. @out_wrapper()
  653. def diagonal_backward(
  654. grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int
  655. ):
  656. grad_input = grad_output.new_zeros(input_sizes)
  657. return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
  658. def _cast_grad_to_input_dtype(
  659. grad_output: Tensor, grad_input: Tensor, input_dtype: torch.dtype
  660. ):
  661. if grad_output.dtype != input_dtype:
  662. grad_input = grad_input.to(input_dtype)
  663. return grad_input
  664. @register_decomposition(aten._softmax_backward_data)
  665. @out_wrapper("grad_input")
  666. @compute_only_pw_cast_for_opmath
  667. def _softmax_backward_data(
  668. grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype
  669. ):
  670. new_grad_output = grad_output * output
  671. grad_input = new_grad_output - output * torch.sum(
  672. new_grad_output, dim=dim, keepdim=True
  673. )
  674. # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor
  675. # if grad_output.device == torch.device("cpu"):
  676. # return grad_input.contiguous()
  677. return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous()
  678. @register_decomposition(aten._log_softmax_backward_data)
  679. @out_wrapper()
  680. @compute_only_pw_cast_for_opmath
  681. def _log_softmax_backward_data(
  682. grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype
  683. ):
  684. grad_input = grad_output - torch.exp(output) * torch.sum(
  685. grad_output, dim=dim, keepdim=True
  686. )
  687. return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype)
  688. def _im2col_col2im_indices_along_dim(
  689. input_d, kernel_d, dilation_d, padding_d, stride_d, device
  690. ):
  691. """Utility function to implement im2col and col2im"""
  692. blocks_d = input_d + padding_d * 2 - dilation_d * (kernel_d - 1)
  693. arange_kw = partial(torch.arange, dtype=torch.int64, device=device)
  694. # Stride kernel over input and find starting indices along dim d
  695. blocks_d_indices = arange_kw(0, blocks_d, stride_d).unsqueeze(0)
  696. # Apply dilation on kernel and find its indices along dim d
  697. kernel_grid = arange_kw(0, kernel_d * dilation_d, dilation_d).unsqueeze(-1)
  698. # Broadcast and add kernel starting positions (indices) with
  699. # kernel_grid along dim d, to get block indices along dim d
  700. return blocks_d_indices + kernel_grid
  701. @register_decomposition(aten.im2col)
  702. @out_wrapper()
  703. def im2col(
  704. input: Tensor,
  705. kernel_size: List[int],
  706. dilation: List[int],
  707. padding: List[int],
  708. stride: List[int],
  709. ) -> Tensor:
  710. torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
  711. torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
  712. torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
  713. torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
  714. def check_positive(param, param_name, strict=True):
  715. cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
  716. torch._check(
  717. cond, lambda: "{param_name} should be greater {'than' zero, but got {param}"
  718. )
  719. check_positive(kernel_size, "kernel_size")
  720. check_positive(dilation, "dilation")
  721. check_positive(dilation, "padding", strict=False)
  722. check_positive(stride, "stride")
  723. shape = input.shape
  724. ndim = len(shape)
  725. torch._check(
  726. ndim in (3, 4) and all(d != 0 for d in shape[-3:]),
  727. lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size "
  728. f"and non-zero dimensions, but got: {tuple(shape)}",
  729. )
  730. output_size = tuple(
  731. 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st
  732. for out, pad, dil, ker, st in zip(
  733. shape[-2:], padding, dilation, kernel_size, stride
  734. )
  735. )
  736. torch._check(
  737. all(c > 0 for c in output_size),
  738. lambda: f"Given an input with spacial size {tuple(shape[-2:])}, "
  739. f"kernel_size={kernel_size}, dilation={dilation}, "
  740. f"padding={padding}, stride={stride}, "
  741. "the calculated shape of the array of sliding blocks "
  742. f"is {output_size}, but its components must be at least one.",
  743. )
  744. batched_input = ndim == 4
  745. if not batched_input:
  746. input = input.unsqueeze(0)
  747. batch_dim, channel_dim, input_h, input_w = input.shape
  748. stride_h, stride_w = stride
  749. padding_h, padding_w = padding
  750. dilation_h, dilation_w = dilation
  751. kernel_h, kernel_w = kernel_size
  752. blocks_row_indices = _im2col_col2im_indices_along_dim(
  753. input_h, kernel_h, dilation_h, padding_h, stride_h, input.device
  754. )
  755. blocks_col_indices = _im2col_col2im_indices_along_dim(
  756. input_w, kernel_w, dilation_w, padding_w, stride_w, input.device
  757. )
  758. # Note that F.pad takes (padding_left, padding_right, padding_top, padding_bottom)
  759. # ugh
  760. padded_input = F.pad(input, (padding_w, padding_w, padding_h, padding_h))
  761. blocks_row_indices = blocks_row_indices.unsqueeze(-1).unsqueeze(-1)
  762. output = padded_input[:, :, blocks_row_indices, blocks_col_indices]
  763. output = output.permute(0, 1, 2, 4, 3, 5)
  764. num_blocks_row = blocks_row_indices.size(1)
  765. num_blocks_col = blocks_col_indices.size(1)
  766. output = output.reshape(
  767. batch_dim, channel_dim * kernel_h * kernel_w, num_blocks_row * num_blocks_col
  768. )
  769. if not batched_input:
  770. output = output.squeeze(0)
  771. return output
  772. @register_decomposition(aten.col2im)
  773. @out_wrapper()
  774. @pw_cast_for_opmath
  775. def col2im(
  776. input: Tensor,
  777. output_size: List[int],
  778. kernel_size: List[int],
  779. dilation: List[int],
  780. padding: List[int],
  781. stride: List[int],
  782. ) -> Tensor:
  783. torch._check(len(output_size) == 2, lambda: "only 2D output_size supported")
  784. torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
  785. torch._check(len(dilation) == 2, lambda: "only 2D dilation supported")
  786. torch._check(len(padding) == 2, lambda: "only 2D padding supported")
  787. torch._check(len(stride) == 2, lambda: "only 2D stride supported")
  788. def check_positive(param, param_name, strict=True):
  789. cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
  790. torch._check(
  791. cond, lambda: "{param_name} should be greater than zero, but got {param}"
  792. )
  793. check_positive(kernel_size, "kernel_size")
  794. check_positive(dilation, "dilation")
  795. check_positive(padding, "padding", strict=False)
  796. check_positive(stride, "stride")
  797. check_positive(output_size, "output_size")
  798. shape = input.shape
  799. ndim = len(shape)
  800. torch._check(
  801. ndim in (2, 3) and all(d != 0 for d in shape[-2:]),
  802. lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size "
  803. f"and non-zero dimensions, but got: {tuple(shape)}",
  804. )
  805. prod_kernel_size = kernel_size[0] * kernel_size[1]
  806. torch._check(
  807. shape[-2] % prod_kernel_size == 0,
  808. lambda: "Expected size of input's first non-batch dimension to be divisible by the "
  809. f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and "
  810. f"kernel_size={kernel_size}",
  811. )
  812. col = [
  813. 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st
  814. for out, pad, dil, ker, st in zip(
  815. output_size, padding, dilation, kernel_size, stride
  816. )
  817. ]
  818. L = col[0] * col[1]
  819. torch._check(
  820. shape[-1] == L,
  821. lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
  822. f"dilation={dilation}, padding={padding}, stride={stride}, "
  823. f"expected input.size(-1) to be {L} but got {shape[-1]}.",
  824. )
  825. torch._check(
  826. L > 0,
  827. lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
  828. f"dilation={dilation}, padding={padding}, stride={stride}, "
  829. f"expected input.size(-1) to be {L} but got {shape[-1]}.",
  830. )
  831. batched_input = ndim == 3
  832. if not batched_input:
  833. input = input.unsqueeze(0)
  834. shape = input.shape
  835. out_h, out_w = output_size
  836. stride_h, stride_w = stride
  837. padding_h, padding_w = padding
  838. dilation_h, dilation_w = dilation
  839. kernel_h, kernel_w = kernel_size
  840. # col2im is defined as the backwards of im2col, so we differentiate its decomposition by hand
  841. input = input.reshape([shape[0], shape[1] // prod_kernel_size] + kernel_size + col)
  842. input = input.permute(0, 1, 2, 4, 3, 5)
  843. indices_row = _im2col_col2im_indices_along_dim(
  844. out_h, kernel_h, dilation_h, padding_h, stride_h, input.device
  845. )
  846. indices_row = _unsqueeze_to_dim(indices_row, 4)
  847. indices_col = _im2col_col2im_indices_along_dim(
  848. out_w, kernel_w, dilation_w, padding_w, stride_w, input.device
  849. )
  850. output_padded_size = [o + 2 * p for o, p in zip(output_size, padding)]
  851. output = input.new_zeros(
  852. [shape[0], shape[1] // prod(kernel_size)] + output_padded_size
  853. )
  854. idx = (None, None, indices_row, indices_col)
  855. output = aten._unsafe_index_put(output, idx, input, accumulate=True)
  856. output = F.pad(output, (-padding_w, -padding_w, -padding_h, -padding_h))
  857. if not batched_input:
  858. output = output.squeeze(0)
  859. return output
  860. @register_decomposition(aten.native_dropout_backward)
  861. @out_wrapper()
  862. def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
  863. # According to the CUDA kernel implementation we should have this test;
  864. # but it seems to fail tests!
  865. # torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
  866. # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format
  867. # This different from TensorIterator's behavior
  868. r = (grad_output * (mask.type_as(grad_output) * scale)).clone(
  869. memory_format=utils.suggest_memory_format(grad_output)
  870. )
  871. return r
  872. @register_decomposition(aten.unfold_backward)
  873. @out_wrapper()
  874. def unfold_backward(
  875. grad: Tensor, input_size: List[int], dimension: int, size: int, step: int
  876. ) -> Tensor:
  877. if len(input_size) == 0:
  878. return torch.squeeze_copy(grad, 0)
  879. dim = utils.canonicalize_dim(len(input_size), dimension)
  880. idx = torch.arange(input_size[dim], device=grad.device, dtype=torch.int32)
  881. idx = idx.unfold(0, size, step).flatten()
  882. grad = grad.movedim(-1, dim + 1).flatten(dim, dim + 1)
  883. # nb. At the moment this generates two kernels in triton
  884. # It could potentially be fused into one call to scatter_reduce,
  885. # in the case step <= size provided scatter_reduce generates 1 kernel
  886. grad_input = grad.new_zeros(input_size)
  887. index = (None,) * dim + (idx,)
  888. return aten._unsafe_index_put(grad_input, index, grad, accumulate=True).contiguous()
  889. @register_decomposition(aten.logit_backward.default)
  890. @pw_cast_for_opmath
  891. def logit_backward(
  892. grad_output: Tensor, self: Tensor, eps: Optional[float] = None
  893. ) -> Tensor:
  894. if eps is not None:
  895. lo = eps
  896. hi = 1.0 - lo
  897. return torch.where(
  898. torch.logical_and(self >= lo, self <= hi),
  899. grad_output / (self * (1.0 - self)),
  900. 0.0,
  901. )
  902. else:
  903. return torch.where(
  904. torch.logical_and(self >= 0.0, self <= 1.0),
  905. grad_output / (self * (1.0 - self)),
  906. self.new_full((), float("nan")),
  907. )
  908. @register_decomposition(aten.dropout)
  909. @aten.dropout.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  910. @aten.dropout.default.py_impl(DispatchKey.Autograd)
  911. def dropout(input: Tensor, p: float, train: Optional[bool]):
  912. if train and p != 0:
  913. return aten.native_dropout(input, p, train)[0]
  914. else:
  915. return input.clone()
  916. @register_decomposition(aten.native_dropout)
  917. @out_wrapper("out0", "out1")
  918. def native_dropout(input: Tensor, p: float, train: Optional[bool]):
  919. if train and p != 0:
  920. if p == 1:
  921. return (torch.zeros_like(input), torch.zeros_like(input, dtype=torch.bool))
  922. if not input.dtype.is_floating_point:
  923. raise RuntimeError(
  924. "result type Float can't be cast to the desired output type Long"
  925. )
  926. bool_mask = torch.rand_like(input) > p
  927. res = bool_mask * input * float(1.0 / (1.0 - p))
  928. return (res, bool_mask)
  929. else:
  930. return (input, torch.ones_like(input, dtype=torch.bool))
  931. @register_decomposition(aten._softmax)
  932. @out_wrapper()
  933. def _softmax(x: Tensor, dim: int, half_to_float: bool):
  934. # eager softmax returns a contiguous tensor. Ensure that decomp also returns
  935. # a contiguous tensor.
  936. x = x.contiguous()
  937. if half_to_float:
  938. assert x.dtype == torch.half
  939. computation_dtype, result_dtype = utils.elementwise_dtypes(
  940. x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  941. )
  942. x = x.to(computation_dtype)
  943. if x.numel() == 0:
  944. unnormalized = torch.exp(x)
  945. else:
  946. x_max = torch.amax(x, dim, keepdim=True)
  947. unnormalized = torch.exp(x - x_max)
  948. result = unnormalized / torch.sum(unnormalized, dim, keepdim=True)
  949. if not half_to_float:
  950. result = result.to(result_dtype)
  951. return result
  952. @register_decomposition(aten._log_softmax)
  953. @out_wrapper()
  954. def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
  955. # eager log_softmax returns a contiguous tensor. Ensure that decomp also
  956. # returns a contiguous tensor.
  957. x = x.contiguous()
  958. if half_to_float:
  959. assert x.dtype == torch.half
  960. computation_dtype, result_dtype = utils.elementwise_dtypes(
  961. x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  962. )
  963. x = x.to(computation_dtype)
  964. if x.numel() == 0:
  965. shifted = x
  966. else:
  967. x_max = torch.amax(x, dim, keepdim=True)
  968. shifted = x - x_max
  969. shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
  970. result = shifted - shifted_logsumexp
  971. if not half_to_float:
  972. result = result.to(result_dtype)
  973. return result
  974. @register_decomposition(aten.embedding)
  975. @out_wrapper()
  976. def embedding(
  977. weight: Tensor,
  978. indices: Tensor,
  979. padding_idx: int = -1,
  980. scale_grad_by_freq: bool = False,
  981. sparse: bool = False,
  982. ) -> Tensor:
  983. assert weight.dim() == 2, "'weight' must be 2-D"
  984. # Nb. scale_grad_by_freq is not used in the forward
  985. if indices.ndim <= 1:
  986. # We need this one as weight[indices] calls item() in these cases
  987. out = weight.index_select(0, indices)
  988. if indices.ndim == 0:
  989. out = out.squeeze(0)
  990. return out
  991. else:
  992. return weight[indices]
  993. @register_decomposition(aten.embedding_dense_backward)
  994. @out_wrapper()
  995. def embedding_dense_backward(
  996. grad_output: Tensor,
  997. indices: Tensor,
  998. num_weights: int,
  999. padding_idx: int,
  1000. scale_grad_by_freq: bool,
  1001. ):
  1002. computation_dtype, result_dtype = utils.elementwise_dtypes(
  1003. grad_output, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  1004. )
  1005. grad_output = grad_output.to(computation_dtype)
  1006. indices = _maybe_convert_to_dtype(indices, torch.long) # type: ignore[assignment]
  1007. if scale_grad_by_freq:
  1008. counts = indices.new_zeros((num_weights,))
  1009. ones = torch.ones_like(indices)
  1010. counts = aten._unsafe_index_put(counts, [indices], ones, accumulate=True)
  1011. grad_weights_scale = counts[indices]
  1012. grad_output = grad_output / grad_weights_scale.unsqueeze(-1)
  1013. mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim)
  1014. grad = grad_output.masked_fill(mask, 0)
  1015. grad_weight = grad_output.new_zeros(
  1016. (num_weights,) + grad_output.shape[indices.ndim :]
  1017. )
  1018. return aten._unsafe_index_put(grad_weight, [indices], grad, accumulate=True).to(
  1019. result_dtype
  1020. )
  1021. def prod(x: List[int]):
  1022. r = 1
  1023. for i in x:
  1024. r *= i
  1025. return r
  1026. def _pad_chunk(
  1027. tensors: List[Tensor],
  1028. dim: int,
  1029. num_chunks: int,
  1030. ) -> List[Tensor]:
  1031. padded_tensors = []
  1032. for tensor in tensors:
  1033. tensor_size = tensor.size()
  1034. pad_along_dim = (tensor_size[dim] + num_chunks - 1) // num_chunks * num_chunks
  1035. if pad_along_dim != tensor_size[dim]:
  1036. # Use aten.constant_pad_nd instead of copy_ for functionalization
  1037. pad = [0] * 2 * (tensor.ndim - dim - 1) + [
  1038. 0,
  1039. pad_along_dim - tensor_size[dim],
  1040. ]
  1041. tensor = aten.constant_pad_nd(tensor, pad, 0)
  1042. view_size = tensor_size[:dim] + torch.Size([num_chunks, -1])
  1043. padded_tensors.append(tensor.view(view_size))
  1044. return padded_tensors
  1045. def have_same_ndims(tensors: List[Tensor]):
  1046. ndim = tensors[0].ndim
  1047. for tensor in tensors:
  1048. if tensor.ndim != ndim:
  1049. return False
  1050. return True
  1051. def leading_dimension_matches(tensors: List[Tensor], dim: int):
  1052. leading_dim_sizes = tensors[0].size()[:dim]
  1053. for tensor in tensors:
  1054. torch._check(
  1055. tensor.size()[:dim] == leading_dim_sizes,
  1056. lambda: "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors",
  1057. )
  1058. def _preprocess_chunk_cat_inputs(
  1059. tensors: List[Tensor],
  1060. dim: int,
  1061. num_chunks: int,
  1062. ):
  1063. torch._check(num_chunks >= 1, lambda: "_chunk_cat expects positive num_chunks")
  1064. torch._check(
  1065. len(tensors) > 0, lambda: "_chunk_cat expects a non-empty input tensor list"
  1066. )
  1067. expected_dtype = tensors[0].dtype
  1068. expected_device = tensors[0].device
  1069. for tensor in tensors:
  1070. torch._check(tensor.numel() > 0, lambda: "_chunk_cat expects non-empty tensor")
  1071. torch._check(
  1072. tensor.dtype == expected_dtype,
  1073. lambda: "_chunk_cat expects all input tensors with the same dtype",
  1074. )
  1075. torch._check(
  1076. tensor.device == expected_device,
  1077. lambda: "_chunk_cat expects all inputs tensors on the same device",
  1078. )
  1079. if have_same_ndims(tensors):
  1080. dim = utils.canonicalize_dim(tensors[0].dim(), dim)
  1081. else:
  1082. torch._check(
  1083. dim >= 0,
  1084. lambda: "_chunk_cat expects non-negative dim when input tensors have different ndims",
  1085. )
  1086. for tensor in tensors:
  1087. torch._check(
  1088. dim < tensor.ndim,
  1089. lambda: "_chunk_cat expects dim < ndim for all input tensors",
  1090. )
  1091. leading_dimension_matches(tensors, dim)
  1092. return dim
  1093. @register_decomposition([aten._chunk_cat.default, aten._chunk_cat.out])
  1094. def _chunk_cat(
  1095. tensors: List[Tensor],
  1096. dim: int,
  1097. num_chunks: int,
  1098. out: Optional[Tensor] = None,
  1099. ) -> Tensor:
  1100. dim = _preprocess_chunk_cat_inputs(tensors, dim, num_chunks)
  1101. padded_tensors = _pad_chunk(tensors, dim, num_chunks)
  1102. if out is None:
  1103. return torch.cat(padded_tensors, dim + 1)
  1104. else:
  1105. torch.cat(padded_tensors, dim + 1, out=out)
  1106. return out
  1107. @register_decomposition(aten.split_with_sizes)
  1108. def split_with_sizes(
  1109. self: Tensor, split_sizes: List[int], dim: int = 0
  1110. ) -> List[Tensor]:
  1111. # NB: Perform the check_is_size tests first so that the
  1112. # sum test does not try to do a replacement
  1113. for i in range(len(split_sizes)):
  1114. torch._check_is_size(
  1115. split_sizes[i],
  1116. lambda: "split_with_sizes expects split_sizes have only non-negative entries",
  1117. )
  1118. torch._check_with(
  1119. ValueError,
  1120. sum(split_sizes) == self.shape[dim],
  1121. lambda: f"Split sizes add up to {sum(split_sizes)} but got the tensor's size of {self.shape[dim]}",
  1122. )
  1123. num_splits = len(split_sizes)
  1124. splits = []
  1125. start_idx = 0
  1126. # Avoid importing sympy at a module level
  1127. from torch.fx.experimental.symbolic_shapes import expect_true
  1128. for i in range(num_splits):
  1129. length = split_sizes[i]
  1130. # We know this is true thanks to the sum, but this assertion helps
  1131. # out our internal reasoning
  1132. expect_true(start_idx + length <= self.shape[dim])
  1133. splits.append(self.narrow(dim, start_idx, length))
  1134. start_idx += length
  1135. return splits
  1136. # out_wrapper currently does not allow optional outputs
  1137. @register_decomposition(
  1138. [aten.split_with_sizes_copy.default, aten.split_with_sizes_copy.out]
  1139. )
  1140. def split_with_sizes_copy(
  1141. self: Tensor,
  1142. split_sizes: List[int],
  1143. dim: int = 0,
  1144. out: Optional[List[Tensor]] = None,
  1145. ) -> Optional[List[Tensor]]:
  1146. splits = split_with_sizes(self, split_sizes, dim=dim)
  1147. if out is None:
  1148. return [s.clone(memory_format=torch.contiguous_format) for s in splits]
  1149. else:
  1150. for output, split in zip(out, splits):
  1151. _maybe_resize_out(output, split.shape)
  1152. _safe_copy_out(copy_from=split, copy_to=output, exact_dtype=True)
  1153. return None
  1154. @register_decomposition(aten.unsafe_split.Tensor)
  1155. def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]:
  1156. return aten.split.Tensor(input, split_size, dim)
  1157. @register_decomposition(aten.unsafe_split_with_sizes.default)
  1158. def unsafe_split_with_sizes(
  1159. input: Tensor, split_sizes: List[int], dim: int = 0
  1160. ) -> Tuple[Tensor, ...]:
  1161. return aten.split_with_sizes.default(input, split_sizes, dim)
  1162. @register_decomposition(aten.split.Tensor)
  1163. def split(self: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]:
  1164. input_sizes = self.shape
  1165. dim_size = input_sizes[dim]
  1166. if split_size == 0:
  1167. assert dim_size == 0
  1168. return (self,)
  1169. chunks = (dim_size + split_size - 1) // split_size
  1170. # Avoid importing sympy at a module level
  1171. from torch.fx.experimental.symbolic_shapes import guard_int
  1172. chunks = guard_int(chunks)
  1173. split_sizes = [split_size for i in range(chunks)]
  1174. split_sizes[-1] = split_size - (split_size * chunks - dim_size)
  1175. return torch.split(self, split_sizes, dim)
  1176. @aten.tensor_split.tensor_indices_or_sections.py_impl(
  1177. DispatchKey.CompositeImplicitAutograd
  1178. )
  1179. def tensor_split_tensor_indices_or_sections_py_impl(
  1180. self: Tensor,
  1181. tensor_indices_or_sections: Tensor,
  1182. dim: int = 0,
  1183. ) -> Tuple[Tensor, ...]:
  1184. assert tensor_indices_or_sections.device.type == "cpu"
  1185. assert tensor_indices_or_sections.dtype == torch.int64
  1186. split_dim = tensor_indices_or_sections.dim()
  1187. torch._check(
  1188. split_dim == 1 or split_dim == 0,
  1189. lambda: "tensor_split expected tensor_indices_or_sections to be a zero-dimensional "
  1190. f"or one-dimensional tensor, but got a tensor with {split_dim} dims",
  1191. )
  1192. if split_dim == 0:
  1193. sections = tensor_indices_or_sections.item()
  1194. assert isinstance(sections, IntLike)
  1195. return self.tensor_split(sections, dim)
  1196. else:
  1197. indices = [i.item() for i in tensor_indices_or_sections]
  1198. # WARNING: Tempted to torch._check_is_size on the indices here? You
  1199. # can't: tensor_split works with negative values in indices:
  1200. #
  1201. # >>> torch.tensor_split(torch.randn(10), torch.tensor([-5, 5]))
  1202. # (tensor([ 0.3540, 2.1074, -0.8507, 1.1639, 0.3055]), tensor([]),
  1203. # tensor([-0.4285, 1.0692, -0.1776, 0.9362, 1.6143]))
  1204. #
  1205. # Sorry, I don't make the rules. Explicitly do the item call in user
  1206. # code if you KNOW that they are non-negative.
  1207. return self.tensor_split(indices, dim)
  1208. # TODO: this doesn't appear to have enough precision in bfloat16
  1209. @register_decomposition(aten.addmm)
  1210. @out_wrapper()
  1211. @pw_cast_for_opmath
  1212. def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
  1213. if not self.is_floating_point() and not self.is_complex():
  1214. beta = int(beta)
  1215. alpha = int(alpha)
  1216. out = alpha * torch.mm(mat1, mat2)
  1217. if beta == 0:
  1218. return out
  1219. # The output of aten.addmm is contiguous, we need to match this behavior in the decomposition.
  1220. # The original implementation 'beta * self + out' would return a strided tensor if `self` is strided.
  1221. # We thus use `out`, the output of torch.mm, which is always contiguous, as the first argument for addition.
  1222. # This is relying on TensorIterator's behavior that it takes higher precedence on the stride of first input.
  1223. # Alternative, we can write `(beta * self + out).contiguous()`, but it introduces another copy in some cases.
  1224. # This implementation is not ideal, and we should revisit this when we have a better solution.
  1225. return out + beta * self
  1226. @register_decomposition(aten._addmm_activation)
  1227. @out_wrapper()
  1228. @pw_cast_for_opmath
  1229. def _addmm_activation(
  1230. self: Tensor,
  1231. mat1: Tensor,
  1232. mat2: Tensor,
  1233. beta: int = 1,
  1234. alpha: int = 1,
  1235. use_gelu: bool = False,
  1236. ):
  1237. out = addmm(self, mat1, mat2, beta, alpha)
  1238. if use_gelu:
  1239. if self.is_cuda:
  1240. return aten.gelu(out, approximate="tanh")
  1241. else:
  1242. return aten.gelu(out)
  1243. return aten.relu(out)
  1244. @register_decomposition(aten.addmv)
  1245. @out_wrapper()
  1246. @pw_cast_for_opmath
  1247. def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1):
  1248. if not self.is_floating_point() and not self.is_complex():
  1249. beta = int(beta)
  1250. alpha = int(alpha)
  1251. out = alpha * torch.mv(mat1, vec)
  1252. if beta == 0:
  1253. return out
  1254. return out + beta * self
  1255. @register_decomposition(aten.native_group_norm_backward.default)
  1256. @pw_cast_for_opmath
  1257. def native_group_norm_backward(
  1258. grad_output: Tensor,
  1259. input: Tensor,
  1260. mean: Tensor,
  1261. rstd: Tensor,
  1262. gamma: Optional[Tensor],
  1263. N: int,
  1264. C: int,
  1265. HxW: int,
  1266. group: int,
  1267. output_mask: List[bool],
  1268. ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
  1269. utils.check_same_device(
  1270. grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False
  1271. )
  1272. utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False)
  1273. utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False)
  1274. torch._check(
  1275. input.numel() == N * C * HxW,
  1276. lambda: f"Expect input to have { N * C * HxW} elements",
  1277. )
  1278. torch._check(
  1279. mean.shape == (N, group),
  1280. lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}",
  1281. )
  1282. torch._check(
  1283. gamma is None or gamma.numel() == C,
  1284. lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}",
  1285. )
  1286. cpg, _rem = divmod(C, group)
  1287. torch._check(
  1288. _rem == 0,
  1289. lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}",
  1290. )
  1291. # Compute Internal gradients
  1292. ds = torch.mul(grad_output, input).view(N, C, HxW).sum(dim=[2])
  1293. db = grad_output.view(N, C, HxW).sum(dim=[2])
  1294. d_input: Optional[Tensor] = None
  1295. d_gamma: Optional[Tensor] = None
  1296. d_bias: Optional[Tensor] = None
  1297. if output_mask[0]:
  1298. s = 1.0 / (HxW * cpg)
  1299. if gamma is not None:
  1300. ds_val = torch.mul(ds, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2)
  1301. db_val = torch.mul(db, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2)
  1302. c1 = torch.mul(
  1303. rstd.unsqueeze(-1),
  1304. gamma.reshape(1, group, cpg),
  1305. )
  1306. else:
  1307. ds_val = ds.reshape(N, group, cpg).sum(2)
  1308. db_val = db.reshape(N, group, cpg).sum(2)
  1309. c1 = torch.mul(
  1310. rstd.unsqueeze(-1),
  1311. torch.ones((1, group, cpg), device=rstd.device),
  1312. )
  1313. c2 = (db_val * mean - ds_val) * rstd * rstd * rstd * s
  1314. c3 = -c2 * mean - db_val * rstd * s
  1315. c1 = c1.unsqueeze(-1)
  1316. c2 = _unsqueeze_to_dim(c2, 4)
  1317. c3 = _unsqueeze_to_dim(c3, 4)
  1318. d_input = (
  1319. torch.mul(grad_output.reshape(N, group, cpg, HxW), c1)
  1320. + torch.mul(input.reshape(N, group, cpg, HxW), c2)
  1321. + c3
  1322. )
  1323. d_input = d_input.reshape(input.shape).to(input.dtype)
  1324. if output_mask[1]:
  1325. d_gamma = (
  1326. (
  1327. (ds.view(N, group, cpg) - db.view(N, group, cpg) * mean.unsqueeze(-1))
  1328. * rstd.unsqueeze(-1)
  1329. )
  1330. .sum(dim=[0])
  1331. .reshape(C)
  1332. )
  1333. if output_mask[2]:
  1334. d_bias = db.sum(dim=[0])
  1335. return (d_input, d_gamma, d_bias)
  1336. # out_wrapper currently does not allow optional outputs
  1337. @register_decomposition(aten.native_group_norm_backward.out)
  1338. def native_group_norm_backward_out(
  1339. grad_output: Tensor,
  1340. input: Tensor,
  1341. mean: Tensor,
  1342. rstd: Tensor,
  1343. gamma: Optional[Tensor],
  1344. N: int,
  1345. C: int,
  1346. HxW: int,
  1347. group: int,
  1348. output_mask: List[bool],
  1349. *,
  1350. out0: torch.Tensor,
  1351. out1: torch.Tensor,
  1352. out2: torch.Tensor,
  1353. ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
  1354. result = native_group_norm_backward(
  1355. grad_output, input, mean, rstd, gamma, N, C, HxW, group, output_mask
  1356. )
  1357. grad_input = (out0, out1, out2)
  1358. for i, r in enumerate(result):
  1359. if r is not None:
  1360. _maybe_resize_out(grad_input[i], r.shape)
  1361. _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True)
  1362. return grad_input
  1363. def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]:
  1364. if x is not None:
  1365. return x.to(dtype)
  1366. return x
  1367. # TODO: Take a closer look at the type promotion semantics
  1368. @register_decomposition(aten.native_layer_norm_backward.default)
  1369. def native_layer_norm_backward(
  1370. grad_out: Tensor,
  1371. input: Tensor,
  1372. normalized_shape: List[int],
  1373. mean: Tensor,
  1374. rstd: Tensor,
  1375. weight: Optional[Tensor],
  1376. bias: Optional[Tensor],
  1377. output_mask: List[bool],
  1378. ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
  1379. input_shape = input.shape
  1380. input_ndim = input.dim()
  1381. computation_dtype = utils.get_computation_dtype(input.dtype)
  1382. grad_out_cast, input_cast, weight_cast, bias_cast = (
  1383. x.to(computation_dtype).contiguous() if x is not None else x
  1384. for x in (grad_out, input, weight, bias)
  1385. )
  1386. assert grad_out_cast is not None
  1387. axis = input_ndim - len(normalized_shape)
  1388. inner_dims = input_shape[axis:]
  1389. outer_dims = input_shape[:axis]
  1390. inner_dim_indices: List[int] = []
  1391. outer_dim_indices: List[int] = []
  1392. for i in range(input_ndim):
  1393. if i >= axis:
  1394. inner_dim_indices.append(i)
  1395. else:
  1396. outer_dim_indices.append(i)
  1397. N = prod(inner_dims) # type: ignore[arg-type]
  1398. M = prod(outer_dims) # type: ignore[arg-type]
  1399. if M <= 0 or N <= 0:
  1400. return (
  1401. input.new_zeros(input_shape) if output_mask[0] else None,
  1402. input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
  1403. input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
  1404. )
  1405. mean = _unsqueeze_to_dim(mean, input_cast.dim()) # type: ignore[union-attr]
  1406. rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr]
  1407. x_hat = (input_cast - mean) * rstd
  1408. if weight_cast is not None:
  1409. grad_x_hat = grad_out_cast * weight_cast
  1410. else:
  1411. grad_x_hat = grad_out_cast
  1412. a = grad_x_hat * N
  1413. b = torch.sum(grad_x_hat, inner_dim_indices, True)
  1414. c1 = torch.mul(grad_x_hat, x_hat)
  1415. c2 = torch.sum(c1, inner_dim_indices, True)
  1416. c3 = torch.mul(x_hat, c2)
  1417. inner = a - b - c3
  1418. d_input: Optional[Tensor] = None
  1419. d_weight: Optional[Tensor] = None
  1420. d_bias: Optional[Tensor] = None
  1421. if output_mask[0]:
  1422. d_input = (rstd / N) * inner
  1423. if output_mask[1] and weight_cast is not None:
  1424. if len(outer_dim_indices) > 0:
  1425. d_weight = torch.sum(grad_out_cast * x_hat, outer_dim_indices, False)
  1426. else:
  1427. d_weight = grad_out_cast * x_hat
  1428. if output_mask[2] and bias_cast is not None:
  1429. if len(outer_dim_indices) > 0:
  1430. d_bias = torch.sum(grad_out_cast, outer_dim_indices, False)
  1431. else:
  1432. d_bias = grad_out_cast.clone()
  1433. return (
  1434. _maybe_cast(d_input, input.dtype),
  1435. _maybe_cast(d_weight, input.dtype),
  1436. _maybe_cast(d_bias, input.dtype),
  1437. )
  1438. # out_wrapper currently does not allow optional outputs
  1439. @register_decomposition(aten.native_layer_norm_backward.out)
  1440. def native_layer_norm_backward_out(
  1441. grad_out: Tensor,
  1442. input: Tensor,
  1443. normalized_shape: List[int],
  1444. mean: Tensor,
  1445. rstd: Tensor,
  1446. weight: Optional[Tensor],
  1447. bias: Optional[Tensor],
  1448. output_mask: List[bool],
  1449. *,
  1450. out0: torch.Tensor,
  1451. out1: torch.Tensor,
  1452. out2: torch.Tensor,
  1453. ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
  1454. result = native_layer_norm_backward(
  1455. grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask
  1456. )
  1457. grad_input = (out0, out1, out2)
  1458. for i, r in enumerate(result):
  1459. if r is not None:
  1460. _maybe_resize_out(grad_input[i], r.shape)
  1461. _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True)
  1462. return grad_input
  1463. def native_batch_norm_helper(
  1464. input: Tensor,
  1465. weight: Optional[Tensor],
  1466. bias: Optional[Tensor],
  1467. running_mean: Optional[Tensor],
  1468. running_var: Optional[Tensor],
  1469. training: bool,
  1470. momentum: float,
  1471. eps: float,
  1472. functional: bool,
  1473. ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
  1474. reduction_dims = [0] + list(range(2, input.dim()))
  1475. computation_dtype = utils.get_computation_dtype(input.dtype)
  1476. new_running_mean = running_mean
  1477. new_running_var = running_var
  1478. if training:
  1479. computation_dtype = utils.get_computation_dtype(input.dtype)
  1480. input_acc = input.to(dtype=computation_dtype)
  1481. biased_var, mean = torch.var_mean(
  1482. input_acc, dim=reduction_dims, correction=0, keepdim=True
  1483. )
  1484. rstd = torch.rsqrt(biased_var + eps)
  1485. output = (input - mean) * rstd
  1486. save_mean = torch.squeeze(mean, reduction_dims)
  1487. save_rstd = torch.squeeze(rstd, reduction_dims)
  1488. if running_mean is not None:
  1489. new_running_mean = momentum * save_mean + (1 - momentum) * running_mean
  1490. if not functional:
  1491. running_mean.copy_(new_running_mean)
  1492. if running_var is not None:
  1493. n = input.numel() / input.shape[1]
  1494. # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction
  1495. # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose
  1496. # numerics probably don't matter.
  1497. squeezed_var = torch.squeeze(biased_var, reduction_dims)
  1498. unbiased_var = squeezed_var * (n / (n - 1))
  1499. new_running_var = momentum * unbiased_var + (1 - momentum) * running_var
  1500. if not functional:
  1501. running_var.copy_(new_running_var)
  1502. else:
  1503. assert running_mean is not None and running_var is not None
  1504. running_mean = running_mean.to(dtype=computation_dtype, copy=True)
  1505. new_running_mean = running_mean
  1506. running_var = running_var.to(dtype=computation_dtype, copy=True)
  1507. new_running_var = running_var
  1508. mean = running_mean
  1509. invstd = 1 / (torch.sqrt(running_var + eps))
  1510. # Very annoying inconsistency where CPU and CUDA give different shapes
  1511. if input.device.type != "cpu":
  1512. save_mean = running_mean
  1513. save_rstd = invstd
  1514. else:
  1515. save_mean = input.new_zeros((0,))
  1516. save_rstd = input.new_zeros((0,))
  1517. mean = _unsqueeze_to_dim(mean, input.dim() - 1)
  1518. invstd = _unsqueeze_to_dim(invstd, input.dim() - 1)
  1519. output = (input - mean) * invstd
  1520. if weight is not None:
  1521. weight = weight.flatten()
  1522. weight = _unsqueeze_to_dim(weight, input.dim() - 1)
  1523. output = output * weight
  1524. if bias is not None:
  1525. bias = bias.flatten()
  1526. bias = _unsqueeze_to_dim(bias, input.dim() - 1)
  1527. output = output + bias
  1528. if input.device.type == "cpu":
  1529. save_mean = save_mean.to(dtype=input.dtype)
  1530. save_rstd = save_rstd.to(dtype=input.dtype)
  1531. return (
  1532. output.to(dtype=input.dtype),
  1533. save_mean,
  1534. save_rstd,
  1535. new_running_mean,
  1536. new_running_var,
  1537. )
  1538. @register_decomposition(aten.native_batch_norm)
  1539. @out_wrapper("out", "save_mean", "save_invstd")
  1540. def native_batch_norm(
  1541. input: Tensor,
  1542. weight: Optional[Tensor],
  1543. bias: Optional[Tensor],
  1544. running_mean: Optional[Tensor],
  1545. running_var: Optional[Tensor],
  1546. training: bool,
  1547. momentum: float,
  1548. eps: float,
  1549. ) -> Tuple[Tensor, Tensor, Tensor]:
  1550. output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
  1551. input, weight, bias, running_mean, running_var, training, momentum, eps, False
  1552. )
  1553. return output, save_mean, save_rstd
  1554. # TODO: this decomposition is NOT here to stay. We would much prefer replacing native_batch_norm
  1555. # with our new correctly schema'd _native_batch_norm_legit and its variants, but
  1556. # we cannot do that immediately in the C++ because it would be forwards incompatible
  1557. # with some mobile use cases.
  1558. #
  1559. # Since this change is most impactful for aot autograd/functionalization, we simply
  1560. # register this decomposition on the Autograd key for the python dispatcher (which is
  1561. # currently only used by aot autograd/functionalization and no one else, really).
  1562. # In two weeks or so, we should remove this decomposition and phase out the current native_batch_norm
  1563. # to be _native_batch_norm_legit and have the right schema (stating that there are input mutations).
  1564. @aten.native_batch_norm.default.py_impl(DispatchKey.Autograd)
  1565. @aten.native_batch_norm.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  1566. def native_batch_norm_decomposition(
  1567. input: Tensor,
  1568. weight: Optional[Tensor],
  1569. bias: Optional[Tensor],
  1570. running_mean: Optional[Tensor],
  1571. running_var: Optional[Tensor],
  1572. training: bool,
  1573. momentum: float,
  1574. eps: float,
  1575. ) -> Tuple[Tensor, Tensor, Tensor]:
  1576. if running_mean is None and running_var is None:
  1577. return aten._native_batch_norm_legit(
  1578. input, weight, bias, training, momentum, eps
  1579. )
  1580. if running_mean is None:
  1581. raise RuntimeError(
  1582. "running_mean is None, but running_var is provided. "
  1583. "They should both be None or both be provided."
  1584. )
  1585. if running_var is None:
  1586. raise RuntimeError(
  1587. "running_var is None, but running_mean is provided. "
  1588. "They should both be None or both be provided."
  1589. )
  1590. if training:
  1591. # HACK: batch norm consolidation should clean this up so this op doesn't take in a training arg.
  1592. return aten._native_batch_norm_legit(
  1593. input, weight, bias, running_mean, running_var, training, momentum, eps
  1594. )
  1595. else:
  1596. return aten._native_batch_norm_legit_no_training(
  1597. input, weight, bias, running_mean, running_var, momentum, eps
  1598. )
  1599. @aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  1600. def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> List[Tensor]:
  1601. dim_size = tensor.size(dim)
  1602. split_size = (dim_size + chunks - 1) // chunks
  1603. if split_size == 0 and dim_size == 0:
  1604. split_sizes = [split_size for _ in chunks]
  1605. split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
  1606. return torch.ops.aten.unsafe_split_with_sizes.default(tensor, split_sizes, dim)
  1607. return torch.ops.aten.unsafe_split.Tensor(tensor, split_size, dim)
  1608. @register_decomposition(aten._native_batch_norm_legit_no_training.default)
  1609. def _native_batch_norm_legit_no_training(
  1610. input: Tensor,
  1611. weight: Optional[Tensor],
  1612. bias: Optional[Tensor],
  1613. running_mean: Tensor,
  1614. running_var: Tensor,
  1615. momentum: float,
  1616. eps: float,
  1617. ) -> Tuple[Tensor, Tensor, Tensor]:
  1618. return aten._native_batch_norm_legit.default(
  1619. input,
  1620. weight,
  1621. bias,
  1622. running_mean,
  1623. running_var,
  1624. False, # training
  1625. momentum,
  1626. eps,
  1627. )
  1628. @register_decomposition(aten._native_batch_norm_legit.default)
  1629. def _native_batch_norm_legit(
  1630. input: Tensor,
  1631. weight: Optional[Tensor],
  1632. bias: Optional[Tensor],
  1633. running_mean: Tensor,
  1634. running_var: Tensor,
  1635. training: bool,
  1636. momentum: float,
  1637. eps: float,
  1638. ) -> Tuple[Tensor, Tensor, Tensor]:
  1639. output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
  1640. input, weight, bias, running_mean, running_var, training, momentum, eps, False
  1641. )
  1642. return output, save_mean, save_rstd
  1643. @register_decomposition(aten._native_batch_norm_legit.no_stats)
  1644. def _native_batch_norm_legit_no_stats(
  1645. input: Tensor,
  1646. weight: Optional[Tensor],
  1647. bias: Optional[Tensor],
  1648. training: bool,
  1649. momentum: float,
  1650. eps: float,
  1651. ) -> Tuple[Tensor, Tensor, Tensor]:
  1652. output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
  1653. input, weight, bias, None, None, training, momentum, eps, False
  1654. )
  1655. return output, save_mean, save_rstd
  1656. @register_decomposition(aten._native_batch_norm_legit_functional.default)
  1657. def _native_batch_norm_legit_functional(
  1658. input: Tensor,
  1659. weight: Optional[Tensor],
  1660. bias: Optional[Tensor],
  1661. running_mean: Tensor,
  1662. running_var: Tensor,
  1663. training: bool,
  1664. momentum: float,
  1665. eps: float,
  1666. ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
  1667. (
  1668. output,
  1669. save_mean,
  1670. save_rstd,
  1671. new_running_mean,
  1672. new_running_var,
  1673. ) = native_batch_norm_helper(
  1674. input, weight, bias, running_mean, running_var, training, momentum, eps, True
  1675. )
  1676. assert new_running_mean is not None, "new_running_mean should not be None"
  1677. assert new_running_var is not None, "new_running_var should not be None"
  1678. return output, save_mean, save_rstd, new_running_mean, new_running_var
  1679. def _get_batch_norm_reserve_tensor(
  1680. input: Tensor,
  1681. weight: Optional[Tensor],
  1682. bias: Optional[Tensor],
  1683. running_mean: Tensor,
  1684. running_var: Tensor,
  1685. eps: float,
  1686. training: bool,
  1687. ) -> Tensor:
  1688. """
  1689. Return a reserve tensor for batch norm, used only by cudnn to pass forward state to the
  1690. backward pass. This is needed for `_batch_norm_with_update` and `_batch_norm_no_update`,
  1691. which support a variety of backends including cudnn. We create this tensor here to get
  1692. the correct shape in the traced graph if we detect that will call the cudnn kernel,
  1693. and rely on DCE to avoid materializing this tensor.
  1694. """
  1695. backend = torch._C._select_batch_norm_backend( # type: ignore[attr-defined]
  1696. input, weight, bias, running_mean, running_var, True, eps
  1697. )
  1698. reserve_size = 0
  1699. if backend == torch._C._BatchNormBackend.Cudnn: # type: ignore[attr-defined]
  1700. reserve_size = torch._C._get_cudnn_batch_norm_reserve_space_size(input, training) # type: ignore[attr-defined]
  1701. return torch.empty(
  1702. reserve_size, dtype=torch.uint8, layout=input.layout, device=input.device
  1703. )
  1704. @register_decomposition(aten._batch_norm_with_update.default)
  1705. def _batch_norm_with_update(
  1706. input: Tensor,
  1707. weight: Optional[Tensor],
  1708. bias: Optional[Tensor],
  1709. running_mean: Tensor,
  1710. running_var: Tensor,
  1711. momentum: float,
  1712. eps: float,
  1713. ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  1714. output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
  1715. input,
  1716. weight,
  1717. bias,
  1718. running_mean,
  1719. running_var,
  1720. True, # training
  1721. momentum,
  1722. eps,
  1723. False, # functional
  1724. )
  1725. reserve = _get_batch_norm_reserve_tensor(
  1726. input, weight, bias, running_mean, running_var, eps, training=True
  1727. )
  1728. return output, save_mean, save_rstd, reserve
  1729. @register_decomposition(aten._batch_norm_with_update_functional.default)
  1730. def _batch_norm_with_update_functional(
  1731. input: Tensor,
  1732. weight: Optional[Tensor],
  1733. bias: Optional[Tensor],
  1734. running_mean: Tensor,
  1735. running_var: Tensor,
  1736. momentum: float,
  1737. eps: float,
  1738. ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
  1739. (
  1740. output,
  1741. save_mean,
  1742. save_rstd,
  1743. new_rm,
  1744. new_rv,
  1745. ) = native_batch_norm_helper(
  1746. input, weight, bias, running_mean, running_var, True, momentum, eps, True
  1747. )
  1748. reserve = _get_batch_norm_reserve_tensor(
  1749. input, weight, bias, running_mean, running_var, eps, training=True
  1750. )
  1751. assert new_rm is not None, "new_running_mean should not be None"
  1752. assert new_rv is not None, "new_running_var should not be None"
  1753. return (output, save_mean, save_rstd, reserve, new_rm, new_rv)
  1754. @register_decomposition(aten._batch_norm_no_update.default)
  1755. def _batch_norm_no_update(
  1756. input: Tensor,
  1757. weight: Optional[Tensor],
  1758. bias: Optional[Tensor],
  1759. running_mean: Tensor,
  1760. running_var: Tensor,
  1761. momentum: float,
  1762. eps: float,
  1763. ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  1764. output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
  1765. input,
  1766. weight,
  1767. bias,
  1768. running_mean,
  1769. running_var,
  1770. False, # training
  1771. momentum,
  1772. eps,
  1773. False, # functional
  1774. )
  1775. reserve = _get_batch_norm_reserve_tensor(
  1776. input, weight, bias, running_mean, running_var, eps, training=False
  1777. )
  1778. return output, save_mean, save_rstd, reserve
  1779. @register_decomposition(aten._fused_dropout)
  1780. @out_wrapper("out0", "out1")
  1781. @pw_cast_for_opmath
  1782. def _fused_dropout_decomposition(input, p, generator=None):
  1783. assert generator is None
  1784. mask = (torch.rand_like(input) < p).to(dtype=torch.uint8)
  1785. res = mask.type_as(input) * input * (1.0 / p)
  1786. return (res, mask)
  1787. @register_decomposition(aten._to_copy)
  1788. @out_wrapper()
  1789. def _to_copy(
  1790. x: Tensor,
  1791. *,
  1792. dtype: Optional[torch.dtype] = None,
  1793. layout=None,
  1794. device: Optional[torch.device] = None,
  1795. pin_memory: bool = False,
  1796. non_blocking: bool = False,
  1797. memory_format: Optional[torch.memory_format] = None,
  1798. ):
  1799. assert not layout or layout == torch.strided, "TODO"
  1800. assert not pin_memory, "TODO"
  1801. if device is None and dtype is None and memory_format is None:
  1802. return x.clone()
  1803. dtype_converted = False
  1804. if device is not None and device != x.device:
  1805. # avoid conversions on cpu
  1806. if dtype is not None and device.type == "cpu":
  1807. x = torch._prims.convert_element_type(x, dtype)
  1808. dtype_converted = True
  1809. x = torch._prims.device_put(x, device)
  1810. if dtype is not None and not dtype_converted:
  1811. x = torch._prims.convert_element_type(x, dtype)
  1812. dtype_converted = True
  1813. if memory_format is not None: # no ref/prim for memory format
  1814. return torch.clone(x, memory_format=memory_format)
  1815. return x
  1816. # Questionable decompositions
  1817. # This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
  1818. # Note that this decomposition causes issues with in-place ops
  1819. @register_decomposition([aten.detach, aten.lift, aten.lift_fresh])
  1820. @out_wrapper()
  1821. def nop_decomposition(x):
  1822. return aten.alias(x)
  1823. # Also register to the Autograd dispatch key, so this decomp can run above autograd.
  1824. # native_batch_norm needs to decompose into other ops before autograd.
  1825. @aten.cudnn_batch_norm.default.py_impl(DispatchKey.Autograd)
  1826. @register_decomposition(aten.cudnn_batch_norm)
  1827. @out_wrapper("out0", "out1", "out2", "out3")
  1828. def cudnn_batch_norm(
  1829. input: Tensor,
  1830. weight: Tensor,
  1831. bias: Optional[Tensor],
  1832. running_mean: Optional[Tensor],
  1833. running_var: Optional[Tensor],
  1834. training: bool,
  1835. exponential_average_factor: float,
  1836. epsilon: float,
  1837. ):
  1838. a, b, c = aten.native_batch_norm(
  1839. input,
  1840. weight,
  1841. bias,
  1842. running_mean,
  1843. running_var,
  1844. training,
  1845. exponential_average_factor,
  1846. epsilon,
  1847. )
  1848. # Cudnn return running mean and variance when training is True
  1849. if training:
  1850. return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
  1851. return (
  1852. a,
  1853. weight.new_zeros((0,)),
  1854. weight.new_zeros((0,)),
  1855. input.new_zeros((0,), dtype=torch.uint8),
  1856. )
  1857. def _broadcast_batch_norm_backward(x, broadcast_mask):
  1858. for axis, mask in enumerate(broadcast_mask):
  1859. if mask == 1 and not (axis < x.ndim and x.shape[axis] == mask):
  1860. x = x.unsqueeze(axis)
  1861. return x
  1862. @register_decomposition(aten.batch_norm_backward.default)
  1863. def batch_norm_backward(
  1864. grad_out: Tensor,
  1865. input: Tensor,
  1866. weight: Optional[Tensor],
  1867. running_mean: Optional[Tensor],
  1868. running_var: Optional[Tensor],
  1869. save_mean: Optional[Tensor],
  1870. save_invstd: Optional[Tensor],
  1871. train: bool,
  1872. eps: float,
  1873. output_mask: List[bool],
  1874. reserve: Tensor,
  1875. ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
  1876. return native_batch_norm_backward(
  1877. grad_out,
  1878. input,
  1879. weight,
  1880. running_mean,
  1881. running_var,
  1882. save_mean,
  1883. save_invstd,
  1884. train,
  1885. eps,
  1886. output_mask,
  1887. )
  1888. @register_decomposition(aten.native_batch_norm_backward.default)
  1889. def native_batch_norm_backward(
  1890. grad_out: Tensor,
  1891. input: Tensor,
  1892. weight: Optional[Tensor],
  1893. running_mean: Optional[Tensor],
  1894. running_var: Optional[Tensor],
  1895. save_mean: Optional[Tensor],
  1896. save_invstd: Optional[Tensor],
  1897. train: bool,
  1898. eps: float,
  1899. output_mask: List[bool],
  1900. ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
  1901. input_dtype = input.dtype
  1902. if weight is not None:
  1903. weight_dtype = weight.dtype
  1904. else:
  1905. weight_dtype = input_dtype
  1906. computation_dtype = utils.get_computation_dtype(input.dtype)
  1907. (
  1908. grad_out_cast,
  1909. input_cast,
  1910. weight_cast,
  1911. running_mean_cast,
  1912. running_var_cast,
  1913. save_mean_cast,
  1914. save_invstd_cast,
  1915. ) = (
  1916. x.to(computation_dtype) if x is not None else x
  1917. for x in (
  1918. grad_out,
  1919. input,
  1920. weight,
  1921. running_mean,
  1922. running_var,
  1923. save_mean,
  1924. save_invstd,
  1925. )
  1926. )
  1927. input_shape = input.shape
  1928. input_rank = input.dim()
  1929. assert input_rank >= 2, "rank of the input must be at least 2"
  1930. axis = 1
  1931. num_features = prod(list(input_shape)) / input_shape[axis]
  1932. mean = save_mean_cast
  1933. invstd = save_invstd_cast
  1934. if train:
  1935. assert save_mean_cast is not None and save_invstd_cast is not None
  1936. else:
  1937. assert running_mean_cast is not None and running_var_cast is not None
  1938. mean = running_mean_cast
  1939. invstd = torch.rsqrt(running_var_cast + eps)
  1940. broadcast_mask: List[int] = [1] * input_rank
  1941. broadcast_mask[axis] = input_shape[axis]
  1942. reduction_axes: List[int] = []
  1943. for i in range(input_rank):
  1944. if i != axis:
  1945. reduction_axes.append(i)
  1946. mean = _broadcast_batch_norm_backward(mean, broadcast_mask) # type: ignore[arg-type]
  1947. norm = 1.0 / num_features
  1948. grad_output_sum = torch.sum(grad_out_cast, reduction_axes) # type: ignore[arg-type]
  1949. dot_p = torch.sum(grad_out_cast * (input_cast - mean), reduction_axes) # type: ignore[operator]
  1950. grad_mean = _broadcast_batch_norm_backward(grad_output_sum * norm, broadcast_mask)
  1951. proj_scale = _broadcast_batch_norm_backward(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) # type: ignore[operator]
  1952. if weight_cast is None:
  1953. grad_scale = _broadcast_batch_norm_backward(invstd, broadcast_mask) * 1.0 # type: ignore[arg-type]
  1954. else:
  1955. grad_scale = _broadcast_batch_norm_backward(
  1956. invstd * weight_cast, broadcast_mask
  1957. )
  1958. if train:
  1959. proj = (input_cast - mean) * proj_scale # type: ignore[operator]
  1960. grad_input = ((grad_out_cast - proj) - grad_mean) * grad_scale
  1961. else:
  1962. grad_input = grad_out_cast * grad_scale
  1963. if output_mask[1]:
  1964. grad_weight = dot_p * invstd
  1965. else:
  1966. grad_weight = None # "None" doesn't work with vjp, should use zeros for vjp
  1967. if output_mask[2]:
  1968. grad_bias = grad_output_sum
  1969. else:
  1970. grad_bias = None # "None" doesn't work with vjp, should use zeros for vjp
  1971. return (
  1972. grad_input.to(input_dtype),
  1973. _maybe_cast(grad_weight, weight_dtype),
  1974. _maybe_cast(grad_bias, weight_dtype),
  1975. )
  1976. # out_wrapper currently does not allow optional outputs
  1977. @register_decomposition(aten.native_batch_norm_backward.out)
  1978. def native_batch_norm_backward_out(
  1979. grad_out: Tensor,
  1980. input: Tensor,
  1981. weight: Optional[Tensor],
  1982. running_mean: Optional[Tensor],
  1983. running_var: Optional[Tensor],
  1984. save_mean: Optional[Tensor],
  1985. save_invstd: Optional[Tensor],
  1986. train: bool,
  1987. eps: float,
  1988. output_mask: List[bool],
  1989. *,
  1990. out0: torch.Tensor,
  1991. out1: torch.Tensor,
  1992. out2: torch.Tensor,
  1993. ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
  1994. result = native_batch_norm_backward(
  1995. grad_out,
  1996. input,
  1997. weight,
  1998. running_mean,
  1999. running_var,
  2000. save_mean,
  2001. save_invstd,
  2002. train,
  2003. eps,
  2004. output_mask,
  2005. )
  2006. grad_input = (out0, out1, out2)
  2007. for i, r in enumerate(result):
  2008. if r is not None:
  2009. _maybe_resize_out(grad_input[i], r.shape)
  2010. _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True)
  2011. return grad_input
  2012. @register_decomposition(aten.miopen_batch_norm_backward)
  2013. @out_wrapper("out0", "out1", "out2")
  2014. def miopen_batch_norm_backward(
  2015. input: Tensor,
  2016. grad_output: Tensor,
  2017. weight: Tensor,
  2018. running_mean: Optional[Tensor],
  2019. running_var: Optional[Tensor],
  2020. save_mean: Optional[Tensor],
  2021. save_var: Optional[Tensor],
  2022. epsilon: float,
  2023. ):
  2024. return aten.native_batch_norm_backward(
  2025. grad_output,
  2026. input,
  2027. weight,
  2028. running_mean,
  2029. running_var,
  2030. save_mean,
  2031. save_var,
  2032. True,
  2033. epsilon,
  2034. [True, True, True],
  2035. )
  2036. @register_decomposition(aten.cudnn_batch_norm_backward)
  2037. @out_wrapper("out0", "out1", "out2")
  2038. def cudnn_batch_norm_backward(
  2039. input: Tensor,
  2040. grad_output: Tensor,
  2041. weight: Tensor,
  2042. running_mean: Optional[Tensor],
  2043. running_var: Optional[Tensor],
  2044. save_mean: Optional[Tensor],
  2045. save_var: Optional[Tensor],
  2046. epsilon: float,
  2047. reserveSpace: Tensor,
  2048. ):
  2049. return aten.native_batch_norm_backward(
  2050. grad_output,
  2051. input,
  2052. weight,
  2053. running_mean,
  2054. running_var,
  2055. save_mean,
  2056. save_var,
  2057. True,
  2058. epsilon,
  2059. [True, True, True],
  2060. )
  2061. @register_decomposition(aten._adaptive_avg_pool2d)
  2062. @out_wrapper()
  2063. @pw_cast_for_opmath
  2064. def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
  2065. # Preconditions
  2066. device = input.device
  2067. shape = input.shape
  2068. ndim = len(shape)
  2069. torch._check(
  2070. ndim in (3, 4),
  2071. lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}",
  2072. )
  2073. for d in input.shape[-2:]:
  2074. torch._check(
  2075. d != 0,
  2076. lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for "
  2077. f"non-batch dimensions, but input has shape {tuple(shape)}.",
  2078. )
  2079. # Optimisation (we should also do this in the kernel implementation)
  2080. if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0:
  2081. stride = tuple(i // o for i, o in zip(shape[-2:], output_size))
  2082. kernel = tuple(
  2083. i - (o - 1) * s for i, o, s in zip(shape[-2:], output_size, stride)
  2084. )
  2085. return torch.nn.functional.avg_pool2d(input, kernel, stride)
  2086. def start_index(a, b, c):
  2087. return torch.div(a * c, b, rounding_mode="trunc")
  2088. def end_index(a, b, c):
  2089. return torch.div((a + 1) * c + b - 1, b, rounding_mode="trunc")
  2090. def compute_idx(in_size, out_size):
  2091. orange = torch.arange(out_size, device=device, dtype=torch.int64)
  2092. i0 = start_index(orange, out_size, in_size)
  2093. # Let length = end_index - start_index, i.e. the length of the pooling kernels
  2094. # length.max() can be computed analytically as follows:
  2095. maxlength = in_size // out_size + 1
  2096. in_size_mod = in_size % out_size
  2097. # adaptive = True iff there are kernels with different lengths
  2098. adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
  2099. if adaptive:
  2100. maxlength += 1
  2101. elif in_size_mod == 0:
  2102. maxlength -= 1
  2103. range_max = torch.arange(maxlength, device=device, dtype=torch.int64)
  2104. idx = i0.unsqueeze(-1) + range_max
  2105. if adaptive:
  2106. # Need to clamp to avoid accessing out-of-bounds memory
  2107. # TODO make minimum accept scalars
  2108. maxval = torch.scalar_tensor(
  2109. in_size - 1, dtype=idx.dtype, device=idx.device
  2110. )
  2111. idx = torch.minimum(idx, maxval)
  2112. # Compute the length
  2113. i1 = end_index(orange, out_size, in_size)
  2114. length = i1 - i0
  2115. else:
  2116. length = maxlength
  2117. return idx, length, range_max, adaptive
  2118. # length is not None if it's constant, otherwise we'll need to compute it
  2119. idxh, length_h, range_max_h, adaptive_h = compute_idx(shape[-2], output_size[-2])
  2120. idxw, length_w, range_max_w, adaptive_w = compute_idx(shape[-1], output_size[-1])
  2121. vals = input[..., _unsqueeze_to_dim(idxh, 4), idxw]
  2122. # Shortcut for the simpler case
  2123. if not adaptive_h and not adaptive_w:
  2124. return torch.mean(vals, dim=(-3, -1))
  2125. def maybe_mask(vals, length, range_max, adaptive, dim):
  2126. if isinstance(length, IntLike):
  2127. return vals, length
  2128. else:
  2129. # zero-out the things we didn't really want to select
  2130. assert dim < 0
  2131. # hack
  2132. mask = range_max >= length.unsqueeze(-1)
  2133. if dim == -2:
  2134. mask = _unsqueeze_to_dim(mask, 4)
  2135. vals = torch.masked_fill(vals, mask, 0.0)
  2136. # Compute the length of each window
  2137. length = _unsqueeze_to_dim(length, -dim)
  2138. return vals, length
  2139. vals, length_h = maybe_mask(
  2140. vals, length_h, range_max_h, adaptive=adaptive_h, dim=-2
  2141. )
  2142. vals, length_w = maybe_mask(
  2143. vals, length_w, range_max_w, adaptive=adaptive_w, dim=-1
  2144. )
  2145. # We unroll the sum as we assume that the kernels are going to be small
  2146. ret = None
  2147. for i, j in product(range(vals.shape[-3]), range(vals.shape[-1])):
  2148. if ret is None:
  2149. ret = vals[..., i, :, j]
  2150. else:
  2151. ret = ret + vals[..., i, :, j]
  2152. return ret / (length_h * length_w)
  2153. @register_decomposition(aten.index_add_)
  2154. def index_add_(
  2155. x: TensorLike,
  2156. dim: int,
  2157. index: TensorLike,
  2158. tensor: TensorLike,
  2159. *,
  2160. alpha: NumberType = 1,
  2161. ):
  2162. return _index_add(x, dim, index, tensor, inplace=True, alpha=alpha)
  2163. @register_decomposition(aten.index_add)
  2164. @out_wrapper()
  2165. def index_add(
  2166. x: TensorLike,
  2167. dim: int,
  2168. index: TensorLike,
  2169. tensor: TensorLike,
  2170. *,
  2171. alpha: NumberType = 1,
  2172. ):
  2173. return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha)
  2174. def _index_add(
  2175. x: TensorLike,
  2176. dim: int,
  2177. index: TensorLike,
  2178. tensor: TensorLike,
  2179. *,
  2180. inplace: bool,
  2181. alpha: NumberType = 1,
  2182. ):
  2183. dim = utils.canonicalize_dims(x.ndim, dim)
  2184. torch._check(
  2185. index.ndim <= 1,
  2186. lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
  2187. )
  2188. index_size = index.size(0) if index.ndim == 1 else 1
  2189. tensor_size = tensor.size(dim) if tensor.ndim > 0 else 1
  2190. torch._check(
  2191. tensor_size == index_size,
  2192. lambda: f"Number of indices ({index_size}) should be equal to tensor.size(dim) ({tensor_size}), for {dim=}",
  2193. )
  2194. if alpha != 1:
  2195. python_type = utils.dtype_to_type(x.dtype)
  2196. torch._check(
  2197. python_type == bool
  2198. or utils.is_weakly_lesser_type(type(alpha), python_type),
  2199. lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
  2200. )
  2201. tensor = tensor * alpha
  2202. # Treat scalars as elements of \R^1
  2203. zero_dim = x.ndim == 0
  2204. x1 = x.unsqueeze(0) if zero_dim else x
  2205. idx = (None,) * dim + (index,)
  2206. index_put = aten.index_put_ if inplace else aten.index_put
  2207. out = index_put(x1, idx, tensor, accumulate=True)
  2208. if inplace:
  2209. return x
  2210. else:
  2211. return out.squeeze(0) if zero_dim else out.contiguous()
  2212. @register_decomposition(aten.pad_sequence.default)
  2213. @aten.pad_sequence.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2214. def pad_sequence(sequences, batch_first=False, padding_value=0.0):
  2215. torch._check(len(sequences) > 0, lambda: "received an empty list of sequences")
  2216. sequences_size = len(sequences)
  2217. max_size = sequences[0].size()
  2218. trailing_dims = max_size[1:]
  2219. max_len = max(x.size(0) for x in sequences)
  2220. if batch_first:
  2221. out_dims = (sequences_size, max_len)
  2222. else:
  2223. out_dims = (max_len, sequences_size)
  2224. out_dims = out_dims + trailing_dims
  2225. out = sequences[0].new_full(out_dims, padding_value)
  2226. dim_paddings = (0, 0) * len(trailing_dims)
  2227. for i in range(sequences_size):
  2228. currseq = sequences[i]
  2229. row = aten.constant_pad_nd(
  2230. currseq, dim_paddings + (0, max_len - currseq.size(0)), padding_value
  2231. )
  2232. if batch_first:
  2233. out = aten.select_scatter(out, row, dim=0, index=i)
  2234. else:
  2235. out = aten.select_scatter(out, row, dim=1, index=i)
  2236. return out
  2237. @register_decomposition(aten.index_copy_)
  2238. def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
  2239. return _index_copy(x, dim, index, tensor, inplace=True)
  2240. @register_decomposition(aten.index_copy)
  2241. @out_wrapper()
  2242. def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
  2243. return _index_copy(x, dim, index, tensor, inplace=False)
  2244. def _index_copy(
  2245. x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool
  2246. ):
  2247. dim = utils.canonicalize_dims(x.ndim, dim)
  2248. torch._check(
  2249. index.ndim <= 1,
  2250. lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
  2251. )
  2252. # Treat scalars as elements of \R^1
  2253. zero_dim = x.ndim == 0
  2254. x1 = x.unsqueeze(0) if zero_dim else x
  2255. index = index.unsqueeze(0) if index.ndim == 0 else index
  2256. idx = (None,) * dim + (index,)
  2257. index_put = aten.index_put_ if inplace else aten.index_put
  2258. out = index_put(x1, idx, tensor)
  2259. if inplace:
  2260. return x
  2261. else:
  2262. return out.squeeze(0) if zero_dim else out.contiguous()
  2263. # nb: Should use acc_t, not op_math
  2264. @register_decomposition(aten.log_sigmoid_forward)
  2265. @out_wrapper("output", "buffer")
  2266. @pw_cast_for_opmath
  2267. def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
  2268. min = torch.minimum(self.new_zeros(()), self)
  2269. z = torch.exp(-torch.abs(self))
  2270. if self.is_cuda:
  2271. buffer = self.new_zeros((0,))
  2272. else:
  2273. buffer = z
  2274. return min - torch.log1p(z), buffer
  2275. @register_decomposition(aten.uniform)
  2276. @out_wrapper()
  2277. def uniform(
  2278. x: Tensor,
  2279. low: Union[bool, int, float] = 0.0,
  2280. high: Union[bool, int, float] = 1.0,
  2281. generator: Optional[torch.Generator] = None,
  2282. ):
  2283. return prims._uniform_helper(
  2284. x.shape,
  2285. low=sym_float(low),
  2286. high=sym_float(high),
  2287. dtype=x.dtype,
  2288. device=x.device,
  2289. generator=generator,
  2290. )
  2291. @register_decomposition(aten.uniform_)
  2292. def uniform_(self, low=0, high=1, generator=None):
  2293. return self.copy_(uniform(self, low, high, generator))
  2294. # aten/src/ATen/native/UpSample.cpp compute_output_size
  2295. def upsample_compute_output_size(input_size, output_size, scale_factors):
  2296. spatial_dimensions = len(input_size) - 2
  2297. if output_size is not None:
  2298. torch._check(
  2299. scale_factors is None,
  2300. lambda: "Must specify exactly one of output_size and scale_factors",
  2301. )
  2302. torch._check(len(output_size) == spatial_dimensions, lambda: "")
  2303. return output_size
  2304. if scale_factors is not None:
  2305. # NB: this isn't necessary lol
  2306. torch._check(
  2307. output_size is None,
  2308. lambda: "Must specify exactly one of output_size and scale_factors",
  2309. )
  2310. torch._check(len(scale_factors) == spatial_dimensions, lambda: "")
  2311. output_size = []
  2312. for i, s in enumerate(scale_factors):
  2313. if int(s) == s:
  2314. output_size.append(input_size[i + 2] * int(s))
  2315. else:
  2316. output_size.append(sym_int(input_size[i + 2] * s))
  2317. return output_size
  2318. torch._check(
  2319. False, lambda: "Must specify exactly one of output_size and scale_factors"
  2320. )
  2321. def get_scale_value(scales, idx):
  2322. if scales is None:
  2323. return None
  2324. return scales[idx]
  2325. @register_decomposition(aten.upsample_nearest1d.vec)
  2326. @register_decomposition(aten.upsample_nearest2d.vec)
  2327. @register_decomposition(aten.upsample_nearest3d.vec)
  2328. @aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  2329. @aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd)
  2330. @aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  2331. @aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd)
  2332. @aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  2333. @aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd)
  2334. def _upsample_nearest_vec(
  2335. input: Tensor,
  2336. output_size: Optional[List[int]],
  2337. scale_factors: Optional[List[float]],
  2338. ) -> Tensor:
  2339. osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
  2340. scales = (
  2341. scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item]
  2342. )
  2343. return _upsample_nearest(input, osize, scales)
  2344. @register_decomposition(aten._upsample_nearest_exact1d.vec)
  2345. @register_decomposition(aten._upsample_nearest_exact2d.vec)
  2346. @register_decomposition(aten._upsample_nearest_exact3d.vec)
  2347. @aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  2348. @aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd)
  2349. @aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  2350. @aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd)
  2351. @aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  2352. @aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd)
  2353. def _upsample_nearest_exact_vec(
  2354. input: Tensor,
  2355. output_size: Optional[List[int]],
  2356. scale_factors: Optional[List[float]],
  2357. ) -> Tensor:
  2358. osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
  2359. scales = (
  2360. scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item]
  2361. )
  2362. return _upsample_nearest(input, osize, scales, exact=True)
  2363. def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
  2364. # For each dim in output_size, compute the set of input indices used
  2365. # to produce the upsampled output.
  2366. indices = []
  2367. num_spatial_dims = len(output_size)
  2368. offset = 0.5 if exact else 0.0
  2369. for d in range(num_spatial_dims):
  2370. # Math matches aten/src/ATen/native/cpu/UpSampleKernel.cpp
  2371. #
  2372. # Indices are computed as following:
  2373. # scale = isize / osize
  2374. # Case: exact=False
  2375. # input_index = floor(output_index * scale)
  2376. # Same as OpenCV INTER_NEAREST
  2377. #
  2378. # Case: exact=False
  2379. # index_f32 = (output_index + 0.5) * scale - 0.5
  2380. # input_index = round(index_f32)
  2381. # Same as Pillow and Scikit-Image/Scipy ndi.zoom
  2382. osize = output_size[d]
  2383. isize = input.shape[-num_spatial_dims + d]
  2384. scale = isize / (isize * scales[d]) if scales[d] is not None else isize / osize
  2385. output_indices = torch.arange(osize, dtype=torch.float32, device=input.device)
  2386. input_indices = ((output_indices + offset) * scale).to(torch.int64)
  2387. for _ in range(num_spatial_dims - 1 - d):
  2388. input_indices = input_indices.unsqueeze(-1)
  2389. indices.append(input_indices)
  2390. return indices
  2391. @register_decomposition([aten.upsample_nearest1d.default, aten.upsample_nearest1d.out])
  2392. @aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2393. @aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd)
  2394. @out_wrapper(preserve_memory_format=True, exact_dtype=True)
  2395. def upsample_nearest1d(
  2396. input: Tensor,
  2397. output_size: List[int],
  2398. scales: Optional[float] = None,
  2399. ) -> Tensor:
  2400. return _upsample_nearest(input, output_size, [scales])
  2401. @register_decomposition(
  2402. [aten._upsample_nearest_exact1d.default, aten._upsample_nearest_exact1d.out]
  2403. )
  2404. @aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2405. @aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd)
  2406. @out_wrapper(preserve_memory_format=True, exact_dtype=True)
  2407. def upsample_nearest_exact1d(
  2408. input: Tensor,
  2409. output_size: List[int],
  2410. scales: Optional[float] = None,
  2411. ) -> Tensor:
  2412. return _upsample_nearest(input, output_size, [scales], exact=True)
  2413. @register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out])
  2414. @aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2415. @aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
  2416. @out_wrapper(preserve_memory_format=True, exact_dtype=True)
  2417. def upsample_nearest2d(
  2418. input: Tensor,
  2419. output_size: List[int],
  2420. scales_h: Optional[float] = None,
  2421. scales_w: Optional[float] = None,
  2422. ) -> Tensor:
  2423. return _upsample_nearest(input, output_size, [scales_h, scales_w])
  2424. @register_decomposition(
  2425. [aten._upsample_nearest_exact2d.default, aten._upsample_nearest_exact2d.out]
  2426. )
  2427. @aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2428. @aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd)
  2429. @out_wrapper(preserve_memory_format=True, exact_dtype=True)
  2430. def _upsample_nearest_exact2d(
  2431. input: Tensor,
  2432. output_size: List[int],
  2433. scales_h: Optional[float] = None,
  2434. scales_w: Optional[float] = None,
  2435. ) -> Tensor:
  2436. return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True)
  2437. @register_decomposition([aten.upsample_nearest3d.default, aten.upsample_nearest3d.out])
  2438. @aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2439. @aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
  2440. @out_wrapper(preserve_memory_format=True, exact_dtype=True)
  2441. def upsample_nearest3d(
  2442. input: Tensor,
  2443. output_size: List[int],
  2444. scales_d: Optional[float] = None,
  2445. scales_h: Optional[float] = None,
  2446. scales_w: Optional[float] = None,
  2447. ) -> Tensor:
  2448. return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w])
  2449. @register_decomposition(
  2450. [aten._upsample_nearest_exact3d.default, aten._upsample_nearest_exact3d.out]
  2451. )
  2452. @aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2453. @aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd)
  2454. @out_wrapper(preserve_memory_format=True, exact_dtype=True)
  2455. def _upsample_nearest_exact3d(
  2456. input: Tensor,
  2457. output_size: List[int],
  2458. scales_d: Optional[float] = None,
  2459. scales_h: Optional[float] = None,
  2460. scales_w: Optional[float] = None,
  2461. ) -> Tensor:
  2462. return _upsample_nearest(
  2463. input, output_size, [scales_d, scales_h, scales_w], exact=True
  2464. )
  2465. @pw_cast_for_opmath
  2466. def _upsample_nearest(
  2467. input: Tensor,
  2468. output_size: List[int],
  2469. scales: List[Optional[float]],
  2470. exact: bool = False,
  2471. ) -> Tensor:
  2472. spatial_indices = _compute_upsample_nearest_indices(
  2473. input, output_size, scales, exact=exact
  2474. )
  2475. indices = [None, None] + spatial_indices
  2476. result = aten._unsafe_index(input, indices)
  2477. if result.ndim == 4:
  2478. # convert output to correct memory format, if necessary
  2479. memory_format = utils.suggest_memory_format(input)
  2480. # following "heuristic: only use channels_last path when it's faster than the contiguous path"
  2481. n_channels = input.shape[1]
  2482. if input.device.type == "cuda" and n_channels < 4:
  2483. memory_format = torch.contiguous_format
  2484. result = result.contiguous(memory_format=memory_format)
  2485. return result
  2486. def gather_params(params, has_biases, has_projections):
  2487. if has_biases and has_projections:
  2488. group_size = 5
  2489. elif has_biases:
  2490. group_size = 4
  2491. elif has_projections:
  2492. group_size = 3
  2493. else:
  2494. group_size = 2
  2495. assert len(params) % group_size == 0, len(params)
  2496. return [
  2497. tuple(params[i : i + group_size]) for i in range(0, len(params), group_size)
  2498. ]
  2499. def params_hiddens(params, hiddens, i, bidirectional):
  2500. if bidirectional:
  2501. cur_params, cur_hidden = params[2 * i], hiddens[2 * i]
  2502. bidir_params, bidir_hidden = params[2 * i + 1], hiddens[2 * i + 1]
  2503. else:
  2504. cur_params, cur_hidden = params[i], hiddens[i]
  2505. bidir_params, bidir_hidden = None, None
  2506. return cur_params, cur_hidden, bidir_params, bidir_hidden
  2507. def update_hidden_for_packed(cur_hidden, last_batch_size, batch_size, hiddens):
  2508. assert last_batch_size > batch_size
  2509. hiddens.append(cur_hidden.narrow(0, batch_size, last_batch_size - batch_size))
  2510. return cur_hidden.narrow(0, 0, batch_size)
  2511. def update_hidden_for_packed_reverse(
  2512. cur_hidden, last_batch_size, batch_size, inp_hidden
  2513. ):
  2514. if last_batch_size == batch_size:
  2515. return cur_hidden
  2516. assert last_batch_size < batch_size
  2517. return torch.concat(
  2518. (
  2519. cur_hidden,
  2520. inp_hidden.narrow(0, last_batch_size, batch_size - last_batch_size),
  2521. )
  2522. )
  2523. def one_layer_rnn_data(
  2524. inp, hidden, params, has_biases, hidden_fn, batch_sizes, reverse=False
  2525. ):
  2526. ih_weight = params[0]
  2527. hh_weight = params[1]
  2528. ih_bias = params[2] if has_biases else None
  2529. hh_bias = params[3] if has_biases else None
  2530. step_output = []
  2531. hiddens: List[torch.Tensor] = []
  2532. last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
  2533. cur_hidden = hidden.narrow(0, 0, last_batch_size)
  2534. split_inp = torch.split(inp, list(batch_sizes))
  2535. if reverse:
  2536. split_inp = split_inp[::-1]
  2537. for inp in split_inp:
  2538. i = inp.shape[0]
  2539. if last_batch_size == i:
  2540. pass # don't update cur_hidden
  2541. # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest
  2542. elif reverse:
  2543. cur_hidden = update_hidden_for_packed_reverse(
  2544. cur_hidden, last_batch_size, i, hidden
  2545. )
  2546. else:
  2547. cur_hidden = update_hidden_for_packed(
  2548. cur_hidden, last_batch_size, i, hiddens
  2549. )
  2550. cur_hidden = hidden_fn(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
  2551. last_batch_size = i
  2552. step_output.append(cur_hidden)
  2553. if reverse:
  2554. step_output.reverse()
  2555. else:
  2556. hiddens.append(cur_hidden)
  2557. hiddens.reverse()
  2558. out = torch.cat(step_output, 0)
  2559. hidden_out = torch.cat(hiddens, 0) if not reverse else cur_hidden
  2560. return out, hidden_out
  2561. def rnn_cell(nonlinearity):
  2562. def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
  2563. return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
  2564. return inner
  2565. def rnn_cell_data(nonlinearity):
  2566. def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
  2567. i = F.linear(i, ih_weight, ih_bias)
  2568. return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
  2569. return inner
  2570. def one_layer_rnn(inp, hidden, params, has_biases, hidden_fn, reverse=False):
  2571. ih_weight = params[0]
  2572. hh_weight = params[1]
  2573. ih_bias = params[2] if has_biases else None
  2574. hh_bias = params[3] if has_biases else None
  2575. precomputed_input = F.linear(inp, ih_weight, ih_bias)
  2576. precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
  2577. cur_hidden = hidden.unsqueeze(0)
  2578. step_output = []
  2579. for i in precomputed_input:
  2580. cur_hidden = hidden_fn(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
  2581. step_output.append(cur_hidden)
  2582. if reverse:
  2583. step_output.reverse()
  2584. out = torch.cat(step_output, 0)
  2585. return out, cur_hidden.squeeze(0)
  2586. def mkldnn_one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
  2587. w0 = params[0]
  2588. w1 = params[1]
  2589. if has_biases:
  2590. w2 = params[2]
  2591. w3 = params[3]
  2592. else:
  2593. w2 = torch.zeros(w0.size())
  2594. w3 = torch.zeros(w1.size())
  2595. hx = hidden[0].unsqueeze(0)
  2596. cx = hidden[1].unsqueeze(0)
  2597. batch_sizes: List[int] = []
  2598. mode = 2 # third_party/ideep/include/ideep/abstract_types.hpp: ideep::rnn_kind::LSTM = 2
  2599. hidden_size = hx.size(2)
  2600. num_layers = 1
  2601. # _rnn_helper already handles bidirectional and batch_first so we hard-code them to False here
  2602. bidirectional = False
  2603. batch_first = False
  2604. train = False
  2605. # If batch_first, inp has been permuted in _rnn_helper. Convert to contiguous here.
  2606. # Same as aten/src/ATen/native/mkldnn/RNN.cpp: mkldnn_rnn: input = input.contiguous();
  2607. inp = inp.contiguous()
  2608. hx = hx.contiguous()
  2609. cx = cx.contiguous()
  2610. outputs = torch.ops.aten.mkldnn_rnn_layer.default(
  2611. inp,
  2612. w0,
  2613. w1,
  2614. w2,
  2615. w3,
  2616. hx,
  2617. cx,
  2618. reverse,
  2619. batch_sizes,
  2620. mode,
  2621. hidden_size,
  2622. num_layers,
  2623. has_biases,
  2624. bidirectional,
  2625. batch_first,
  2626. train,
  2627. )
  2628. y, hy, cy = outputs[0], outputs[1], outputs[2]
  2629. return y, (hy.squeeze(0), cy.squeeze(0))
  2630. def _rnn_helper(
  2631. input,
  2632. hidden,
  2633. params,
  2634. has_biases,
  2635. num_layers,
  2636. dropout,
  2637. train,
  2638. bidirectional,
  2639. batch_first,
  2640. layer_fn,
  2641. ):
  2642. input = input.transpose(0, 1) if batch_first else input
  2643. final_hiddens = []
  2644. for i in range(num_layers):
  2645. cur_params, cur_hidden, bidir_params, bidir_hidden = params_hiddens(
  2646. params, hidden, i, bidirectional
  2647. )
  2648. dropout = dropout if (train and num_layers < i - 1) else 0.0
  2649. fwd_inp, fwd_hidden = layer_fn(input, cur_hidden, cur_params, has_biases)
  2650. final_hiddens.append(fwd_hidden)
  2651. if bidirectional:
  2652. bwd_inp, bwd_hidden = layer_fn(
  2653. input, bidir_hidden, bidir_params, has_biases, reverse=True
  2654. )
  2655. final_hiddens.append(bwd_hidden)
  2656. if bidirectional:
  2657. input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1) # type: ignore[possibly-undefined]
  2658. else:
  2659. input = fwd_inp
  2660. if dropout != 0 and train and i < num_layers - 1:
  2661. input = torch.dropout(input, dropout, train=True)
  2662. input = input.transpose(0, 1) if batch_first else input
  2663. return input, final_hiddens
  2664. @register_decomposition(aten.rnn_tanh.input)
  2665. @aten.rnn_tanh.input.py_impl(DispatchKey.CompositeImplicitAutograd)
  2666. @aten.rnn_tanh.input.py_impl(DispatchKey.Autograd)
  2667. def rnn_tanh_input(
  2668. input,
  2669. hx,
  2670. params,
  2671. has_biases,
  2672. num_layers,
  2673. dropout,
  2674. train,
  2675. bidirectional,
  2676. batch_first,
  2677. ):
  2678. hidden = hx.unbind(0)
  2679. params = gather_params(params, has_biases, False)
  2680. out, final_hiddens = _rnn_helper(
  2681. input,
  2682. hidden,
  2683. params,
  2684. has_biases,
  2685. num_layers,
  2686. dropout,
  2687. train,
  2688. bidirectional,
  2689. batch_first,
  2690. partial(one_layer_rnn, hidden_fn=rnn_cell(torch.tanh)),
  2691. )
  2692. return out, torch.stack(final_hiddens, 0)
  2693. @register_decomposition(aten.rnn_relu.input)
  2694. @aten.rnn_relu.input.py_impl(DispatchKey.CompositeImplicitAutograd)
  2695. @aten.rnn_relu.input.py_impl(DispatchKey.Autograd)
  2696. def rnn_relu_input(
  2697. input,
  2698. hx,
  2699. params,
  2700. has_biases,
  2701. num_layers,
  2702. dropout,
  2703. train,
  2704. bidirectional,
  2705. batch_first,
  2706. ):
  2707. hidden = hx.unbind(0)
  2708. params = gather_params(params, has_biases, False)
  2709. out, final_hiddens = _rnn_helper(
  2710. input,
  2711. hidden,
  2712. params,
  2713. has_biases,
  2714. num_layers,
  2715. dropout,
  2716. train,
  2717. bidirectional,
  2718. batch_first,
  2719. partial(one_layer_rnn, hidden_fn=rnn_cell(torch.relu)),
  2720. )
  2721. return out, torch.stack(final_hiddens, 0)
  2722. @register_decomposition(aten.rnn_relu.data)
  2723. @aten.rnn_relu.data.py_impl(DispatchKey.CompositeImplicitAutograd)
  2724. @aten.rnn_relu.data.py_impl(DispatchKey.Autograd)
  2725. def rnn_relu_data(
  2726. data,
  2727. batch_sizes,
  2728. hx,
  2729. params,
  2730. has_biases,
  2731. num_layers,
  2732. dropout,
  2733. train,
  2734. bidirectional,
  2735. ):
  2736. hidden = hx.unbind(0)
  2737. params = gather_params(params, has_biases, False)
  2738. out, final_hiddens = _rnn_helper(
  2739. data,
  2740. hidden,
  2741. params,
  2742. has_biases,
  2743. num_layers,
  2744. dropout,
  2745. train,
  2746. bidirectional,
  2747. False,
  2748. partial(
  2749. one_layer_rnn_data,
  2750. batch_sizes=batch_sizes,
  2751. hidden_fn=rnn_cell_data(torch.relu),
  2752. ),
  2753. )
  2754. return out, torch.stack(final_hiddens, 0)
  2755. @register_decomposition(aten.rnn_tanh.data)
  2756. @aten.rnn_tanh.data.py_impl(DispatchKey.CompositeImplicitAutograd)
  2757. @aten.rnn_tanh.data.py_impl(DispatchKey.Autograd)
  2758. def rnn_tanh_data(
  2759. data,
  2760. batch_sizes,
  2761. hx,
  2762. params,
  2763. has_biases,
  2764. num_layers,
  2765. dropout,
  2766. train,
  2767. bidirectional,
  2768. ):
  2769. hidden = hx.unbind(0)
  2770. params = gather_params(params, has_biases, False)
  2771. out, final_hiddens = _rnn_helper(
  2772. data,
  2773. hidden,
  2774. params,
  2775. has_biases,
  2776. num_layers,
  2777. dropout,
  2778. train,
  2779. bidirectional,
  2780. False,
  2781. partial(
  2782. one_layer_rnn_data,
  2783. batch_sizes=batch_sizes,
  2784. hidden_fn=rnn_cell_data(torch.tanh),
  2785. ),
  2786. )
  2787. return out, torch.stack(final_hiddens, 0)
  2788. def lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim):
  2789. gates = F.linear(hx, hh_weight, hh_bias) + inp
  2790. chunked_gates = gates.chunk(4, chunk_dim)
  2791. in_gate = chunked_gates[0].sigmoid()
  2792. forget_gate = chunked_gates[1].sigmoid()
  2793. cell_gate = chunked_gates[2].tanh()
  2794. out_gate = chunked_gates[3].sigmoid()
  2795. cy = forget_gate * cx + (in_gate * cell_gate)
  2796. hy = out_gate * cy.tanh()
  2797. hy = hy if hr_weight is None else F.linear(hy, hr_weight, None)
  2798. return hy, cy
  2799. def one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
  2800. ih_weight = params[0]
  2801. hh_weight = params[1]
  2802. ih_bias = params[2] if has_biases else None
  2803. hh_bias = params[3] if has_biases else None
  2804. hr_weight = (
  2805. params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
  2806. )
  2807. hx = hidden[0].unsqueeze(0)
  2808. cx = hidden[1].unsqueeze(0)
  2809. precomputed_input = F.linear(inp, ih_weight, ih_bias)
  2810. precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
  2811. step_output = []
  2812. for inp in precomputed_input:
  2813. hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=2)
  2814. step_output.append(hx)
  2815. if reverse:
  2816. step_output.reverse()
  2817. out = torch.cat(step_output, 0)
  2818. return out, (hx.squeeze(1), cx.squeeze(1))
  2819. def one_layer_lstm_data(inp, hidden, params, has_biases, batch_sizes, reverse=False):
  2820. ih_weight = params[0]
  2821. hh_weight = params[1]
  2822. ih_bias = params[2] if has_biases else None
  2823. hh_bias = params[3] if has_biases else None
  2824. hr_weight = (
  2825. params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
  2826. )
  2827. step_output = []
  2828. hiddens = []
  2829. last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
  2830. split_inp = torch.split(inp, list(batch_sizes))
  2831. if reverse:
  2832. split_inp = split_inp[::-1]
  2833. orig_hx = hidden[0]
  2834. orig_cx = hidden[1]
  2835. hx, cx = orig_hx.narrow(0, 0, last_batch_size), orig_cx.narrow(
  2836. 0, 0, last_batch_size
  2837. )
  2838. for inp in split_inp:
  2839. i = inp.shape[0]
  2840. inp = F.linear(inp, ih_weight, ih_bias)
  2841. # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest
  2842. if i < last_batch_size:
  2843. hiddens.append(
  2844. (
  2845. hx.narrow(0, i, last_batch_size - i),
  2846. cx.narrow(0, i, last_batch_size - i),
  2847. )
  2848. )
  2849. hx, cx = hx.narrow(0, 0, i), cx.narrow(0, 0, i)
  2850. # this will only happen when reverse=True
  2851. if i > last_batch_size:
  2852. hx = torch.concat(
  2853. (hx, orig_hx.narrow(0, last_batch_size, i - last_batch_size)), 0
  2854. )
  2855. cx = torch.concat(
  2856. (cx, orig_cx.narrow(0, last_batch_size, i - last_batch_size)), 0
  2857. )
  2858. hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=1)
  2859. last_batch_size = i
  2860. step_output.append(hx)
  2861. if reverse:
  2862. step_output.reverse()
  2863. hidden_out = (hx, cx)
  2864. else:
  2865. hiddens.append((hx, cx))
  2866. hiddens.reverse()
  2867. hidden0, hidden1 = zip(*hiddens)
  2868. hidden_out = torch.cat(hidden0, 0), torch.cat(hidden1, 0)
  2869. out = torch.cat(step_output, 0)
  2870. return out, hidden_out
  2871. def select_one_layer_lstm_function(input, hx, params):
  2872. r"""Check whether we could use decompose lstm with mkldnn_rnn_layer.
  2873. All the below conditions need to be met:
  2874. * ``torch._C._get_mkldnn_enabled()`` returns ``True``.
  2875. * All the input args are on CPU.
  2876. * The dtypes of args are either torch.float or torch.bfloat16.
  2877. * Inference.
  2878. * ``has_projections`` returns ``False``.
  2879. Args:
  2880. * input: the input sequence to LSTM
  2881. * hx: a tuple of the input hidden state and cell state ``(h_0, c_0)`` to LSTM
  2882. * params: the weight and bias tensors of LSTM
  2883. """
  2884. def use_mkldnn(input, hx, params):
  2885. if not torch._C._get_mkldnn_enabled():
  2886. return False
  2887. tensors = [input] + list(hx) + list(chain.from_iterable(params))
  2888. devices = {t.device for t in tensors}
  2889. if len(devices) != 1:
  2890. return False
  2891. device = devices.pop()
  2892. if device != torch.device("cpu"):
  2893. return False
  2894. # With autocast, possible to have mixed dtype here
  2895. dtypes = {t.dtype for t in tensors}
  2896. for dtype in dtypes:
  2897. if dtype not in [torch.float, torch.bfloat16]:
  2898. return False
  2899. if input.requires_grad:
  2900. return False
  2901. has_projections = hx[0].size(2) != hx[1].size(2)
  2902. if has_projections:
  2903. return False
  2904. return True
  2905. # mkldnn_one_layer_lstm does not depend on seq_len while one_layer_lstm
  2906. # will expand over the seq_len dim
  2907. if use_mkldnn(input, hx, params):
  2908. return mkldnn_one_layer_lstm
  2909. else:
  2910. return one_layer_lstm
  2911. @register_decomposition(aten.lstm.input)
  2912. @aten.lstm.input.py_impl(DispatchKey.CompositeImplicitAutograd)
  2913. @aten.lstm.input.py_impl(DispatchKey.Autograd)
  2914. def lstm_impl(
  2915. input,
  2916. hx,
  2917. params,
  2918. has_biases,
  2919. num_layers,
  2920. dropout,
  2921. train,
  2922. bidirectional,
  2923. batch_first,
  2924. ):
  2925. assert len(hx) == 2, "lstm expects two hidden states"
  2926. params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2))
  2927. hidden = list(zip(hx[0], hx[1]))
  2928. layer_fn = select_one_layer_lstm_function(input, hx, params)
  2929. out, final_hiddens = _rnn_helper(
  2930. input,
  2931. hidden,
  2932. params,
  2933. has_biases,
  2934. num_layers,
  2935. dropout,
  2936. train,
  2937. bidirectional,
  2938. batch_first,
  2939. layer_fn,
  2940. )
  2941. final_hiddens = list(zip(*final_hiddens))
  2942. return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
  2943. @register_decomposition(aten.lstm.data)
  2944. @aten.lstm.data.py_impl(DispatchKey.CompositeImplicitAutograd)
  2945. @aten.lstm.data.py_impl(DispatchKey.Autograd)
  2946. def lstm_data_impl(
  2947. data,
  2948. batch_sizes,
  2949. hx,
  2950. params,
  2951. has_biases,
  2952. num_layers,
  2953. dropout,
  2954. train,
  2955. bidirectional,
  2956. ):
  2957. assert len(hx) == 2, "lstm expects two hidden states"
  2958. params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2))
  2959. hidden = list(zip(hx[0], hx[1]))
  2960. out, final_hiddens = _rnn_helper(
  2961. data,
  2962. hidden,
  2963. params,
  2964. has_biases,
  2965. num_layers,
  2966. dropout,
  2967. train,
  2968. bidirectional,
  2969. False,
  2970. partial(one_layer_lstm_data, batch_sizes=batch_sizes),
  2971. )
  2972. final_hiddens = list(zip(*final_hiddens))
  2973. return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
  2974. def gru_cell(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
  2975. chunked_igates = inp.chunk(3, 1)
  2976. chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 2)
  2977. reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
  2978. input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
  2979. new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
  2980. return (cur_hidden - new_gate) * input_gate + new_gate
  2981. def gru_cell_data(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
  2982. chunked_igates = F.linear(inp, ih_weight, ih_bias).chunk(3, 1)
  2983. chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 1)
  2984. reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
  2985. input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
  2986. new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
  2987. return (cur_hidden - new_gate) * input_gate + new_gate
  2988. @register_decomposition(aten.gru.data)
  2989. @aten.gru.data.py_impl(DispatchKey.CompositeImplicitAutograd)
  2990. @aten.gru.data.py_impl(DispatchKey.Autograd)
  2991. def gru_impl_data(
  2992. data,
  2993. batch_sizes,
  2994. hx,
  2995. params,
  2996. has_biases,
  2997. num_layers,
  2998. dropout,
  2999. train,
  3000. bidirectional,
  3001. ):
  3002. params = gather_params(params, has_biases, False)
  3003. out, final_hiddens = _rnn_helper(
  3004. data,
  3005. hx.unbind(0),
  3006. params,
  3007. has_biases,
  3008. num_layers,
  3009. dropout,
  3010. train,
  3011. bidirectional,
  3012. False,
  3013. partial(one_layer_rnn_data, batch_sizes=batch_sizes, hidden_fn=gru_cell_data),
  3014. )
  3015. return out, torch.stack(final_hiddens, 0)
  3016. @register_decomposition(aten.gru.input)
  3017. @aten.gru.input.py_impl(DispatchKey.CompositeImplicitAutograd)
  3018. @aten.gru.input.py_impl(DispatchKey.Autograd)
  3019. def gru_impl(
  3020. input,
  3021. hx,
  3022. params,
  3023. has_biases,
  3024. num_layers,
  3025. dropout,
  3026. train,
  3027. bidirectional,
  3028. batch_first,
  3029. ):
  3030. params = gather_params(params, has_biases, False)
  3031. out, final_hiddens = _rnn_helper(
  3032. input,
  3033. hx.unbind(0),
  3034. params,
  3035. has_biases,
  3036. num_layers,
  3037. dropout,
  3038. train,
  3039. bidirectional,
  3040. batch_first,
  3041. partial(one_layer_rnn, hidden_fn=gru_cell),
  3042. )
  3043. return out, torch.stack(final_hiddens, 0)
  3044. @register_decomposition(aten._upsample_bilinear2d_aa.vec)
  3045. @aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  3046. @aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.Autograd)
  3047. def upsample_bilinear2d_aa_vec(input, output_size, align_corners, scale_factors):
  3048. osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
  3049. scale_h = get_scale_value(scale_factors, 0)
  3050. scale_w = get_scale_value(scale_factors, 1)
  3051. return torch.ops.aten._upsample_bilinear2d_aa(
  3052. input, osize, align_corners, scale_h, scale_w
  3053. )
  3054. @register_decomposition(aten._upsample_bicubic2d_aa.vec)
  3055. @aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  3056. @aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.Autograd)
  3057. def upsample_bicubic2d_aa_vec(input, output_size, align_corners, scale_factors):
  3058. osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
  3059. scale_h = get_scale_value(scale_factors, 0)
  3060. scale_w = get_scale_value(scale_factors, 1)
  3061. return torch.ops.aten._upsample_bicubic2d_aa(
  3062. input, osize, align_corners, scale_h, scale_w
  3063. )
  3064. @register_decomposition(aten.upsample_bilinear2d.vec)
  3065. @register_decomposition(aten.upsample_trilinear3d.vec)
  3066. @aten.upsample_linear1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  3067. @aten.upsample_linear1d.vec.py_impl(DispatchKey.Autograd)
  3068. @aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  3069. @aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd)
  3070. @aten.upsample_trilinear3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  3071. @aten.upsample_trilinear3d.vec.py_impl(DispatchKey.Autograd)
  3072. def _upsample_linear_vec(input, output_size, align_corners, scale_factors):
  3073. osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
  3074. scales = scale_factors if scale_factors else [None] * len(osize)
  3075. return _upsample_linear(input, osize, align_corners, scales)
  3076. @register_decomposition([aten.upsample_linear1d.default, aten.upsample_linear1d.out])
  3077. @out_wrapper()
  3078. def upsample_linear1d(
  3079. input: Tensor,
  3080. output_size: List[int],
  3081. align_corners: bool,
  3082. scales_w: Optional[float] = None,
  3083. ) -> Tensor:
  3084. return _upsample_linear(input, output_size, align_corners, [scales_w])
  3085. @register_decomposition(
  3086. [aten.upsample_bilinear2d.default, aten.upsample_bilinear2d.out]
  3087. )
  3088. @aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd)
  3089. @out_wrapper()
  3090. def upsample_bilinear2d(
  3091. input: Tensor,
  3092. output_size: List[int],
  3093. align_corners: bool,
  3094. scales_h: Optional[float] = None,
  3095. scales_w: Optional[float] = None,
  3096. ) -> Tensor:
  3097. return _upsample_linear(input, output_size, align_corners, [scales_h, scales_w])
  3098. @register_decomposition(
  3099. [aten.upsample_trilinear3d.default, aten.upsample_trilinear3d.out]
  3100. )
  3101. @out_wrapper()
  3102. def upsample_trilinear3d(
  3103. input: Tensor,
  3104. output_size: List[int],
  3105. align_corners: bool,
  3106. scales_d: Optional[float] = None,
  3107. scales_h: Optional[float] = None,
  3108. scales_w: Optional[float] = None,
  3109. ) -> Tensor:
  3110. return _upsample_linear(
  3111. input, output_size, align_corners, [scales_d, scales_h, scales_w]
  3112. )
  3113. def _compute_scale(in_size, out_size, align_corners, scale=None):
  3114. if align_corners:
  3115. return (in_size - 1.0) / (out_size - 1.0) if out_size > 1 else 0
  3116. else:
  3117. return 1.0 / scale if scale is not None and scale > 0 else in_size / out_size
  3118. def _compute_source_index(scale, dst_index, align_corners):
  3119. if align_corners:
  3120. return scale * dst_index
  3121. else:
  3122. return scale * (dst_index + 0.5) - 0.5
  3123. def _sum_tensors_uint8(
  3124. src: Iterable[Tensor], weights: Iterable[Tensor], weights_precision: Tensor
  3125. ) -> Tensor:
  3126. output = _sum_tensors(
  3127. s.to(torch.int32) * c.to(torch.int32) for s, c in zip(src, weights)
  3128. ) + (1 << (weights_precision - 1))
  3129. output = output >> weights_precision
  3130. return torch.clamp(output, 0, 255).to(torch.uint8)
  3131. def _compute_weight_precision(weights: TensorSequenceType) -> Tensor:
  3132. max_weight = torch.stack(weights).max()
  3133. max_weight_precision = 22
  3134. precisions = torch.arange(max_weight_precision, device=max_weight.device)
  3135. values = 0.5 + max_weight * (1 << (precisions + 1))
  3136. mask = values >= (1 << 15)
  3137. return max_weight_precision - mask.sum()
  3138. @pw_cast_for_opmath
  3139. def _upsample_linear(
  3140. input: Tensor,
  3141. output_size: List[int],
  3142. align_corners: bool,
  3143. scales: List[Optional[float]],
  3144. ) -> Tensor:
  3145. # get dimensions of original image
  3146. n_batch, n_channels = input.shape[:2]
  3147. inp_sizes = input.shape[2:]
  3148. n_dims = len(inp_sizes)
  3149. _, dtype = utils.elementwise_dtypes(
  3150. input,
  3151. type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  3152. )
  3153. def get_values(inp_size, out_size, scales, nsqueeze):
  3154. # First Calculate scaling factor
  3155. scale_factor = _compute_scale(inp_size, out_size, align_corners, scales)
  3156. # We have to create arange with int64 dtype and use .to in order to avoid
  3157. # additional kernels creation in inductor and get a perf slowdown
  3158. i = torch.arange(out_size, device=input.device).to(dtype=dtype)
  3159. x_f32 = _compute_source_index(scale_factor, i, align_corners).clamp(min=0.0)
  3160. x_f32 = x_f32.reshape(x_f32.shape[0], *[1] * (nsqueeze))
  3161. x = x_f32.to(torch.int64)
  3162. xp1 = (x + 1).clamp(max=inp_size - 1)
  3163. return x_f32, x, xp1
  3164. values = [
  3165. get_values(inp_size, out_size, scales, n_dims - 1 - i)
  3166. for i, (inp_size, out_size, scales) in enumerate(
  3167. zip(inp_sizes, output_size, scales)
  3168. )
  3169. ]
  3170. xs_f32, xs, xp1s = list(zip(*values))
  3171. vs = []
  3172. for a in product(*[[0, 1]] * n_dims):
  3173. idx = [None, None] + [xs[k] if a[k] == 0 else xp1s[k] for k in range(n_dims)]
  3174. v = aten._unsafe_index(input, idx)
  3175. v = _maybe_convert_to_dtype(v, dtype)
  3176. vs.append(v)
  3177. for i in reversed(range(n_dims)):
  3178. xscale = (xs_f32[i] - xs[i]).clamp(0.0, 1.0).to(dtype)
  3179. vs = [
  3180. # x1 * (1 - alpha) + x2 * alpha == x1 + (x2 - x1) * alpha
  3181. v1 + torch.mul(v2 - v1, xscale)
  3182. for v1, v2 in zip(vs[::2], vs[1::2])
  3183. ]
  3184. assert len(vs) == 1
  3185. result = vs[0]
  3186. # convert output to correct memory format, if necessary
  3187. memory_format = utils.suggest_memory_format(input)
  3188. # following "heuristic: only use channels_last path when it's faster than the contiguous path"
  3189. if input.device.type == "cuda" and n_channels < 16:
  3190. memory_format = torch.contiguous_format
  3191. assert isinstance(result, torch.Tensor)
  3192. result = result.contiguous(memory_format=memory_format)
  3193. if not input.is_floating_point():
  3194. result = result.round()
  3195. return result
  3196. # We should be applying decompositions after all transformations
  3197. @register_decomposition(aten.is_same_size.default)
  3198. def is_same_size(a: Tensor, b: Tensor) -> bool:
  3199. return a.shape == b.shape
  3200. @register_decomposition([aten._reshape_alias, aten._unsafe_view])
  3201. @out_wrapper()
  3202. def _reshape_alias(x, shape, *args):
  3203. return aten.view(x, shape)
  3204. @register_decomposition([aten._unsafe_index])
  3205. def _index(x, indices):
  3206. return aten.index(x, indices)
  3207. def _nll_loss_forward(
  3208. self: Tensor,
  3209. target: Tensor,
  3210. weight: Optional[Tensor],
  3211. reduction: int,
  3212. ignore_index: int,
  3213. ) -> Tuple[Tensor, Tensor]:
  3214. # self can be [N, C] or [C]
  3215. # target can be [N] or []
  3216. n_dims = self.dim()
  3217. channel_dim = 1
  3218. if n_dims < 2:
  3219. channel_dim = 0
  3220. if weight is not None:
  3221. if n_dims > 1:
  3222. shape = [
  3223. 1,
  3224. ] * n_dims
  3225. shape[channel_dim] = weight.shape[0]
  3226. w = weight.view(shape)
  3227. else:
  3228. w = weight
  3229. self = self * w
  3230. safe_target = torch.where(target != ignore_index, target, 0)
  3231. safe_target_ = safe_target.unsqueeze(channel_dim)
  3232. # target can be [N, 1] or [1]
  3233. result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
  3234. result = torch.where(target != ignore_index, result, 0)
  3235. if reduction == Reduction.NONE.value and n_dims > 1:
  3236. total_weight = self.new_full((), 0.0)
  3237. return result, total_weight
  3238. if weight is not None:
  3239. w = w.expand(self.shape)
  3240. wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
  3241. wsum = torch.where(target != ignore_index, wsum, 0)
  3242. total_weight = wsum.sum()
  3243. else:
  3244. total_weight = (target != ignore_index).sum().to(self)
  3245. if reduction == Reduction.SUM.value:
  3246. result = result.sum()
  3247. elif reduction == Reduction.MEAN.value:
  3248. result = result.sum() / total_weight
  3249. return result, total_weight
  3250. @register_decomposition(aten.nll_loss_forward)
  3251. @out_wrapper("output", "total_weight")
  3252. def nll_loss_forward(
  3253. self: Tensor,
  3254. target: Tensor,
  3255. weight: Optional[Tensor],
  3256. reduction: int,
  3257. ignore_index: int,
  3258. ) -> Tuple[Tensor, Tensor]:
  3259. assert self.dim() > 0 and self.dim() <= 2, "input tensor should be 1D or 2D"
  3260. assert (
  3261. target.dim() <= 1
  3262. ), "0D or 1D target tensor expected, multi-target not supported"
  3263. no_batch_dim = self.dim() == 1 and target.dim() == 0
  3264. assert no_batch_dim or (
  3265. self.shape[0] == target.shape[0]
  3266. ), f"size mismatch (got input: {self.shape}, target: {target.shape})"
  3267. n_classes = self.shape[-1]
  3268. assert weight is None or (
  3269. weight.dim() == 1 and weight.numel() == n_classes
  3270. ), f"weight tensor should be defined either for all {n_classes} classes or no classes but got weight tensor of shape: {weight.shape}" # noqa: B950
  3271. return _nll_loss_forward(self, target, weight, reduction, ignore_index)
  3272. @register_decomposition(aten.nll_loss2d_forward)
  3273. @out_wrapper("output", "total_weight")
  3274. def nll_loss2d_forward(
  3275. self: Tensor,
  3276. target: Tensor,
  3277. weight: Optional[Tensor],
  3278. reduction: int,
  3279. ignore_index: int,
  3280. ) -> Tuple[Tensor, Tensor]:
  3281. return _nll_loss_forward(self, target, weight, reduction, ignore_index)
  3282. # These are adapted from aten/src/ATen/native/UpSample.h, wich is based on
  3283. # https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
  3284. def _upsample_cubic_convolution1(x: Tensor, A: float) -> Tensor:
  3285. return ((A + 2) * x - (A + 3)) * x * x + 1
  3286. def _upsample_cubic_convolution2(x: Tensor, A: float) -> Tensor:
  3287. return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A
  3288. def _upsample_get_cubic_coefficients(t: Tensor) -> TensorSequenceType:
  3289. A = -0.75
  3290. if t.device == torch.device("cpu"):
  3291. tt1 = torch.stack([t, 1.0 - t], dim=0)
  3292. tt2 = torch.stack([t + 1.0, 2.0 - t], dim=0)
  3293. w03 = _upsample_cubic_convolution2(tt2, A)
  3294. w12 = _upsample_cubic_convolution1(tt1, A)
  3295. w0, w3 = torch.unbind(w03, dim=0)
  3296. w1, w2 = torch.unbind(w12, dim=0)
  3297. return w0, w1, w2, w3
  3298. else:
  3299. return (
  3300. _upsample_cubic_convolution2(t + 1.0, A),
  3301. _upsample_cubic_convolution1(t, A),
  3302. _upsample_cubic_convolution1(1.0 - t, A),
  3303. _upsample_cubic_convolution2(2.0 - t, A),
  3304. )
  3305. def _upsample_cubic_interp1d(coeffs: TensorSequenceType, ts: Tensor) -> Tensor:
  3306. coeffs2 = _upsample_get_cubic_coefficients(ts)
  3307. return _sum_tensors(c1 * c2 for (c1, c2) in zip(coeffs, coeffs2))
  3308. # Need this instead of just sum() to keep mypy happy
  3309. def _sum_tensors(ts: Iterable[Tensor]) -> Tensor:
  3310. return reduce(torch.add, ts)
  3311. def _linspace_from_neg_one(
  3312. num_steps: int, align_corners: bool, dtype: torch.dtype, device: torch.device
  3313. ):
  3314. if num_steps <= 1:
  3315. return torch.tensor(0, device=device, dtype=dtype)
  3316. a = ((num_steps - 1) / num_steps) if not align_corners else 1
  3317. return torch.linspace(-a, a, steps=num_steps, device=device, dtype=dtype)
  3318. def _make_base_grid_4d(theta: Tensor, h: int, w: int, align_corners: bool):
  3319. dtype = theta.dtype
  3320. device = theta.device
  3321. # Using padding and summation generates a single kernel vs using torch.stack where 3 kernels generated
  3322. # corresponding to each individual tensor: grid_x, grid_y, grid_one
  3323. grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, w, 1)
  3324. grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(h, 1, 1)
  3325. grid_one = torch.ones((1, 1, 1), dtype=dtype, device=device)
  3326. # this is just a temporary hack and we should use torch.stack here once #104480 is merged
  3327. grid_x = torch.nn.functional.pad(grid_x, pad=(0, 2), mode="constant", value=0)
  3328. grid_y = torch.nn.functional.pad(grid_y, pad=(1, 1), mode="constant", value=0)
  3329. grid_one = torch.nn.functional.pad(grid_one, pad=(2, 0), mode="constant", value=0)
  3330. return grid_x + grid_y + grid_one
  3331. def _make_base_grid_5d(theta: Tensor, d: int, h: int, w: int, align_corners: bool):
  3332. dtype = theta.dtype
  3333. device = theta.device
  3334. grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, 1, w, 1)
  3335. grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(1, h, 1, 1)
  3336. grid_z = _linspace_from_neg_one(d, align_corners, dtype, device).view(d, 1, 1, 1)
  3337. grid_one = torch.ones((1, 1, 1, 1), dtype=dtype, device=device)
  3338. # this is just a temporary hack and we should use torch.stack here once #104480 is merged
  3339. grid_x = torch.nn.functional.pad(grid_x, pad=(0, 3), mode="constant", value=0)
  3340. grid_y = torch.nn.functional.pad(grid_y, pad=(1, 2), mode="constant", value=0)
  3341. grid_z = torch.nn.functional.pad(grid_z, pad=(2, 1), mode="constant", value=0)
  3342. grid_one = torch.nn.functional.pad(grid_one, pad=(3, 0), mode="constant", value=0)
  3343. return grid_x + grid_y + grid_z + grid_one
  3344. def _affine_grid_generator_4d(theta: Tensor, size: List[int], align_corners: bool):
  3345. n, _, h, w = size
  3346. base_grid = _make_base_grid_4d(theta, h, w, align_corners=align_corners)
  3347. # base_grid shape is (h, w, 3) and theta shape is (n, 2, 3)
  3348. # We do manually a matrix multiplication which is faster than mm()
  3349. # (h * w, 3, 1) * (n, 1, 3, 2) -> (n, h * w, 2)
  3350. grid = (base_grid.view(-1, 3, 1) * theta.mT.unsqueeze(1)).sum(-2)
  3351. return grid.view(n, h, w, 2)
  3352. def _affine_grid_generator_5d(theta: Tensor, size: List[int], align_corners: bool):
  3353. n, _, d, h, w = size
  3354. base_grid = _make_base_grid_5d(theta, d, h, w, align_corners=align_corners)
  3355. # base_grid shape is (d, h, w, 4) and theta shape is (n, 3, 4)
  3356. # We do manually a matrix multiplication which is faster than mm()
  3357. # (d * h * w, 4, 1) * (n, 1, 4, 3) -> (n, h * w, 3)
  3358. grid = (base_grid.view(-1, 4, 1) * theta.mT.unsqueeze(1)).sum(-2)
  3359. return grid.view(n, d, h, w, 3)
  3360. @register_decomposition(aten.affine_grid_generator)
  3361. @out_wrapper()
  3362. @pw_cast_for_opmath
  3363. def affine_grid_generator(theta: Tensor, size: List[int], align_corners: bool):
  3364. torch._check(
  3365. len(size) in (4, 5),
  3366. lambda: "affine_grid_generator needs 4d (spatial) or 5d (volumetric) inputs.",
  3367. )
  3368. if len(size) == 4:
  3369. return _affine_grid_generator_4d(theta, size, align_corners=align_corners)
  3370. else:
  3371. return _affine_grid_generator_5d(theta, size, align_corners=align_corners)
  3372. def _grid_sampler_2d(
  3373. a: Tensor,
  3374. grid: Tensor,
  3375. interpolation_mode: int = 0,
  3376. padding_mode: int = 0,
  3377. align_corners: bool = False,
  3378. _expand_grid: bool = True,
  3379. ) -> Tensor:
  3380. # This method is a copy of grid_sampler_2d implementation and introduced with additional arg _expand_grid to
  3381. # optionally expand the input grid for performance reasons.
  3382. # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x
  3383. # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2)
  3384. # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first.
  3385. # Thus we apply this hack to not expand the grid for this case.
  3386. torch._check(
  3387. interpolation_mode in (0, 1, 2),
  3388. lambda: f"Invalid interpolation mode {interpolation_mode}",
  3389. )
  3390. torch._check(
  3391. padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
  3392. )
  3393. def unnormalize(coords: Tensor, size: int) -> Tensor:
  3394. # Rescale coordinates from [-1, 1] to:
  3395. # [0, size - 1] if align_corners is True
  3396. # [-.5, size -.5] if align_corners is False
  3397. mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5)
  3398. ofs = size * 0.5 - 0.5
  3399. return coords * mul + ofs
  3400. # Reflects coordinates until they fall between low and high (inclusive).
  3401. # The bounds are passed as twice their value so that half-integer values
  3402. # can be represented as ints.
  3403. def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor:
  3404. if twice_low == twice_high:
  3405. return torch.zeros_like(coords)
  3406. coords_min = twice_low / 2
  3407. coords_span = (twice_high - twice_low) / 2
  3408. coords2 = (coords - coords_min).abs()
  3409. extra = torch.fmod(coords2, coords_span)
  3410. flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
  3411. return torch.where(
  3412. flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra
  3413. )
  3414. def compute_coordinates(coords: Tensor, size: int) -> Tensor:
  3415. if padding_mode == 0: # Zero
  3416. return coords
  3417. elif padding_mode == 1: # Borders
  3418. return torch.clamp(coords, 0, size - 1)
  3419. else: # padding_mode == 2, Reflection
  3420. if align_corners:
  3421. coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1))
  3422. else:
  3423. coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1)
  3424. return torch.clamp(coords_reflected, 0, size - 1)
  3425. def compute_source_index(coords: Tensor, size: int) -> Tensor:
  3426. coords_un = unnormalize(coords, size)
  3427. return compute_coordinates(coords_un, size)
  3428. N, C, iH, iW = a.shape
  3429. _, oH, oW, two = grid.shape
  3430. assert two == 2
  3431. if _expand_grid:
  3432. # Let's expand grid to [N, C, oH, oW, 2]
  3433. # This allows to generate a single triton cuda kernel instead of two kernels.
  3434. # Two kernels are due source indices, weights have shape (N, 1, oH, oW), xnumel=N*oH*oW
  3435. # and output has shape (N, C, oH, oW), xnumel=N*C*oH*oW
  3436. # Expanding grid to (N, C, oH, oW, two) unifies xnumel to N*C*oH*oW
  3437. grid = grid.view(N, 1, oH, oW, two).expand(N, C, oH, oW, 2)
  3438. def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor:
  3439. return torch.logical_and(
  3440. 0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys < iH))
  3441. )
  3442. N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1)
  3443. C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1)
  3444. def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType:
  3445. cond = in_bounds_cond(xs, ys)
  3446. # To clip to inside valid coordinates, we map the coordinates
  3447. # to (x, y) = (0, 0) and also set the weight to 0
  3448. # We also change the shape of the tensor to the appropriate one for
  3449. # broadcasting with N_idx, C_idx for the purposes of advanced indexing
  3450. c = C if _expand_grid else 1
  3451. return tuple(
  3452. torch.where(cond, t, 0).view(N, c, oH, oW)
  3453. for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws)
  3454. )
  3455. def get_summand(ix: Tensor, iy: Tensor, w) -> Tensor:
  3456. # Perform clipping, index into input tensor and multiply by weight
  3457. idx_x, idx_y, w_ = clip(ix, iy, w)
  3458. return a[N_idx, C_idx, idx_y, idx_x] * w_
  3459. x = grid[..., 0]
  3460. y = grid[..., 1]
  3461. if interpolation_mode == 0: # Bilinear
  3462. ix = compute_source_index(x, iW)
  3463. iy = compute_source_index(y, iH)
  3464. ix_nw, iy_nw = ix.floor(), iy.floor()
  3465. ix_ne, iy_ne = ix_nw + 1, iy_nw
  3466. ix_sw, iy_sw = ix_nw, iy_nw + 1
  3467. ix_se, iy_se = ix_ne, iy_sw
  3468. w_nw = (ix_se - ix) * (iy_se - iy)
  3469. w_ne = (ix - ix_sw) * (iy_sw - iy)
  3470. w_sw = (ix_ne - ix) * (iy - iy_ne)
  3471. w_se = (ix - ix_nw) * (iy - iy_nw)
  3472. return _sum_tensors(
  3473. get_summand(ix, iy, w)
  3474. for (ix, iy, w) in (
  3475. (ix_nw, iy_nw, w_nw),
  3476. (ix_ne, iy_ne, w_ne),
  3477. (ix_sw, iy_sw, w_sw),
  3478. (ix_se, iy_se, w_se),
  3479. )
  3480. )
  3481. elif interpolation_mode == 1: # Nearest
  3482. ix = compute_source_index(x, iW)
  3483. iy = compute_source_index(y, iH)
  3484. ix_nearest = ix.round()
  3485. iy_nearest = iy.round()
  3486. return get_summand(ix_nearest, iy_nearest, 1)
  3487. else: # interpolation_mode == 2, Bicubic
  3488. ix = unnormalize(x, iW)
  3489. iy = unnormalize(y, iH)
  3490. ix_nw = ix.floor()
  3491. iy_nw = iy.floor()
  3492. tx = ix - ix_nw
  3493. ty = iy - iy_nw
  3494. if not _expand_grid:
  3495. tx = tx.unsqueeze(1)
  3496. ty = ty.unsqueeze(1)
  3497. def get_value_bounded(ix: Tensor, iy: Tensor) -> Tensor:
  3498. x = compute_coordinates(ix, iW)
  3499. y = compute_coordinates(iy, iH)
  3500. return get_summand(x, y, 1)
  3501. def get_coeff(ofs: int) -> Tensor:
  3502. iy_ofs = iy_nw + (ofs - 1)
  3503. cs = (
  3504. get_value_bounded(ix_nw - 1, iy_ofs),
  3505. get_value_bounded(ix_nw, iy_ofs),
  3506. get_value_bounded(ix_nw + 1, iy_ofs),
  3507. get_value_bounded(ix_nw + 2, iy_ofs),
  3508. )
  3509. return _upsample_cubic_interp1d(cs, tx)
  3510. coeffs = tuple(get_coeff(ofs) for ofs in range(4))
  3511. return _upsample_cubic_interp1d(coeffs, ty)
  3512. @register_decomposition(aten.grid_sampler_2d)
  3513. @out_wrapper()
  3514. @pw_cast_for_opmath
  3515. def grid_sampler_2d(
  3516. a: Tensor,
  3517. grid: Tensor,
  3518. interpolation_mode: int = 0,
  3519. padding_mode: int = 0,
  3520. align_corners: bool = False,
  3521. ) -> Tensor:
  3522. return _grid_sampler_2d(
  3523. a,
  3524. grid=grid,
  3525. interpolation_mode=interpolation_mode,
  3526. padding_mode=padding_mode,
  3527. align_corners=align_corners,
  3528. )
  3529. @register_decomposition(aten.mv)
  3530. @out_wrapper()
  3531. @pw_cast_for_opmath
  3532. def mv(self, vec):
  3533. torch._check(
  3534. self.dim() == 2 and vec.dim() == 1,
  3535. lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}",
  3536. )
  3537. torch._check(
  3538. self.size(1) == vec.size(0),
  3539. lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})",
  3540. )
  3541. return (self * vec).sum(dim=1)
  3542. @register_decomposition(aten.binary_cross_entropy_with_logits)
  3543. @out_wrapper()
  3544. def binary_cross_entropy_with_logits(
  3545. self, target, weight=None, pos_weight=None, reduction=Reduction.MEAN.value
  3546. ):
  3547. if pos_weight is not None:
  3548. log_weight = (pos_weight - 1) * target + 1
  3549. loss = (1 - target) * self - (log_weight * F.logsigmoid(self))
  3550. else:
  3551. loss = (1 - target) * self - F.logsigmoid(self)
  3552. if weight is not None:
  3553. loss = loss * weight
  3554. return apply_loss_reduction(loss, reduction)
  3555. def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> bool:
  3556. # For comments of the logic of this function see eager in /native/LinearAlgebra.cpp
  3557. t1, t2 = (tensor1, tensor2) if tensor1.ndim >= tensor2.ndim else (tensor2, tensor1)
  3558. from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
  3559. if not (t1.ndim >= 3 and t2.ndim <= 2):
  3560. return False
  3561. if t2.requires_grad and not is_out:
  3562. return True
  3563. if tensor1.ndim == 2:
  3564. return False
  3565. if guard_size_oblivious(t1.numel() == 0):
  3566. return True
  3567. t1_shape = t1.shape
  3568. t1_stride = t1.stride()
  3569. return all(
  3570. st1 == st2 * s2
  3571. for (st1, st2, s2) in zip(t1_stride[:-2], t1_stride[1:-1], t1_shape[1:-1])
  3572. )
  3573. @aten.matmul.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  3574. @aten.matmul.out.py_impl(DispatchKey.CompositeImplicitAutograd)
  3575. @out_wrapper(pass_is_out=True)
  3576. def matmul(tensor1, tensor2, *, is_out=False):
  3577. dim_tensor1 = tensor1.dim()
  3578. dim_tensor2 = tensor2.dim()
  3579. assert dim_tensor1 != 0 and dim_tensor2 != 0
  3580. if dim_tensor1 == 1 and dim_tensor2 == 1:
  3581. return torch.dot(tensor1, tensor2)
  3582. elif dim_tensor1 == 2 and dim_tensor2 == 1:
  3583. return torch.mv(tensor1, tensor2)
  3584. elif dim_tensor1 == 1 and dim_tensor2 == 2:
  3585. return torch.squeeze(torch.mm(torch.unsqueeze(tensor1, 0), tensor2), 0)
  3586. elif dim_tensor1 == 2 and dim_tensor2 == 2:
  3587. return torch.mm(tensor1, tensor2)
  3588. elif should_fold(tensor1, tensor2, is_out):
  3589. # dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) ||
  3590. # dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2)
  3591. # and some condition on the strides is fulfilled
  3592. # optimization: use mm instead of bmm by folding the batch of the larger tensor
  3593. # into its leading matrix dimension
  3594. transpose = dim_tensor2 > dim_tensor1
  3595. t1 = tensor2.mT if transpose else tensor1
  3596. t2 = (
  3597. tensor2 if not transpose else (tensor1.t() if dim_tensor1 == 2 else tensor1)
  3598. )
  3599. # Invariant: t1.dim() >= 3 && (t2.dim() == 1 || t2.dim() == 2)
  3600. # and t1 and t2 are matmul-compatible
  3601. # Why not t1.view(-1, sizes_1[-1])?
  3602. # If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous.
  3603. # This can happen in e.g. [3, 5, 0] @ [0, 0].
  3604. sizes_1 = t1.shape
  3605. output_shape = list(sizes_1[:-1])
  3606. folded_dim1 = reduce(operator.mul, output_shape)
  3607. # Readjust output_shape if we are multiplying by a matrix
  3608. t2_is_matrix = t2.dim() == 2
  3609. if t2_is_matrix:
  3610. output_shape.append(t2.shape[1])
  3611. # This will almost always be a view.
  3612. # It may not be a view if t2->requires_grad(). See should_fold in aten/ for an explanation
  3613. t1_folded = t1.reshape(folded_dim1, sizes_1[-1])
  3614. if t2_is_matrix:
  3615. # This copies if we perform a 2D @ 3D and the first tensor requires_grad
  3616. # See should_fold native/LinearAlgebra.cpp for why.
  3617. output = t1_folded.mm(t2).view(output_shape)
  3618. return output.mT.contiguous() if transpose else output
  3619. else:
  3620. return t1_folded.mv(t2).view(output_shape)
  3621. elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
  3622. # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
  3623. # we track m1 vs m2 separately even though they must match for nicer error messages
  3624. n = tensor1.size(-2) if dim_tensor1 > 1 else 1
  3625. m1 = tensor1.size(-1)
  3626. batch_tensor1 = tensor1.shape[:-2]
  3627. m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1)
  3628. p = tensor2.size(-1) if dim_tensor2 > 1 else 1
  3629. batch_tensor2: List[int] = []
  3630. # TODO: handling of slice
  3631. for i in range(dim_tensor2 - 2):
  3632. batch_tensor2.append(tensor2.size(i))
  3633. # Same optimization for the gradients as that in should_fold
  3634. # If we're going to broadcast, we force it to go through the should_fold branch
  3635. if (
  3636. dim_tensor1 == 3
  3637. and dim_tensor2 == 3
  3638. and batch_tensor1[0] != batch_tensor2[0]
  3639. ):
  3640. if batch_tensor1[0] == 1 and tensor1.requires_grad:
  3641. return matmul(tensor1.squeeze(0), tensor2)
  3642. if batch_tensor2[0] == 1 and tensor2.requires_grad:
  3643. return matmul(tensor1, tensor2.squeeze(0))
  3644. # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
  3645. expand_batch_portion = list(
  3646. torch.broadcast_shapes(batch_tensor1, batch_tensor2)
  3647. )
  3648. tensor1_expand_size = expand_batch_portion + [n, m1]
  3649. expand_batch_product = prod(expand_batch_portion)
  3650. # HACK: We need reshape with symint support
  3651. tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape(
  3652. expand_batch_product, n, m1
  3653. )
  3654. vector_rhs = dim_tensor2 == 1
  3655. if vector_rhs:
  3656. tensor2_expand_size = expand_batch_portion + [m2]
  3657. tensor2_expanded = (
  3658. tensor2.expand(tensor2_expand_size)
  3659. .reshape(expand_batch_product, m2)
  3660. .unsqueeze(2)
  3661. )
  3662. else:
  3663. tensor2_expand_size = expand_batch_portion + [m2, p]
  3664. tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape(
  3665. expand_batch_product, m2, p
  3666. )
  3667. output_shape = expand_batch_portion
  3668. if dim_tensor1 > 1:
  3669. output_shape.append(n)
  3670. if dim_tensor2 > 1:
  3671. output_shape.append(p)
  3672. if vector_rhs:
  3673. return tensor1_expanded.bmm(tensor2_expanded).squeeze(-1).view(output_shape)
  3674. else:
  3675. return tensor1_expanded.bmm(tensor2_expanded).view(output_shape)
  3676. else:
  3677. torch._check(False, lambda: "both arguments to matmul need to be at least 1D")
  3678. @register_decomposition([aten.upsample_bicubic2d.default, aten.upsample_bicubic2d.out])
  3679. @aten.upsample_bicubic2d.default.py_impl(DispatchKey.Autograd)
  3680. @out_wrapper()
  3681. @pw_cast_for_opmath
  3682. def upsample_bicubic2d_default(
  3683. input: Tensor,
  3684. output_size: Tuple[int, int],
  3685. align_corners: bool,
  3686. scale_h: Optional[float] = None,
  3687. scale_w: Optional[float] = None,
  3688. ) -> Tensor:
  3689. # get dimensions of original image
  3690. _, _, in_h, in_w = input.shape
  3691. # Calculate horizontal and vertical scaling factor
  3692. h_scale_factor = _compute_scale(in_h, output_size[0], align_corners, scale_h)
  3693. w_scale_factor = _compute_scale(in_w, output_size[1], align_corners, scale_w)
  3694. _, dtype = utils.elementwise_dtypes(
  3695. input, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  3696. )
  3697. # We have to create arange with int64 dtype and use .to in order to avoid
  3698. # additional kernels creation in inductor and get a perf slowdown
  3699. i = torch.arange(output_size[0], device=input.device).to(dtype=dtype)
  3700. j = torch.arange(output_size[1], device=input.device).to(dtype=dtype)
  3701. x_float = _compute_source_index(w_scale_factor, j, align_corners)
  3702. y_float = _compute_source_index(h_scale_factor, i, align_corners)
  3703. y_float = y_float.unsqueeze(-1)
  3704. x = x_float.floor()
  3705. y = y_float.floor()
  3706. # We should also clamp xscale/yscale
  3707. # See guard_index_and_lambda in UpSample.h
  3708. yscale = (y_float - y).clamp(0.0, 1.0)
  3709. xscale = (x_float - x).clamp(0.0, 1.0)
  3710. x = x.to(torch.int64)
  3711. y = y.to(torch.int64)
  3712. iys_ofs = (y - 1, y, y + 1, y + 2)
  3713. ixs_ofs = (x - 1, x, x + 1, x + 2)
  3714. weights_x = _upsample_get_cubic_coefficients(xscale)
  3715. weights_y = _upsample_get_cubic_coefficients(yscale)
  3716. weights_precision_x, weights_precision_y = None, None
  3717. if input.dtype == torch.uint8:
  3718. weights_precision_x = _compute_weight_precision(weights_x)
  3719. weights_precision_y = _compute_weight_precision(weights_y)
  3720. weights_x = [
  3721. (w * (1 << weights_precision_x) + torch.sign(w) * 0.5).to(torch.int16)
  3722. for w in weights_x
  3723. ]
  3724. weights_y = [
  3725. (w * (1 << weights_precision_y) + torch.sign(w) * 0.5).to(torch.int16)
  3726. for w in weights_y
  3727. ]
  3728. def load_bounded(ys, xs):
  3729. y_idx = torch.clamp(ys, 0, in_h - 1)
  3730. x_idx = torch.clamp(xs, 0, in_w - 1)
  3731. v = aten._unsafe_index(input, [None, None, y_idx, x_idx])
  3732. return v
  3733. def get_x_interp(y):
  3734. src_x = tuple(load_bounded(y, x_ofs) for x_ofs in ixs_ofs)
  3735. if input.dtype == torch.uint8:
  3736. assert weights_precision_x is not None
  3737. return _sum_tensors_uint8(src_x, weights_x, weights_precision_x)
  3738. return _sum_tensors(c1 * c2 for (c1, c2) in zip(src_x, weights_x))
  3739. src_y = tuple(get_x_interp(y_ofs) for y_ofs in iys_ofs)
  3740. if input.dtype == torch.uint8:
  3741. assert weights_precision_y is not None
  3742. result = _sum_tensors_uint8(src_y, weights_y, weights_precision_y)
  3743. else:
  3744. result = _sum_tensors(c1 * c2 for (c1, c2) in zip(src_y, weights_y))
  3745. # convert output to correct memory format, if necessary
  3746. memory_format = utils.suggest_memory_format(input)
  3747. result = result.contiguous(memory_format=memory_format)
  3748. return result
  3749. @register_decomposition(aten.upsample_bicubic2d.vec)
  3750. @aten.upsample_bicubic2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  3751. @aten.upsample_bicubic2d.vec.py_impl(DispatchKey.Autograd)
  3752. @out_wrapper()
  3753. @pw_cast_for_opmath
  3754. def upsample_bicubic2d_vec(
  3755. a: Tensor,
  3756. output_size: Optional[Tuple[int, int]],
  3757. align_corners: bool,
  3758. scale_factors: Optional[Tuple[float, float]] = None,
  3759. ) -> Tensor:
  3760. torch._check(
  3761. bool(output_size) + bool(scale_factors) == 1,
  3762. lambda: "Must specify exactly one of output_size and scale_factors.",
  3763. )
  3764. if output_size is None:
  3765. assert scale_factors is not None
  3766. output_size = cast(
  3767. Tuple[int, int],
  3768. tuple(
  3769. sym_int(sym_float(w) * scale)
  3770. for w, scale in zip(a.shape[2:], scale_factors)
  3771. ),
  3772. )
  3773. scale_h, scale_w = scale_factors if scale_factors else (None, None)
  3774. return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w)
  3775. @register_decomposition(aten.reflection_pad1d)
  3776. @register_decomposition(aten.reflection_pad2d)
  3777. @register_decomposition(aten.reflection_pad3d)
  3778. @pw_cast_for_opmath
  3779. @out_wrapper()
  3780. def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
  3781. def idx(left, middle, right):
  3782. dim_idx = torch.arange(-left, middle + right, device=a.device)
  3783. return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
  3784. return _reflection_or_replication_pad(
  3785. a,
  3786. padding,
  3787. idx,
  3788. )
  3789. @register_decomposition(aten.replication_pad1d)
  3790. @register_decomposition(aten.replication_pad2d)
  3791. @register_decomposition(aten.replication_pad3d)
  3792. @pw_cast_for_opmath
  3793. @out_wrapper()
  3794. def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
  3795. def idx(left, middle, right):
  3796. dim_idx = torch.arange(-left, middle + right, device=a.device)
  3797. return torch.clamp(dim_idx, 0, middle - 1)
  3798. return _reflection_or_replication_pad(
  3799. a,
  3800. padding,
  3801. idx,
  3802. )
  3803. def _reflection_or_replication_pad(
  3804. a: Tensor,
  3805. padding: Tuple[int, ...],
  3806. idx_fn: Callable[[int, int, int], Tensor],
  3807. ) -> Tensor:
  3808. dim = len(padding) // 2
  3809. torch._check(
  3810. a.dim() in (dim + 1, dim + 2),
  3811. lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
  3812. )
  3813. inp_shape = a.shape[-dim:]
  3814. nc_dim = a.dim() - dim
  3815. padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
  3816. padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]
  3817. result = a
  3818. for i in range(dim):
  3819. idx: List[Any] = [None] * result.dim()
  3820. idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
  3821. result = aten._unsafe_index(result, idx)
  3822. # convert output to correct memory format, if necessary
  3823. memory_format = utils.suggest_memory_format(result)
  3824. result = result.contiguous(memory_format=memory_format)
  3825. return result
  3826. @register_decomposition(aten.aminmax)
  3827. @out_wrapper("min", "max")
  3828. def aminmax(self, *, dim=None, keepdim=False):
  3829. amin = torch.amin(self, dim=dim, keepdim=keepdim)
  3830. amax = torch.amax(self, dim=dim, keepdim=keepdim)
  3831. return amin, amax
  3832. @register_decomposition(aten.nansum)
  3833. @out_wrapper()
  3834. def nansum(self, dim=None, keepdim=False, *, dtype=None):
  3835. return aten.sum(torch.where(torch.isnan(self), 0, self), dim, keepdim, dtype=dtype)
  3836. @register_decomposition([aten.arange.default, aten.arange.out])
  3837. @out_wrapper()
  3838. def arange_default(
  3839. end: NumberType,
  3840. *,
  3841. dtype: Optional[torch.dtype] = None,
  3842. layout: torch.layout = torch.strided,
  3843. device: Optional[torch.device] = None,
  3844. pin_memory: bool = False,
  3845. ):
  3846. return aten.arange.start_step(
  3847. 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  3848. )
  3849. @register_decomposition([aten.arange.start])
  3850. def arange_start(
  3851. start: NumberType,
  3852. end: NumberType,
  3853. *,
  3854. dtype: Optional[torch.dtype] = None,
  3855. layout: torch.layout = torch.strided,
  3856. device: Optional[torch.device] = None,
  3857. pin_memory: bool = False,
  3858. ):
  3859. return aten.arange.start_step(
  3860. start, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  3861. )
  3862. @register_decomposition(out_dtype)
  3863. def out_dtype_decomp(*args, **kwargs):
  3864. from torch._higher_order_ops.out_dtype import out_dtype_dense
  3865. return out_dtype_dense(*args, **kwargs)
  3866. @register_decomposition(aten.multi_margin_loss)
  3867. @aten.multi_margin_loss.default.py_impl(DispatchKey.Autograd)
  3868. @out_wrapper()
  3869. def multi_margin_loss(
  3870. input: Tensor,
  3871. target: Tensor,
  3872. p: NumberType = 1,
  3873. margin: NumberType = 1,
  3874. weight: Optional[Tensor] = None,
  3875. reduction: int = Reduction.MEAN.value,
  3876. ) -> Tensor:
  3877. input = torch.atleast_2d(input)
  3878. target = torch.atleast_1d(target)
  3879. nframe = input.shape[0]
  3880. dim = input.shape[1]
  3881. torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported")
  3882. torch._check(
  3883. input.ndim == 2 and dim != 0,
  3884. lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {input.shape}",
  3885. )
  3886. torch._check(
  3887. target.ndim == 1 and target.numel() == nframe,
  3888. lambda: f"inconsistent target size, expected {nframe} but got {target.shape}",
  3889. )
  3890. if weight is not None:
  3891. weight = torch.atleast_1d(weight)
  3892. torch._check(
  3893. weight.ndim == 1 and weight.numel() == dim, # type: ignore[union-attr]
  3894. lambda: f"inconsistent weight size, expected {dim} but got {weight.shape}", # type: ignore[union-attr]
  3895. )
  3896. target = target.unsqueeze(1)
  3897. u = torch.gather(input, dim=1, index=target)
  3898. z = margin - u + input
  3899. z = z.clamp_min(0)
  3900. z = z if p == 1 else z * z
  3901. if weight is not None:
  3902. z = z * weight[target]
  3903. idx = torch.arange(dim, device=input.device)
  3904. z = torch.where(idx != target, z, 0)
  3905. if reduction == Reduction.MEAN.value:
  3906. return z.mean()
  3907. elif reduction == Reduction.SUM.value:
  3908. return z.sum() / z.shape[1]
  3909. else:
  3910. return z.mean(dim=1)
  3911. @register_decomposition(aten.multilabel_margin_loss_forward)
  3912. @aten.multilabel_margin_loss_forward.default.py_impl(DispatchKey.Autograd)
  3913. @out_wrapper("output", "is_target")
  3914. def multilabel_margin_loss_forward(
  3915. input: Tensor,
  3916. target: Tensor,
  3917. reduction: int,
  3918. ) -> Tuple[Tensor, Tensor]:
  3919. orig_input_shape = input.shape
  3920. orig_target_shape = target.shape
  3921. input = torch.atleast_2d(input)
  3922. target = torch.atleast_2d(target)
  3923. dim = input.shape[1]
  3924. torch._check(
  3925. len(orig_input_shape) <= 2 and dim != 0,
  3926. lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {orig_input_shape}",
  3927. )
  3928. torch._check(
  3929. len(orig_target_shape) <= 2 and orig_target_shape == orig_input_shape,
  3930. lambda: f"inconsistent target size: {orig_target_shape} for input of size: {orig_input_shape}",
  3931. )
  3932. # ignores labels after the first -1, detects when -1 is not present
  3933. idx = torch.arange(dim, device=target.device)
  3934. is_end = target == -1
  3935. end_idx = torch.amin(torch.where(is_end, idx, dim), dim=-1, keepdim=True)
  3936. # target indices
  3937. target_mask = idx < end_idx
  3938. # masks target to be able to use gather, which doesn't allow -1
  3939. tidx0 = torch.where(target_mask, target, 0)
  3940. u = torch.gather(input, dim=-1, index=tidx0)
  3941. # is_target
  3942. tidx1 = torch.where(target_mask, target, -1)
  3943. is_target = torch.any(idx == tidx1.unsqueeze(dim=-1), dim=1)
  3944. # loss
  3945. z = 1.0 - u.T.unsqueeze(dim=-1) + input
  3946. z = z.clamp_min(0)
  3947. z = z / dim
  3948. # masks loss
  3949. z = torch.where(is_target, 0, z)
  3950. # reduction
  3951. if reduction == Reduction.MEAN.value:
  3952. z = z.sum(dim=(0, -1)).mean()
  3953. elif reduction == Reduction.SUM.value:
  3954. z = z.sum()
  3955. else:
  3956. z = z.sum(dim=(0, -1))
  3957. # result
  3958. is_target = is_target.to(input.dtype).reshape(orig_target_shape)
  3959. return z, is_target
  3960. # scaled_dot_product_attention used to be decomposed in pre-autograd, given that
  3961. # it calls _scaled_dot_product_attention_math and
  3962. # _scaled_dot_product_attention_math only has a CompositeImplicitAutograd
  3963. # kernel. As a result it's decomposed into ops with finer granularity.
  3964. # However recent PRs (#103826 #105131 #115913) added new logic in
  3965. # scaled_dot_product_attention and now it calls
  3966. # _scaled_dot_product_flash_attention_for_cpu in export path. This results
  3967. # in _scaled_dot_product_flash_attention_for_cpu showing up in export result.
  3968. # This decomposition ensures scaled_dot_product_attention is still decomposed
  3969. # the same way as before, i.e., going through
  3970. # _scaled_dot_product_attention_math. Notice that this decomp rule should be
  3971. # excluded by inductor.
  3972. @register_decomposition(aten._scaled_dot_product_flash_attention_for_cpu.default)
  3973. def scaled_dot_product_flash_attention_for_cpu(
  3974. query: Tensor,
  3975. key: Tensor,
  3976. value: Tensor,
  3977. dropout_p: float = 0.0,
  3978. is_causal: bool = False,
  3979. *,
  3980. attn_mask: Optional[Tensor] = None,
  3981. scale: Optional[float] = None,
  3982. ) -> Tuple[Tensor, Tensor]:
  3983. dtype = query.dtype
  3984. torch._check(
  3985. torch.is_floating_point(query),
  3986. lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}",
  3987. )
  3988. torch._check(
  3989. query.dim() == 4 and key.dim() == 4 and value.dim() == 4,
  3990. lambda: f"q, k, v must be a 4 dimensional tensor, got {query.dim()}, {key.dim()}, {value.dim()}",
  3991. )
  3992. torch._check(
  3993. dropout_p == 0.0, lambda: f"dropout probability must be zero, got {dropout_p}"
  3994. )
  3995. torch._check(
  3996. query.shape[3] == value.shape[3] and key.shape[3] == value.shape[3],
  3997. lambda: "q, k, v should have the same head size",
  3998. )
  3999. output, attn = aten._scaled_dot_product_attention_math.default(
  4000. query,
  4001. key,
  4002. value,
  4003. attn_mask=attn_mask,
  4004. dropout_p=dropout_p,
  4005. is_causal=is_causal,
  4006. dropout_mask=None,
  4007. scale=scale,
  4008. )
  4009. # Why this change?
  4010. # In pre-dispatch export scaled_dot_product_attention is executed via
  4011. # * flash_attention.
  4012. # flash_attention allocates output tensor as (N, L, H, E)
  4013. # it then transposes that to get (N, H, L, E) which is supposed to be the return
  4014. # tensor dim for scaled_dot_product_attention
  4015. # assume x: [N, H, L, E] is the output sdpa
  4016. # In MHA code, this output is then permuted via (2, 0, 1, 3) to get
  4017. # (L, N, H, E) dim tensor
  4018. # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via
  4019. # x = x.view(L * N, H * E)
  4020. # During pre autograd dispatch call to contiguous is not traced because
  4021. # flash_attention output after the x.permute is already contiguous
  4022. # on which the view is valid
  4023. # However, during 2nd stage export, post-dispatch, we run _match variant
  4024. # instead of flash* to get the decomposition. _match variant returns
  4025. # x: [N, H, L, E] applying x.permute(2, 0, 1, 3) returns
  4026. # x: [L, N, H, E] and without converting this to contiguous tensor
  4027. # subsequent view is not valid and the export fails
  4028. # solution is to maintain the return tensor view from the decomp to be
  4029. # exactly same as *flash* variant.
  4030. # flash variants output is contiguous as [N, L, H, E]
  4031. # _match variant out is contiguous as [N, H, L, E]
  4032. # out = out.transpose(1, 2).contiguous gets output as contiguous
  4033. # in [N, L, H, E].
  4034. # Subsrequent transpose(1, 2) then returns a view on which
  4035. # aforementioned code snippet, as showm below, is valid
  4036. # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via
  4037. # x = x.view(L * N, H * E)
  4038. # Really the invariant you want to maintain is:
  4039. # pre-dispatch op-output and its decomposed representation must
  4040. # return tensor with same view and dims
  4041. output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format)
  4042. return (output.transpose(1, 2), attn)
  4043. def register_inplace(aten_op, outplace_op):
  4044. @register_decomposition(aten_op)
  4045. def inplace_op(*args, **kwargs):
  4046. out = outplace_op(*args, **kwargs)
  4047. return args[0].copy_(out)
  4048. return inplace_op
  4049. @register_decomposition([aten.baddbmm])
  4050. @out_wrapper()
  4051. @pw_cast_for_opmath
  4052. def baddbmm(self, batch1, batch2, beta=1, alpha=1):
  4053. if not self.is_floating_point() and not self.is_complex():
  4054. beta = int(beta)
  4055. alpha = int(alpha)
  4056. result = torch.bmm(batch1, batch2)
  4057. if not isinstance(alpha, numbers.Number) or alpha != 1:
  4058. result = result * alpha
  4059. if beta == 0:
  4060. return result
  4061. if not isinstance(beta, numbers.Number) or beta != 1:
  4062. self = self * beta
  4063. return self + result
  4064. @register_decomposition(aten.floor_divide)
  4065. @out_wrapper()
  4066. def floor_divide(self, other):
  4067. return torch.div(self, other, rounding_mode="floor")
  4068. @register_decomposition(aten.sym_numel)
  4069. def sym_numel(t):
  4070. return functools.reduce(operator.mul, t.shape, 1)
  4071. @register_decomposition([aten.sum.default, aten.sum.out])
  4072. def sum_default(
  4073. self: Tensor,
  4074. *,
  4075. dtype: Optional[torch.dtype] = None,
  4076. out: Optional[Tensor] = None,
  4077. ) -> Tensor:
  4078. if out is None:
  4079. return aten.sum.dim_IntList(self, [], dtype=dtype)
  4080. else:
  4081. return aten.sum.IntList_out(self, [], dtype=dtype, out=out)
  4082. @register_decomposition([aten.squeeze.default, aten.squeeze.dim])
  4083. def squeeze_default(self: Tensor, dim: Optional[int] = None):
  4084. if dim is None:
  4085. return aten.squeeze.dims(self, list(range(self.dim())))
  4086. else:
  4087. return aten.squeeze.dims(self, [dim])
  4088. @register_decomposition(torch.ops.aten._weight_norm_interface)
  4089. def _weight_norm_interface(x, y, dim=0):
  4090. # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58
  4091. keep_dim = tuple(i for i in range(len(x.shape)) if i != dim)
  4092. norm = x.norm(2, keep_dim, keepdim=True)
  4093. return x * (y / norm), norm
  4094. @register_decomposition(aten.isin)
  4095. @out_wrapper()
  4096. def isin(elements, test_elements, *, assume_unique=False, invert=False):
  4097. # handle when either elements or test_elements are Scalars (they can't both be)
  4098. if not isinstance(elements, torch.Tensor):
  4099. elements = torch.tensor(elements, device=test_elements.device)
  4100. if not isinstance(test_elements, torch.Tensor):
  4101. test_elements = torch.tensor(test_elements, device=elements.device)
  4102. if test_elements.numel() < 10.0 * pow(elements.numel(), 0.145):
  4103. return isin_default(elements, test_elements, invert=invert)
  4104. else:
  4105. return isin_sorting(
  4106. elements, test_elements, assume_unique=assume_unique, invert=invert
  4107. )
  4108. def isin_default(elements, test_elements, *, invert=False):
  4109. if elements.numel() == 0:
  4110. return torch.empty_like(elements, dtype=torch.bool)
  4111. x = elements.view(*elements.shape, *((1,) * test_elements.ndim))
  4112. if not invert:
  4113. cmp = x == test_elements
  4114. else:
  4115. cmp = x != test_elements
  4116. dim = tuple(range(-1, -test_elements.ndim - 1, -1))
  4117. return cmp.any(dim=dim)
  4118. def isin_sorting(elements, test_elements, *, assume_unique=False, invert=False):
  4119. elements_flat = elements.flatten()
  4120. test_elements_flat = test_elements.flatten()
  4121. if assume_unique:
  4122. # This is the same as the aten implementation. For
  4123. # assume_unique=False, we cannot use unique() here, so we use a
  4124. # version with searchsorted instead.
  4125. all_elements = torch.cat([elements_flat, test_elements_flat])
  4126. sorted_elements, sorted_order = torch.sort(all_elements, stable=True)
  4127. duplicate_mask = sorted_elements[1:] == sorted_elements[:-1]
  4128. duplicate_mask = torch.constant_pad_nd(duplicate_mask, [0, 1], False)
  4129. if invert:
  4130. duplicate_mask = duplicate_mask.logical_not()
  4131. mask = torch.empty_like(duplicate_mask)
  4132. mask = mask.index_copy(0, sorted_order, duplicate_mask)
  4133. return mask[0 : elements.numel()]
  4134. else:
  4135. sorted_test_elements, _ = torch.sort(test_elements_flat)
  4136. idx = torch.searchsorted(sorted_test_elements, elements_flat)
  4137. test_idx = torch.where(idx < sorted_test_elements.numel(), idx, 0)
  4138. cmp = sorted_test_elements[test_idx] == elements_flat
  4139. cmp = cmp.logical_not() if invert else cmp
  4140. return cmp.reshape(elements.shape)
  4141. @register_decomposition(aten.take)
  4142. @out_wrapper()
  4143. def take(self, index):
  4144. flattened = self.reshape(-1)
  4145. return flattened[index]
  4146. @register_decomposition(aten.resize_as)
  4147. def resize_as(self, other, memory_format=None):
  4148. if memory_format is None:
  4149. memory_format = torch.contiguous_format
  4150. if memory_format == torch.preserve_format:
  4151. memory_format = suggest_memory_format(other)
  4152. return aten.resize(self, other.shape, memory_format=memory_format)
  4153. register_inplace(aten.addbmm_, aten.addbmm)
  4154. register_inplace(aten.addmm_, aten.addmm)
  4155. register_inplace(aten.addmv_, aten.addmv)
  4156. register_inplace(aten.baddbmm_, aten.baddbmm)
  4157. register_inplace(aten.fill_, aten.fill)
  4158. register_inplace(aten.gelu_, aten.gelu)
  4159. register_inplace(aten.hardswish_, aten.hardswish)
  4160. register_inplace(aten.hardtanh_, aten.hardtanh)
  4161. register_inplace(aten.hardsigmoid_, aten.hardsigmoid)
  4162. register_inplace(aten.__iand__, aten.__and__)
  4163. register_inplace(aten.__ilshift__, aten.__lshift__)
  4164. register_inplace(aten.index_put_, aten.index_put)
  4165. register_inplace(aten.index_reduce_, aten.index_reduce)
  4166. register_inplace(aten.__ior__, aten.__or__)
  4167. register_inplace(aten.__irshift__, aten.__rshift__)
  4168. register_inplace(aten.__ixor__, aten.__xor__)
  4169. register_inplace(aten.leaky_relu_, aten.leaky_relu)
  4170. register_inplace(aten.logit_, aten.logit)
  4171. register_inplace(aten.relu_, aten.relu)
  4172. register_inplace(aten.renorm_, aten.renorm)
  4173. register_inplace(aten.round_, aten.round)
  4174. register_inplace(aten.scatter_, aten.scatter)
  4175. register_inplace(aten.scatter_add_, aten.scatter_add)
  4176. register_inplace(aten.scatter_reduce_, aten.scatter_reduce)
  4177. register_inplace(aten.silu_, aten.silu)