| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492 |
- # mypy: allow-untyped-defs
- import builtins
- import collections
- import inspect
- import itertools
- import math
- import operator
- import warnings
- from collections.abc import Iterable
- from enum import Enum
- from functools import partial, reduce, singledispatch, wraps
- from typing import Any, Callable, Dict, List, Optional, overload, Sequence, Tuple, Union
- import torch
- import torch._prims as prims
- import torch._prims_common as utils
- from torch import sym_float, sym_int
- from torch._prims_common import (
- BoolLike,
- DeviceLikeType,
- Dim,
- DimsSequenceType,
- DimsType,
- dtype_to_type,
- ELEMENTWISE_TYPE_PROMOTION_KIND,
- FloatLike,
- FloatWithoutSymFloat,
- IntLike,
- is_weakly_lesser_type,
- Number,
- NumberType,
- RealNumberType,
- REDUCTION_OUTPUT_TYPE_KIND,
- ShapeType,
- StrideType,
- TensorLike,
- TensorLikeType,
- TensorOrNumberLikeType,
- TensorSequenceType,
- )
- from torch._prims_common.wrappers import (
- _maybe_convert_to_dtype,
- _maybe_resize_out,
- _safe_copy_out,
- elementwise_type_promotion_wrapper,
- elementwise_unary_scalar_wrapper,
- out_wrapper,
- )
- # Experimental module containing prototype Python references for existing
- # PyTorch operations.
- __all__ = [
- #
- # Elementwise Unary References
- #
- "abs",
- "acos",
- "acosh",
- "asinh",
- "asin",
- "atan",
- "atanh",
- "bitwise_not",
- # "cbrt", # No corresponding torch operation
- "ceil",
- "conj_physical",
- "cos",
- "cosh",
- "count_nonzero",
- "deg2rad",
- "digamma",
- "erf",
- "erfinv",
- "erfc",
- "exp",
- "expm1",
- "exponential",
- "exp2",
- "fill",
- "fill_",
- "floor",
- "frac",
- "geometric",
- "index_add",
- "index_copy",
- "index_copy_",
- "index_select",
- "index_fill",
- "index_fill_",
- "isfinite",
- "isinf",
- "isposinf",
- "isneginf",
- "isnan",
- "isreal",
- "i0",
- "lerp",
- "lgamma",
- "log",
- "log1p",
- "log2",
- "log10",
- "log_normal",
- "log_softmax",
- "mvlgamma",
- "norm",
- "normal",
- "nan_to_num",
- "neg",
- "positive",
- "rad2deg",
- "reciprocal",
- "round", # TODO: model kwargs
- "sigmoid",
- "sgn",
- "sign",
- "signbit",
- "sin",
- "sinc",
- "sinh",
- "softmax",
- "sqrt",
- "square",
- "tan",
- "tanh",
- "trace",
- "trunc",
- #
- # Elementwise Binary References
- #
- "add",
- "atan2",
- "bitwise_and",
- "bitwise_left_shift",
- "bitwise_or",
- "bitwise_right_shift",
- "bitwise_xor",
- "clamp_min",
- "clamp_max",
- "copysign",
- "div",
- "eq",
- "float_power",
- "floor_divide",
- "fmax",
- "fmin",
- "fmod",
- "gcd",
- "ge",
- "gt",
- "heaviside",
- "hypot",
- "igamma",
- "igammac",
- "imag",
- "isclose",
- "lcm",
- # 'ldexp',
- "le",
- "logaddexp",
- "logaddexp2",
- "logical_and",
- "logical_not",
- "logical_or",
- "logical_xor",
- "logsumexp",
- "lt",
- # 'max', # implement with reductions
- "maximum",
- # 'min', # implement with reductions
- "minimum",
- "mul",
- "ne",
- "nextafter",
- # 'polar', # abs, cos, sin
- "pow",
- "real",
- "rpow",
- "remainder",
- "rsub",
- "rtruediv",
- "rfloordiv",
- "sub",
- "true_divide",
- "trunc_divide",
- "xlogy",
- #
- # Elementwise Ternary References
- #
- "addcdiv",
- "addcmul",
- "clamp",
- #
- # Conditional references
- #
- "masked_fill",
- "masked_fill_",
- "where",
- #
- # Data conversion and movement references
- #
- "clone",
- "copy_to", # TODO: add OpInfo (or implement .to)
- "item",
- "to",
- #
- # Reduction ops
- #
- "all",
- "amax",
- "amin",
- "any",
- "cumsum",
- "cumprod",
- "mean",
- "dot",
- "vdot",
- "std",
- "std_mean",
- "sum",
- "sum_to_size",
- "prod",
- "var",
- "var_mean",
- #
- # Linear algebra ops
- #
- "addr",
- #
- # View & Shape Ops
- #
- "alias",
- "atleast_1d",
- "atleast_2d",
- "atleast_3d",
- "as_strided",
- "as_strided_scatter",
- "block_diag",
- "broadcast_shapes",
- "broadcast_tensors",
- "broadcast_to",
- "cat",
- "chunk",
- "column_stack",
- "conj",
- "constant_pad_nd",
- "contiguous",
- "diag_embed",
- "diag",
- "diagonal",
- "diagonal_copy",
- "diagonal_scatter",
- "dsplit",
- "dstack",
- "expand",
- "expand_as",
- "flatten",
- "flip",
- "fliplr",
- "flipud",
- "hsplit",
- "hstack",
- "meshgrid",
- "movedim",
- "narrow",
- "narrow_copy",
- "native_group_norm",
- "native_layer_norm",
- "permute",
- "ravel",
- "repeat",
- "reshape",
- "reshape_as",
- "roll",
- "rot90",
- "rsqrt",
- "stack",
- "swap_axes", # alias for transpose
- "squeeze",
- "t",
- "T",
- "take_along_dim",
- "tensor_split",
- "transpose",
- "unfold",
- "unfold_copy",
- "unsqueeze",
- "view",
- "view_as",
- "vsplit",
- "vstack",
- "view_as_complex",
- "unflatten",
- "unbind",
- "triu",
- "tril",
- "triu_indices",
- "tril_indices",
- #
- # Tensor Creation
- #
- "arange",
- "cauchy",
- "empty",
- "empty_like",
- "empty_permuted",
- "empty_strided",
- "eye",
- "full",
- "full_like",
- "linspace",
- "logspace",
- "new_empty",
- "new_empty_strided",
- "new_full",
- "new_ones",
- "new_zeros",
- "ones",
- "ones_like",
- "randn",
- "scalar_tensor",
- "zero",
- "zeros",
- "zeros_like",
- #
- # Test-related functions
- #
- "allclose",
- "equal",
- #
- # Statistical operations
- #
- "bucketize",
- #
- # Misc
- #
- "is_complex",
- "renorm",
- "stft",
- "istft",
- ]
- Tensor = torch.Tensor
- DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
- aten = torch._ops.ops.aten
- # Note that the docstrings for the public methods from this file are in
- # torch/_torch_docs.py
- def is_noncontiguous_supported(device):
- if device is not None and device.type == "hpu":
- return False
- return True
- def handle_noncontiguous_outputs(input_tlist, output):
- device = None
- from torch._subclasses.fake_tensor import FakeTensor
- for t in input_tlist:
- if isinstance(t, FakeTensor):
- device = t.fake_device
- break
- if not is_noncontiguous_supported(device):
- output = output.contiguous()
- return output
- def _broadcast_shapes(*_shapes):
- from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
- shapes = tuple(
- (x,) if isinstance(x, IntLike) else x
- for x in filter(lambda x: x is not None, _shapes)
- )
- # Short-circuits on no input
- if len(shapes) == 0:
- return None
- # Type checking
- # TODO: make common validations available as utils
- for shape in shapes:
- assert isinstance(shape, Sequence)
- # Computes common shape
- common_shape = [
- 1,
- ] * reduce(max, (len(shape) for shape in shapes))
- for arg_idx, shape in enumerate(shapes):
- for idx in range(-1, -1 - len(shape), -1):
- if guard_size_oblivious(common_shape[idx] == 1):
- if shape[idx] < 0:
- raise ValueError(
- "Attempting to broadcast a dimension with negative length!"
- )
- common_shape[idx] = shape[idx]
- elif guard_size_oblivious(shape[idx] != 1):
- if common_shape[idx] != shape[idx]:
- raise RuntimeError(
- f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! "
- f"Mismatching argument at index {arg_idx} had {shape}; but expected shape "
- f"should be broadcastable to {common_shape}"
- )
- return common_shape
- def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
- # Computes common shape
- common_shape = _broadcast_shapes(
- *(t.shape if isinstance(t, TensorLike) else None for t in args)
- )
- def __maybe_broadcast(x, shape):
- if x is None:
- return None
- elif isinstance(x, Number):
- return x
- elif isinstance(x, TensorLike):
- if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
- return x
- if not utils.same_shape(x.shape, common_shape):
- return x.expand(common_shape)
- return x
- else:
- raise RuntimeError(
- "Unexpected type when broadcasting: " + str(type(x)) + "!"
- )
- return tuple(__maybe_broadcast(x, common_shape) for x in args)
- # Utilities should come BEFORE this import
- from torch._decomp import register_decomposition
- #
- # Elementwise unary references
- #
- infer_aten_op = object()
- # TODO: add type promotion support
- def _make_elementwise_unary_reference(
- type_promotion_kind,
- *,
- aten_op=infer_aten_op,
- extra_meta=None,
- ) -> Callable:
- def inner(prim: Callable):
- nonlocal aten_op
- @wraps(prim)
- @out_wrapper()
- @elementwise_unary_scalar_wrapper
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a",),
- type_promotion_kind=type_promotion_kind,
- )
- def _ref(a: TensorLikeType) -> TensorLikeType:
- if extra_meta is not None:
- extra_meta(a)
- output = prim(a)
- return handle_noncontiguous_outputs([a], output)
- if aten_op is infer_aten_op:
- aten_op = utils.get_aten_op(prim, prim.__name__)
- if aten_op is not None:
- register_decomposition(aten_op)(_ref)
- return _ref
- return inner
- def _make_alias(fn, name):
- """
- This function defines an alias of another function and sets its __name__ argument.
- It also sets its __module__ argument to the module of the caller.
- Note that when naively doing `alias = fn`, we have that `alias.__name__ == "fn"`, and
- `alias.__module__ == fn.__module__`.
- """
- def _fn(*args, **kwargs):
- return fn(*args, **kwargs)
- _fn.__name__ = name
- _fn.__module__ = inspect.currentframe().f_back.f_globals["__name__"] # type: ignore[union-attr]
- return _fn
- def _make_inplace(fn):
- """
- Given a function with out variant (i.e. using `out_wrapper()), it returns its in-place variant
- See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-do-in-place-operations-work-in-pytorch
- """
- # nb. We use the name of the first argument used in the unary references
- @wraps(fn)
- def _fn(a, *args, **kwargs):
- return fn(a, *args, out=a, **kwargs)
- inplace_name = f"{fn.__name__}_"
- _fn.__name__ = inplace_name
- _fn = register_decomposition(getattr(aten, inplace_name))(_fn)
- # We access the __all__ attribute of the module where fn is defined
- # There may be a cleaner way of doing this...
- from inspect import getmodule
- _all = getmodule(fn).__all__ # type: ignore[union-attr]
- if inplace_name not in _all:
- _all.append(inplace_name)
- return _fn
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT)
- def abs(a):
- return prims.abs(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def acos(a):
- return prims.acos(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def acosh(a):
- return prims.acosh(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def asin(a):
- return prims.asin(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def asinh(a):
- return prims.asinh(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def atan(a):
- return prims.atan(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def atanh(a):
- return prims.atanh(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- def bitwise_not(a):
- return prims.bitwise_not(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- def ceil(a):
- return prims.ceil(a)
- @register_decomposition(aten.is_complex)
- def is_complex(input: TensorLikeType):
- return utils.is_complex_dtype(input.dtype)
- @register_decomposition(aten.conj_physical)
- @out_wrapper()
- def conj_physical(input: TensorLikeType):
- if not utils.is_complex_dtype(input.dtype):
- return input
- return prims.conj_physical(input)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def cos(a):
- return prims.cos(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def cosh(a):
- return prims.cosh(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def digamma(a):
- return prims.digamma(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def erf(a):
- return prims.erf(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def erfinv(a):
- return prims.erf_inv(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def erfc(a):
- return prims.erfc(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def exp(a):
- return prims.exp(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def expm1(a):
- return prims.expm1(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def exp2(a):
- return prims.exp2(a)
- # Fill has its own implementation because it has a value parameter
- # CompositeImplicitAutograd - don't register decomp
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a,"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
- )
- def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
- assert isinstance(a, TensorLike)
- assert isinstance(value, Number)
- python_type = utils.dtype_to_type(a.dtype)
- if not utils.is_weakly_lesser_type(type(value), python_type):
- msg = f"value argument of type {type(value)} cannot be safely cast to type {python_type}!"
- raise ValueError(msg)
- return prims.fill(a, value)
- def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType:
- r = prims.fill(a, value)
- prims.copy_to(a, r)
- return a
- @register_decomposition(aten.zero)
- @out_wrapper()
- def zero(input: TensorLikeType) -> TensorLikeType:
- return torch.zeros_like(input)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- def floor(a):
- return prims.floor(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- def frac(x: TensorLikeType) -> TensorLikeType:
- trunc_x = torch.mul(torch.floor(torch.abs(x)), torch.sign(x))
- return torch.sub(x, trunc_x)
- # imag does not use _make_elementwise_unary_reference because it does not support out
- def imag(a: TensorLikeType) -> TensorLikeType:
- assert isinstance(a, TensorLike)
- torch._check(
- utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors."
- )
- return prims.imag(a)
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- aten_op=None, # CompositeImplicitAutograd
- )
- def isfinite(a: TensorLikeType) -> TensorLikeType:
- if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype):
- return prims.isfinite(a)
- return ones_like(a, dtype=torch.bool)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
- def isinf(a: TensorLikeType) -> TensorLikeType:
- if utils.is_complex_dtype(a.dtype):
- return torch.logical_or(isinf(torch.real(a)), isinf(torch.imag(a)))
- if utils.is_float_dtype(a.dtype):
- return torch.abs(a) == float("inf")
- return torch.zeros_like(a, dtype=torch.bool)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
- def isposinf(a: TensorLikeType) -> TensorLikeType:
- torch._check(
- not utils.is_complex_dtype(a.dtype),
- lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}",
- )
- if utils.is_float_dtype(a.dtype):
- return a == float("inf")
- return torch.zeros_like(a, dtype=torch.bool)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
- def isneginf(a: TensorLikeType) -> TensorLikeType:
- torch._check(
- not utils.is_complex_dtype(a.dtype),
- lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}",
- )
- if utils.is_float_dtype(a.dtype):
- return a == float("-inf")
- return torch.zeros_like(a, dtype=torch.bool)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
- def isnan(a: TensorLikeType) -> TensorLikeType:
- return prims.ne(a, a)
- # alias
- mvlgamma = _make_alias(torch.special.multigammaln, "mvlgamma") # type: ignore[has-type]
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- aten_op=None, # CompositeImplicitAutograd
- )
- def isreal(a: TensorLikeType) -> TensorLikeType:
- if utils.is_complex_dtype(a.dtype):
- return torch.imag(a) == 0
- return torch.ones_like(a, dtype=torch.bool)
- # TODO: if this is special maybe it should be defined there and imported here?
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.i0
- )
- def i0(a):
- return prims.bessel_i0(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def lgamma(a):
- return prims.lgamma(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def log(a):
- return prims.log(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def log1p(a):
- return prims.log1p(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def log2(a):
- return prims.log2(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def log10(a):
- return prims.log10(a)
- # CompositeImplicitAutograd - don't register decomp
- @out_wrapper()
- def log_softmax(
- a: TensorLikeType,
- dim: int,
- dtype: Optional[torch.dtype] = None,
- ) -> TensorLikeType:
- result_dtype = dtype or a.dtype
- computation_dtype = utils.get_computation_dtype(result_dtype)
- a_ = _maybe_convert_to_dtype(a, computation_dtype)
- return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value]
- @register_decomposition(aten.logsumexp)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self",),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def logsumexp(
- self: TensorLikeType, dim: DimsType, keepdim: bool = False
- ) -> TensorLikeType:
- if not isinstance(dim, Iterable):
- dim = (dim,)
- if self.numel() == 0:
- return torch.sum(torch.exp(self), dim, keepdim).log()
- maxes = torch.amax(self, dim, keepdim=True)
- maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0)
- maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim)
- result = torch.sum(torch.exp(self - maxes), dim, keepdim)
- return result.log().add(maxes_squeezed)
- @register_decomposition(aten.nan_to_num)
- @out_wrapper()
- def nan_to_num(
- a: TensorLikeType,
- nan: Optional[NumberType] = 0.0,
- posinf: Optional[NumberType] = None,
- neginf: Optional[NumberType] = None,
- ) -> TensorLikeType:
- assert isinstance(a, TensorLike)
- if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
- return a.clone()
- if nan is None:
- nan = 0.0
- if posinf is None:
- posinf = torch.finfo(a.dtype).max
- if neginf is None:
- neginf = torch.finfo(a.dtype).min
- result = torch.where(torch.isnan(a), nan, a) # type: ignore[call-overload]
- result = torch.where(torch.isneginf(a), neginf, result) # type: ignore[call-overload]
- result = torch.where(torch.isposinf(a), posinf, result) # type: ignore[call-overload]
- return result
- def _neg_meta(a: TensorLikeType):
- torch._check(
- a.dtype is not torch.bool,
- lambda: (
- "Negation, the `-` operator, on a bool tensor is not supported. "
- "If you are trying to invert a mask, use the `~` or `logical_not()` "
- "operator instead."
- ),
- )
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, extra_meta=_neg_meta
- )
- def neg(a):
- return prims.neg(a)
- # positive does not use _make_elementwise_unary_reference because it does not support out
- # CompositeImplicitAutograd - don't register decomp
- def positive(a: TensorLikeType) -> TensorLikeType:
- assert isinstance(a, TensorLike)
- if a.dtype is torch.bool:
- msg = "positive does not support bool tensors."
- raise RuntimeError(msg)
- return a
- # real does not use _make_elementwise_unary_reference because it does not support out
- def real(a: TensorLikeType) -> TensorLikeType:
- assert isinstance(a, TensorLike)
- if utils.is_complex_dtype(a.dtype):
- return prims.real(a)
- return a
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def reciprocal(a):
- return prims.reciprocal(a)
- @register_decomposition(aten.round)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a",),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def round(a: TensorLikeType, *, decimals: int = 0) -> TensorLikeType:
- if decimals == 0:
- return prims.round(a)
- else:
- ten_pow = 10**decimals
- ten_neg_pow = 10 ** (-decimals)
- return prims.mul(prims.round(prims.mul(a, ten_pow)), ten_neg_pow)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def rsqrt(a):
- return prims.rsqrt(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def sigmoid(a: TensorLikeType) -> TensorLikeType:
- return true_divide(1, add(1, exp(neg(a))))
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- def sgn(a):
- if utils.is_complex_dtype(a.dtype):
- a_abs = a.abs()
- return torch.where(a_abs == 0, 0, a / a_abs)
- else:
- return a.sign()
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- def sign(a):
- return prims.sign(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
- def signbit(a):
- return prims.signbit(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def sin(a):
- return prims.sin(a)
- # Autograd note: This will give the right first derivative at zero (by chance),
- # but not the right second derivative
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def sinc(a):
- a = math.pi * a
- return torch.where(a == 0, 1, torch.sin(a) / a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def sinh(a):
- return prims.sinh(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def sqrt(a):
- return prims.sqrt(a)
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
- aten_op=None, # CompositeImplicitAutograd,
- )
- def square(a: TensorLikeType) -> TensorLikeType:
- return mul(a, a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def tan(a):
- return prims.tan(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def tanh(a):
- return prims.tanh(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- def trunc(a):
- return prims.trunc(a)
- # TODO: register this as a real ref/decomposition once TorchInductor supports complex!
- def view_as_complex(self: TensorLikeType) -> TensorLikeType:
- input_dtype = self.dtype
- torch._check(
- utils.is_float_dtype(input_dtype),
- lambda: f"view_as_complex is only supported for floating point"
- f"tensors, but got a tensor of scalar type: {input_dtype}",
- )
- sizes = self.size()
- torch._check(
- len(sizes) != 0,
- lambda: "Input tensor must have one or more dimensions",
- )
- torch._check(
- sizes[-1] == 2,
- lambda: "Tensor must have a last dimension of size 2",
- )
- old_strides = self.stride()
- torch._check(
- old_strides[-1] == 1,
- lambda: "Tensor must have a last dimension with stride 1",
- )
- dims = old_strides[:-1]
- torch._check(
- py_all(stride % 2 == 0 for stride in dims),
- lambda: "Tensor must have a stride divisible by 2 for all but last dimension",
- )
- torch._check(
- self.storage_offset() % 2 == 0,
- lambda: "Tensor must have a storage_offset divisible by 2",
- )
- return prims.view_element_type(
- self, utils.corresponding_complex_dtype(input_dtype)
- ).squeeze(-1)
- def _make_elementwise_binary_reference(
- type_promotion_kind,
- aten_op=infer_aten_op,
- name=None,
- has_out=True,
- supports_lhs_python_scalar=True,
- supports_rhs_python_scalar=True,
- supports_two_python_scalars=False,
- should_register_decomposition=True,
- ) -> Callable:
- def inner(prim: Callable):
- nonlocal aten_op, name
- if name is None:
- name = prim.__name__
- @wraps(prim)
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a", "b"),
- type_promotion_kind=type_promotion_kind,
- )
- def _ref(
- a: Union[Tensor, NumberType],
- b: Union[Tensor, NumberType],
- ) -> Tensor:
- torch._check_value(
- supports_lhs_python_scalar or not isinstance(a, Number),
- lambda: f"{name}: Received a lhs Python scalar to an elementwise binary "
- "operation that does not accept lhs scalars!",
- )
- torch._check_value(
- supports_rhs_python_scalar or not isinstance(b, Number),
- lambda: f"{name}: Received a rhs Python scalar to an elementwise binary "
- "operation that does not accept rhs scalars!",
- )
- torch._check_value(
- supports_two_python_scalars
- or not (isinstance(a, Number) and isinstance(b, Number)),
- lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!",
- )
- a, b = _maybe_broadcast(a, b)
- output = prim(a, b)
- return handle_noncontiguous_outputs([a, b], output)
- if has_out:
- _ref = out_wrapper()(_ref)
- _ref.__name__ = name
- if aten_op is infer_aten_op:
- aten_op = utils.get_aten_op(prim, name)
- if aten_op is not None and should_register_decomposition:
- register_decomposition(aten_op)(_ref)
- return _ref
- return inner
- # Add has its own implementation because it has an alpha argument
- @register_decomposition(aten.add)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a", "b"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def add(
- a: Union[TensorLikeType, NumberType],
- b: Union[TensorLikeType, NumberType],
- *,
- alpha: Optional[NumberType] = None,
- ):
- """
- Reference implementation of torch.add
- """
- a, b = _maybe_broadcast(a, b)
- if alpha is not None:
- dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr]
- python_type = utils.dtype_to_type(dtype)
- if python_type != bool and not utils.is_weakly_lesser_type(
- type(alpha), python_type
- ):
- msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
- raise ValueError(msg)
- if isinstance(b, TensorLike):
- b = prims.mul(b, alpha)
- else:
- b = b * alpha
- output = prims.add(a, b)
- return handle_noncontiguous_outputs([a, b], output)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def atan2(a, b):
- return prims.atan2(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def bitwise_and(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.bitwise_and(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def bitwise_left_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.shift_left(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def bitwise_or(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.bitwise_or(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def bitwise_right_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.shift_right_arithmetic(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def bitwise_xor(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.bitwise_xor(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- supports_lhs_python_scalar=False,
- )
- def copysign(
- a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
- ):
- if isinstance(b, Number) and isinstance(a, Tensor):
- b = scalar_tensor(b, dtype=a.dtype, device=a.device)
- elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
- msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!"
- raise RuntimeError(msg)
- return where(signbit(b), neg(abs(a)), abs(a))
- # complex = _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- @register_decomposition(aten.div)
- @out_wrapper()
- def div(
- a: Union[TensorLikeType, NumberType],
- b: Union[TensorLikeType, NumberType],
- *,
- rounding_mode: Optional[str] = None,
- ):
- """
- Reference implementation of torch.div
- """
- if rounding_mode is None:
- return true_divide(a, b)
- elif rounding_mode == "trunc":
- return trunc_divide(a, b)
- elif rounding_mode == "floor":
- return floor_divide(a, b)
- else:
- msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
- raise ValueError(msg)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- supports_lhs_python_scalar=False,
- )
- def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.eq(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
- )
- def pow(
- a: Union[TensorLikeType, NumberType],
- b: Union[TensorLikeType, NumberType],
- ) -> TensorLikeType:
- assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType)
- if isinstance(b, Number):
- if b == 1.0:
- return a.clone() # type: ignore[return-value,union-attr]
- elif b == 2.0:
- return a * a # type: ignore[return-value]
- elif b == 0.5:
- return torch.sqrt(a) # type: ignore[arg-type]
- elif isinstance(a, Number):
- if a == 1.0:
- return torch.fill(b, True)
- if a == 2.0 and (
- utils.is_float_dtype(b.dtype) or utils.is_complex_dtype(b.dtype)
- ):
- return torch.exp2(b)
- return prims.pow(a, b)
- # Float power has its own implementation because it has unique type promotion.
- # CompositeImplicitAutograd - don't register decomp
- @out_wrapper()
- def float_power(
- a: Union[TensorLikeType, NumberType],
- b: Union[TensorLikeType, NumberType],
- ) -> Tensor:
- if isinstance(a, Number) and isinstance(b, Number):
- raise ValueError(
- "Receive two Number inputs to an elementwise binary operation!"
- )
- # Handles type promotion
- dtype = utils.get_higher_dtype(a, b)
- assert dtype is not None
- if utils.is_complex_dtype(dtype):
- dtype = torch.complex128
- else:
- dtype = torch.float64
- # Float power has the following contiguous cast behavior to be
- # consistent with its C++ impl
- a = _maybe_convert_to_dtype(a, dtype)
- b = _maybe_convert_to_dtype(b, dtype)
- a, b = _maybe_broadcast(a, b)
- return pow(a, b)
- # >>> a = torch.tensor(-0.2500, dtype=torch.float64)
- # tensor(-0.250000000000000, dtype=torch.float64)
- #
- # >>> b = torch.tensor(-0.0010, dtype=torch.float64)
- # tensor(-0.001000000000000, dtype=torch.float64)
- #
- # Note: In this case, casting float to double will expand the float mantissa with zeros,
- # while creating a double generates a distinct mantissa.
- # >>> torch.tensor(-0.001).to(dtype=torch.float64)
- # tensor(-0.001000000047497, dtype=torch.float64)
- #
- # Floor Division
- # The difference is caused because torch.remainder(a, b) = -0.001.
- #
- # >>> torch.floor(torch.true_divide(a, b))
- # tensor(250., dtype=torch.float64)
- #
- # >>> torch.div(a, b, rounding_mode='floor')
- # tensor(249., dtype=torch.float64)
- #
- # Definition: a // b = (a - remainder(a, b)) / b
- # >>> torch.true_divide(torch.sub(a, torch.remainder(a, b)), b)
- # tensor(249., dtype=torch.float64)
- #
- # For reference, see CPython's implementation:
- # https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
- @_make_elementwise_binary_reference(
- type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_two_python_scalars=True,
- should_register_decomposition=False,
- )
- def floor_divide(
- a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
- ):
- # Wrap scalars because some references only accept tensor arguments.
- if isinstance(a, Number) and isinstance(b, Number):
- a = scalar_tensor(a)
- b = scalar_tensor(b)
- elif isinstance(b, Number) and isinstance(a, Tensor):
- b = scalar_tensor(b, dtype=a.dtype, device=a.device)
- elif isinstance(a, Number) and isinstance(b, Tensor):
- a = scalar_tensor(a, dtype=b.dtype, device=b.device)
- elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
- if a.device == torch.device("cpu"):
- msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!"
- raise RuntimeError(msg)
- else:
- b = prims.device_put(b, device=a.device)
- assert isinstance(a, Tensor) and isinstance(b, Tensor)
- dtype = a.dtype
- if utils.is_float_dtype(dtype):
- return _floor_divide_float(a, b)
- elif utils.is_integer_dtype(dtype):
- return _floor_divide_integer(a, b)
- else:
- torch._check(False, lambda: f"{dtype} not supported for floor_divide")
- def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor:
- a, b = _maybe_broadcast(a, b)
- if not a.dtype.is_signed:
- return prims.div(a, b)
- # Convert truncation to flooring:
- offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0)
- return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype)
- def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor:
- mod = fmod(a, b)
- div = true_divide(sub(a, mod), b)
- # Ensure that the remainder has the same sign as denominator
- different_signed_inputs = bitwise_xor(lt(a, 0), lt(b, 0))
- non_zero_remainder = ne(mod, 0)
- mask = bitwise_and(non_zero_remainder, different_signed_inputs)
- div = where(mask, sub(div, 1), div)
- # Map quotient to nearest integer value
- floor_div = floor(div)
- mask = gt(sub(div, floor_div), 0.5)
- floor_div = where(mask, add(floor_div, 1), floor_div)
- basic_div = true_divide(a, b)
- zero_tensor = scalar_tensor(0, dtype=basic_div.dtype, device=basic_div.device)
- # If quotient is zero, copy signbit from true_divide quotient
- floor_div = where(ne(div, 0), floor_div, copysign(zero_tensor, basic_div))
- # If denominator is zero, then follow true_divide behavior
- return where(ne(b, 0), floor_div, basic_div)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def fmax(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.fmax(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def fmin(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.fmin(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=True,
- )
- def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.fmod(a, b)
- @register_decomposition(aten.frexp)
- @out_wrapper("mantissa", "exponent")
- def frexp(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]:
- return torch.return_types.frexp(prims.frexp(self))
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def gcd(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.gcd(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- supports_lhs_python_scalar=False,
- )
- def ge(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.ge(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- supports_lhs_python_scalar=False,
- )
- def gt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.gt(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType:
- input_eq_zero = torch.eq(input, 0)
- input_lt_zero = torch.logical_or(torch.lt(input, 0), torch.isnan(input))
- zeros_and_ones = torch.where(input_lt_zero, 0, 1)
- output = torch.where(input_eq_zero, values, zeros_and_ones)
- return output
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def hypot(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.hypot(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def igamma(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.igamma(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def igammac(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.igammac(a, b)
- def _check_close_args(
- name: str,
- a: TensorLikeType,
- b: TensorLikeType,
- rtol: float,
- atol: float,
- ) -> None:
- torch._check_value(
- a.dtype == b.dtype,
- lambda: f"{name}: Attempting to compare tensors of different dtypes {a.dtype} and {b.dtype}!",
- )
- torch._check(
- rtol >= 0,
- lambda: f"{name}: rtol must be greater than or equal to zero, but got {rtol}!",
- )
- torch._check(
- atol >= 0,
- lambda: f"{name}: atol must be greater than or equal to zero, but got {atol}!",
- )
- # CompositeImplicitAutograd - don't register decomp
- def isclose(
- a: TensorLikeType,
- b: TensorLikeType,
- rtol: float = 1e-05,
- atol: float = 1e-08,
- equal_nan: bool = False,
- ) -> TensorLikeType:
- _check_close_args(name="torch.isclose", a=a, b=b, rtol=rtol, atol=atol)
- close = eq(a, b)
- if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)):
- close = logical_or(close, logical_and(isnan(a), isnan(b)))
- # Note: In case of zero tolerances the closeness inequality degenerates to an equality check.
- # In this case, the short-circuit prevents false positives as detailed in the paragraph below.
- if atol == 0 and rtol == 0:
- return close
- # Note [closeness error computation]
- # atol and rtol are provided as doubles, so the computation
- # rtol * other will produce a float or complex tensor.
- # When the difference (self - other) is compared to it then the
- # tensor representing the difference will also be cast to float or complex.
- # However, since (self - other) in uint8 is very likely to produce a
- # negative value, this moves the cast forward so the difference is
- # always computed in a float or complex type.
- # If the values of the integer tensors cannot be exactly represented
- # by the default scalar type then this may cause an incorrect result.
- if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype):
- a = prims.convert_element_type(a, torch.get_default_dtype())
- b = prims.convert_element_type(b, torch.get_default_dtype())
- allowed_error = add(atol, abs(mul(b, rtol)))
- actual_error = abs(sub(a, b))
- # Computes finite closeness
- result = logical_or(
- close, logical_and(isfinite(actual_error), le(actual_error, allowed_error))
- )
- return result
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def lcm(a: TensorLikeType, b: TensorLikeType):
- dtype = a.dtype
- # promoting to int32 to maintain 100% consistency with C++ and to
- # prevent overflow in case of int8 and int16
- promote_to_int = dtype in (torch.int8, torch.int16)
- if promote_to_int:
- a = prims.convert_element_type(a, torch.int32)
- b = prims.convert_element_type(b, torch.int32)
- g = torch.gcd(a, b)
- # Avoid division by zero in case gcd(0, 0) == 0
- g = torch.where(g == 0, 1, g)
- res = torch.abs(prims.div(a, g) * b)
- return res if not promote_to_int else prims.convert_element_type(res, dtype)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- supports_lhs_python_scalar=False,
- )
- def le(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.le(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def logaddexp(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- # Nb. this implementation does not distribute the gradients evenly when a == b
- mask = torch.real(a) >= torch.real(b)
- max_ = torch.where(mask, a, b)
- min_ = torch.where(mask, b, a)
- inf_mask = torch.logical_and(
- torch.logical_not(torch.isfinite(torch.real(a))), torch.real(a) == torch.real(b)
- )
- if utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype):
- # are you wondering what this bunch of codes are for? edge cases!
- neg_min_mask = torch.real(min_) < 0
- inf_vals = torch.where(
- neg_min_mask, min_, torch.log(torch.exp(min_) + torch.exp(max_))
- )
- non_nan_vals = torch.where(
- inf_mask, inf_vals, max_ + torch.log1p(torch.exp(min_ - max_))
- )
- # the type for full_like does not include tensor yet
- nan_mask = torch.isnan(min_)
- return torch.where(nan_mask, complex(float("nan"), float("nan")), non_nan_vals) # type: ignore[call-overload]
- else:
- return torch.where(inf_mask, a, max_ + torch.log1p(torch.exp(min_ - max_)))
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def logaddexp2(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- torch._check(
- not (utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype)),
- lambda: "logaddexp2 doesn't support complex dtypes",
- )
- # Nb. this implementation does not distribute the gradients evenly when a == b
- mask = a >= b
- max_ = torch.where(mask, a, b)
- min_ = torch.where(mask, b, a)
- inf_mask = torch.logical_and(torch.isinf(a), a == b)
- inv_log_2 = 1.0 / math.log(2)
- result = max_ + torch.log1p(torch.exp2(min_ - max_)) * inv_log_2
- return torch.where(inf_mask, a, result)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- )
- def logical_and(a: TensorLikeType, b: TensorLikeType):
- if not utils.is_boolean_dtype(a.dtype):
- a = a != 0
- if not utils.is_boolean_dtype(b.dtype):
- b = b != 0
- return a & b
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
- def logical_not(a: TensorLikeType):
- if not utils.is_boolean_dtype(a.dtype):
- return a == 0
- return ~a
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- )
- def logical_or(a: TensorLikeType, b: TensorLikeType):
- if not utils.is_boolean_dtype(a.dtype):
- a = a != 0
- if not utils.is_boolean_dtype(b.dtype):
- b = b != 0
- return bitwise_or(a, b)
- # TODO: skip unnecessary conversion of long to float
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- )
- def logical_xor(a: TensorLikeType, b: TensorLikeType):
- if not utils.is_boolean_dtype(a.dtype):
- a = a != 0
- if not utils.is_boolean_dtype(b.dtype):
- b = b != 0
- return a ^ b
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- supports_lhs_python_scalar=False,
- )
- def lt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.lt(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def maximum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.maximum(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def minimum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.minimum(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_two_python_scalars=True,
- )
- def mul(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.mul(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- supports_lhs_python_scalar=False,
- )
- def ne(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.ne(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def nextafter(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.nextafter(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.remainder(a, b)
- # reverse sub
- @register_decomposition(aten.rsub)
- @out_wrapper()
- def rsub(
- a: Union[TensorLikeType, NumberType],
- b: Union[TensorLikeType, NumberType],
- alpha: NumberType = 1,
- ):
- if isinstance(a, Number):
- msg = "Received a Number for the first argument, but expected a Tensor"
- raise ValueError(msg)
- return torch.sub(b, a, alpha=alpha)
- # TODO: consider refactoring this with add impl
- # sub has its own implementation because it has an alpha argument
- @register_decomposition(aten.sub)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a", "b"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def sub(
- a: Union[TensorLikeType, NumberType],
- b: Union[TensorLikeType, NumberType],
- *,
- alpha: NumberType = 1,
- ):
- """
- Reference implementation of torch.sub
- """
- a, b = _maybe_broadcast(a, b)
- if alpha != 1:
- dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr]
- python_type = utils.dtype_to_type(dtype)
- if not utils.is_weakly_lesser_type(type(alpha), python_type):
- msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
- raise ValueError(msg)
- if isinstance(b, torch.Tensor):
- b = prims.mul(b, alpha)
- else:
- # Carefully not to use prims.mul if b is a scalar / symint.
- # prims.mul always returns a tensor,
- # which will mess with type promotion.
- b = b * alpha
- output = prims.sub(a, b)
- return handle_noncontiguous_outputs([a, b], output)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- name="true_divide",
- aten_op=None, # CompositeImplicitAutograd
- supports_two_python_scalars=True,
- )
- def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.div(a, b)
- @register_decomposition(aten.xlogy)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a", "b"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
- torch._check(
- isinstance(a, TensorLike) or isinstance(b, TensorLike),
- lambda: 'Expected either argument a or b to be a Tensor"',
- )
- # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
- if isinstance(b, TensorLike) and isinstance(a, Number):
- a = scalar_tensor(a, dtype=b.dtype, device=b.device)
- elif isinstance(a, TensorLike) and isinstance(b, Number):
- b = scalar_tensor(b, dtype=a.dtype, device=a.device)
- # mypy: expected "Tensor"
- assert isinstance(a, TensorLike)
- assert isinstance(b, TensorLike)
- rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log(b)))
- return torch.where(torch.isnan(b), float("nan"), rhs)
- @_make_elementwise_binary_reference(
- type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- aten_op=None, # CompositeImplicitAutograd
- supports_two_python_scalars=True,
- )
- def trunc_divide(
- a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
- ):
- dtype = utils.get_dtype(a)
- if utils.is_integer_dtype(dtype):
- return prims.div(a, b)
- return trunc(prims.div(a, b))
- #
- # Elementwise Ternary References
- #
- @register_decomposition(aten.addcdiv)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self", "tensor1", "tensor2"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def addcdiv(
- self: TensorLikeType,
- tensor1: TensorLikeType,
- tensor2: TensorLikeType,
- *,
- value: NumberType = 1,
- ) -> TensorLikeType:
- """
- Reference implementation of torch.addcdiv
- """
- if value is not None:
- dtype = self.dtype # no scalars allowed, see add
- python_type = utils.dtype_to_type(dtype)
- torch._check_value(
- utils.is_weakly_lesser_type(type(value), python_type),
- lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!",
- )
- return self + value * tensor1 / tensor2
- @register_decomposition(aten.addcmul)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self", "tensor1", "tensor2"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def addcmul(
- self: TensorLikeType,
- tensor1: TensorLikeType,
- tensor2: TensorLikeType,
- *,
- value: NumberType = 1,
- ) -> TensorLikeType:
- """
- Reference implementation of torch.addcmul
- """
- if value is not None:
- dtype = self.dtype # no scalars allowed, see add
- python_type = utils.dtype_to_type(dtype)
- torch._check_value(
- utils.is_weakly_lesser_type(type(value), python_type),
- lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!",
- )
- return self + value * tensor1 * tensor2
- @register_decomposition(aten.clamp)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a", "min", "max"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def clamp(
- a: TensorLikeType,
- min: Optional[TensorOrNumberLikeType] = None,
- max: Optional[TensorOrNumberLikeType] = None,
- ) -> TensorLikeType:
- # NOTE: grad behavior with implementation `where` is not consistent on `nan`
- if min is None and max is None:
- msg = "clamp called but both min and max are none!"
- raise ValueError(msg)
- if min is not None:
- a_isnan = torch.isnan(a)
- condition = torch.bitwise_or(torch.ge(a, min), a_isnan) # type: ignore[arg-type]
- # we should also propagate `nan` coming from boundaries. However, that's
- # not necessary since `ge` would already `False` when either operands has
- # a `nan`. So this line below is redundant
- # `condition = bitwise_and(condition, bitwise_not(isnan(min)))`
- a = torch.where(condition, a, min) # type: ignore[arg-type]
- if max is not None:
- a_isnan = torch.isnan(a)
- # same as above, no need to adjust `nan` from `max`
- condition = torch.bitwise_or(torch.le(a, max), a_isnan) # type: ignore[arg-type]
- a = torch.where(condition, a, max) # type: ignore[arg-type]
- return a
- @register_decomposition(aten.clamp_min)
- @out_wrapper()
- def clamp_min(
- self: TensorLikeType,
- min: Optional[TensorOrNumberLikeType] = None,
- ) -> TensorLikeType:
- return torch.clamp(self, min=min) # type: ignore[arg-type]
- @register_decomposition(aten.clamp_max)
- @out_wrapper()
- def clamp_max(
- self: TensorLikeType,
- max: Optional[TensorOrNumberLikeType] = None,
- ) -> TensorLikeType:
- return torch.clamp(self, max=max) # type: ignore[arg-type]
- #
- # Conditional references
- #
- # https://pytorch.org/docs/stable/generated/torch.where.html
- # TODO: implement alternate where
- @register_decomposition(aten.where)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a", "b"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
- )
- def where(
- pred: Tensor,
- a: Optional[TensorOrNumberLikeType] = None,
- b: Optional[TensorOrNumberLikeType] = None,
- ):
- """ """
- if a is None or b is None:
- raise NotImplementedError
- utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
- torch._check(
- pred.dtype is torch.bool,
- lambda: f"expected predicate to be bool, got {pred.dtype}",
- )
- pred, a, b = _maybe_broadcast(pred, a, b)
- return prims.where(pred, a, b)
- #
- # Data Movement References
- #
- @register_decomposition(aten.clone)
- @out_wrapper()
- def clone(
- a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
- ) -> TensorLikeType:
- result = prims.clone(a, memory_format=memory_format)
- return result
- def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True):
- if not allow_cross_device and a.device != b.device:
- msg = f"Attempting to copy from device {b.device} to device {a.device}, but cross-device copies are not allowed!"
- raise RuntimeError(msg)
- return prims.copy_to(a, b)
- @register_decomposition(aten.item)
- def item(a: TensorLikeType) -> NumberType:
- if a.numel() != 1:
- msg = f"Can't convert a tensor with {a.numel()} elements to a number!"
- raise ValueError(msg)
- # NOTE: explicit conversion is necessary for bool!
- # See https://github.com/pytorch/pytorch/issues/78071
- number_type = utils.dtype_to_type(a.dtype)
- return number_type(prims.item(a))
- # fast path when `to` returns an alias to input. This mimics the same function in aten
- def _to_will_alias(
- a: TensorLikeType,
- device: Optional[DeviceLikeType] = None,
- dtype: Optional[torch.dtype] = None,
- copy: Optional[bool] = None,
- layout: Optional[torch.layout] = None,
- memory_format: Optional[torch.memory_format] = None,
- pin_memory: Optional[bool] = False,
- non_blocking: bool = False, # not using non_blocking
- ) -> bool:
- return (
- not copy
- and (device is None or a.device == device)
- and (dtype is None or a.dtype == dtype)
- and (layout is None or a.layout == layout)
- # is_pinned issue #84925
- # and (pin_memory is None or pin_memory == a.is_pinned())
- and (
- memory_format is None
- or memory_format == torch.preserve_format
- or utils.is_contiguous_for_memory_format(a, memory_format=memory_format)
- )
- )
- @singledispatch
- def _to_dispatch(*args, **kwargs):
- raise NotImplementedError
- @_to_dispatch.register
- def _to_device(
- device: torch.device,
- dtype: torch.dtype,
- non_blocking: bool = False,
- copy: bool = False,
- memory_format: Optional[torch.memory_format] = None,
- ) -> Dict[str, Any]:
- kwargs = {
- "device": device,
- "dtype": dtype,
- "non_blocking": non_blocking,
- "copy": copy,
- "memory_format": memory_format,
- }
- return kwargs
- @_to_dispatch.register
- def _to_device_str(
- device: str,
- dtype: torch.dtype,
- non_blocking: bool = False,
- copy: bool = False,
- memory_format: Optional[torch.memory_format] = None,
- ) -> Dict[str, Any]:
- kwargs = {
- "device": torch.device(device),
- "dtype": dtype,
- "non_blocking": non_blocking,
- "copy": copy,
- "memory_format": memory_format,
- }
- return kwargs
- @_to_dispatch.register
- def _to_dtype(
- dtype: torch.dtype,
- non_blocking: bool = False,
- copy: bool = False,
- memory_format: Optional[torch.memory_format] = None,
- ) -> Dict[str, Any]:
- kwargs = {
- "dtype": dtype,
- "non_blocking": non_blocking,
- "copy": copy,
- "memory_format": memory_format,
- }
- return kwargs
- @_to_dispatch.register
- def _to_other(
- other: Tensor,
- non_blocking: bool = False,
- copy: bool = False,
- memory_format: Optional[torch.memory_format] = None,
- ) -> Dict[str, Any]:
- device = other.device
- dtype = other.dtype
- layout = other.layout
- # is_pinned issue #84925
- # pin_memory = other.is_pinned()
- kwargs = {
- "device": device,
- "dtype": dtype,
- "layout": layout,
- "non_blocking": non_blocking,
- "copy": copy,
- "memory_format": memory_format,
- }
- return kwargs
- # remove to_kwargs that is already present in `a`
- def _canonicalize_to_arguments(a: Tensor, to_kwargs: dict):
- options_to_check = ["dtype", "device", "layout", "memory_format"]
- # "device" option could be passed a str instead torch.device
- if "device" in to_kwargs and isinstance(to_kwargs["device"], str):
- to_kwargs["device"] = torch.device(to_kwargs["device"])
- for kw in options_to_check:
- if kw in to_kwargs:
- if (
- (kw == "memory_format" and to_kwargs[kw] is torch.preserve_format)
- or (
- kw == "device"
- and to_kwargs[kw].type == a.device.type
- and (
- not to_kwargs[kw].index or to_kwargs[kw].index == a.device.index
- )
- )
- or (
- getattr(a, kw, None) == to_kwargs[kw]
- ) # this also handles {"memory_format": None}
- ):
- to_kwargs.pop(kw)
- def to(a: TensorLikeType, *args, **kwargs) -> TensorLikeType:
- # handled dispatch via positional arguments
- if len(args) != 0:
- kwargs = _to_dispatch(*args, **kwargs)
- # TODO: is_pinned is not currently supported in refs or fake_tensor
- # https://github.com/pytorch/pytorch/issues/84925
- assert "pin_memory" not in kwargs
- _canonicalize_to_arguments(a, kwargs)
- if _to_will_alias(a, **kwargs):
- return a
- copy = kwargs.pop("copy") if "copy" in kwargs else False
- non_blocking = kwargs.pop("non_blocking") if "non_blocking" in kwargs else False
- # short-circuit to `prims.convert_element_type` when `to` is just a dtype change
- if (
- (copy or (kwargs.get("dtype", a.dtype) != a.dtype))
- and (not non_blocking)
- and ("memory_format" not in kwargs)
- and ("device" not in kwargs)
- and ("layout" not in kwargs)
- # is_pinned issue #84925
- # and ("pin_memory" not in kwargs)
- ):
- return prims.convert_element_type(a, kwargs.get("dtype", a.dtype))
- result = torch.empty_like(a, **kwargs)
- # TODO: non_blocking should be handled by `copy_to`
- copy_to(result, a)
- return result
- #
- # Reduction references
- #
- def _reduction(
- a: TensorLikeType,
- prim: Callable,
- *,
- has_identity: bool = True,
- accepts_dim_tuple: bool = True, # to handle min/argmin that accept single dim only
- dims: Optional[DimsType] = None,
- keepdims: bool = False,
- dtype: Optional[torch.dtype] = None, # should be specified for ops that support it
- out: Optional[Tensor] = None,
- output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
- ) -> TensorLikeType: # it is usually SAME, but I want
- # ref writers to actually think about what to put here
- assert isinstance(a, TensorLike)
- if a.ndim > 64:
- raise RuntimeError(
- f"Received a tensor with {a.ndim} dimensions, but only tensors with up to 64 dims are supported!"
- )
- if out is not None:
- assert isinstance(out, TensorLike)
- if dtype is not None:
- # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms
- if dtype != out.dtype:
- raise RuntimeError(
- "dtype argument and out dtype must match in reduction"
- )
- if not accepts_dim_tuple:
- assert dims is None or isinstance(dims, Dim)
- if isinstance(dims, Dim):
- dims = (dims,) # type: ignore[assignment]
- dims = utils.reduction_dims(a.shape, dims)
- if not has_identity:
- valid_shape = a.ndim == 0 or py_all(a.shape[i] for i in dims)
- if not valid_shape:
- raise RuntimeError(
- "reducing over zero-size dimension for reduction operation without identity"
- )
- computation_dtype, result_dtype = utils.reduction_dtypes(
- a, output_dtype_kind, dtype
- )
- a = _maybe_convert_to_dtype(a, computation_dtype) # type: ignore[method-assign]
- result = prim(a, dims)
- if keepdims:
- output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)]
- broadcast_dims = [i for i in range(a.ndim) if i not in dims]
- result = prims.broadcast_in_dim(result, output_shape, broadcast_dims)
- if out is not None:
- assert result_dtype is not None
- if dtype is not None and result_dtype != out.dtype:
- raise RuntimeError(
- "Expected the dtype of reduction result and out to match"
- )
- out = _maybe_resize_out(out, result.shape)
- return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type]
- if result.dtype != result_dtype and result_dtype is not None:
- result = prims.convert_element_type(result, result_dtype)
- return result
- def _make_copy_from_view(fn):
- """
- Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy)
- """
- name = fn.__name__
- fn = out_wrapper()(fn)
- def _fn(*args, out=None, **kwargs):
- result = fn(*args, out=out, **kwargs)
- if out is None:
- return result.clone(memory_format=torch.contiguous_format)
- return result
- copy_name = f"{name}_copy"
- _fn.__name__ = copy_name
- _fn = register_decomposition(getattr(aten, copy_name))(_fn)
- return _fn
- # Saves Python all
- py_all = all
- @register_decomposition(aten.all)
- @out_wrapper()
- def all(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- keepdim: bool = False,
- ) -> TensorLikeType:
- result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim))
- if a.dtype == torch.uint8:
- result = result.to(dtype=torch.uint8)
- return result
- # Saves Python any
- py_any = any
- @register_decomposition(aten.any)
- @out_wrapper()
- def any(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- keepdim: bool = False,
- ) -> TensorLikeType:
- a_ = _maybe_convert_to_dtype(a, torch.bool)
- if isinstance(dim, (list, tuple)) and len(dim) == 0:
- result = a_.clone()
- else:
- result = a_.sum(dim=dim, keepdim=keepdim).ne(False)
- # Preserves uint8 -- probably a legacy mask thing
- if a.dtype is torch.uint8:
- return prims.convert_element_type(result, torch.uint8)
- return result
- @register_decomposition([aten.sum.dim_IntList, aten.sum.IntList_out])
- def sum(
- a: TensorLikeType,
- dim: Union[Optional[int], Optional[List[int]]] = None,
- keepdim: bool = False,
- *,
- dtype: Optional[torch.dtype] = None,
- out: Optional[Tensor] = None,
- ) -> TensorLikeType:
- if dtype is None:
- if out is not None:
- dtype = out.dtype
- elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
- dtype = torch.int64
- else:
- dtype = a.dtype
- # reduces over all dimensions if dim=() is passed
- if dim == () or dim == []:
- dim = None
- return _reduction(
- a,
- prims.sum,
- dims=dim,
- keepdims=keepdim,
- dtype=dtype,
- out=out,
- output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
- )
- def sum_to_size(
- a: Tensor,
- *shape,
- ) -> Tensor:
- shape = utils.extract_shape_from_varargs(shape, validate=False)
- torch._check(
- utils.is_expandable_to(shape, a.shape),
- lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"',
- )
- # In ATen scalar tensors are sent through sum and the result is returned as
- # type promoted
- if utils.is_same_shape(shape, a.shape) and len(shape) > 0:
- return prims.view_of(a)
- leading_dims = a.ndim - len(shape)
- reduce_dims = tuple(range(leading_dims)) + tuple(
- i
- for i in range(leading_dims, len(shape))
- if shape[i - leading_dims] == 1 and a.shape[i] != 1
- )
- return torch.sum(a, dim=reduce_dims, keepdim=True, dtype=None)
- @register_decomposition(aten.prod)
- def prod(
- a: TensorLikeType,
- dim: Union[Optional[int], Optional[List[int]]] = None,
- keepdim: bool = False,
- *,
- dtype=None,
- out: Optional[Tensor] = None,
- ) -> TensorLikeType:
- if dtype is None:
- if out is not None:
- dtype = out.dtype
- elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
- dtype = torch.int64
- else:
- dtype = a.dtype
- # reduces over all dimensions if dim=() is passed
- if dim == () or dim == []:
- dim = None
- return _reduction(
- a,
- prims.prod,
- dims=dim,
- keepdims=keepdim,
- dtype=dtype,
- out=out,
- output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
- )
- @register_decomposition(aten.amin)
- def amin(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- keepdim: bool = False,
- *,
- out: Optional[Tensor] = None,
- ) -> TensorLikeType:
- # reduces over all dimensions if dim=() is passed
- if dim == () or dim == []:
- dim = None
- return _reduction(
- a,
- prims.amin,
- dims=dim,
- keepdims=keepdim,
- dtype=None,
- out=out,
- has_identity=False,
- output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
- )
- @register_decomposition(aten.amax)
- def amax(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- keepdim: bool = False,
- *,
- out: Optional[Tensor] = None,
- ) -> TensorLikeType:
- # reduces over all dimensions if dim=() is passed
- if dim == () or dim == []:
- dim = None
- return _reduction(
- a,
- prims.amax,
- dims=dim,
- keepdims=keepdim,
- dtype=None,
- out=out,
- has_identity=False,
- output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
- )
- def _dim_var_dispatch(dim=None, unbiased=None):
- # There's the following overload of torch.var:
- # var(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
- # We need to explicitly convert bool dims to unbiased arg
- if unbiased is None and isinstance(dim, bool):
- unbiased = dim
- dim = None
- return dim, unbiased
- @register_decomposition(aten.var)
- @out_wrapper()
- def var(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- unbiased: Optional[bool] = None,
- keepdim: bool = False,
- *,
- correction: Optional[NumberType] = None,
- ) -> TensorLikeType:
- dim, unbiased = _dim_var_dispatch(dim, unbiased)
- correction = utils.set_correction(unbiased, correction)
- # reduces over all dimensions if dim=() is passed
- if dim == () or dim == []:
- dim = None
- result = _reduction(
- a,
- partial(prims.var, correction=correction),
- dims=dim,
- keepdims=keepdim,
- dtype=None,
- out=None,
- has_identity=True,
- output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
- )
- return result
- @register_decomposition(aten.std)
- @out_wrapper()
- def std(
- a: TensorLikeType,
- dim: Union[Optional[int], Optional[List[int]]] = None,
- unbiased: Optional[bool] = None,
- keepdim: bool = False,
- *,
- correction: Optional[NumberType] = None,
- ) -> TensorLikeType:
- dim, unbiased = _dim_var_dispatch(dim, unbiased)
- correction = utils.set_correction(unbiased, correction)
- opmath_dtype, dtype = utils.reduction_dtypes(
- a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
- )
- a = _maybe_convert_to_dtype(a, opmath_dtype)
- a_var = torch.var(a, dim, correction=correction, keepdim=keepdim)
- a_std = torch.sqrt(a_var)
- assert dtype is not None
- return _maybe_convert_to_dtype(a_std, dtype)
- @register_decomposition(aten.mean)
- def mean(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- keepdim: bool = False,
- *,
- dtype=None,
- out=None,
- ) -> TensorLikeType:
- # reduces over all dimensions if dim=() is passed
- if dim == () or dim == []:
- dim = None
- orig_dtype = dtype
- if dtype is None:
- dtype = a.dtype
- # can't use out wrapper because of this argument
- torch._check(
- out is None or out.dtype == dtype,
- lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead",
- )
- result = _reduction(
- a,
- prims.sum,
- dims=dim,
- keepdims=keepdim,
- dtype=dtype,
- out=None,
- output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
- )
- torch._check(
- utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
- lambda: (
- f"mean(): could not infer output dtype. "
- f"{'Input' if orig_dtype is None else 'Optional'} dtype must be either "
- f"a floating point or complex dtype. Got: {dtype}"
- ),
- )
- if isinstance(dim, Dim):
- dim = (dim,) # type: ignore[assignment]
- dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type]
- nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1)
- result = true_divide(result, nelem)
- result_dtype = a.dtype if dtype is None else dtype
- result = _maybe_convert_to_dtype(result, result_dtype) # type: ignore[method-assign]
- if out is not None:
- assert isinstance(out, TensorLike)
- out = _maybe_resize_out(out, result.shape)
- return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type]
- return result
- @register_decomposition(aten.std_mean)
- @out_wrapper("out0", "out1")
- def std_mean(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- *,
- unbiased: Optional[bool] = None,
- keepdim: bool = False,
- correction: Optional[NumberType] = None,
- ):
- dim, unbiased = _dim_var_dispatch(dim, unbiased)
- correction = utils.set_correction(unbiased, correction)
- opmath_dtype, dtype = utils.reduction_dtypes(
- a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
- )
- original_dtype = a.dtype
- a = _maybe_convert_to_dtype(a, opmath_dtype)
- a_var, a_mean = torch.var_mean(a, dim, correction=correction, keepdim=keepdim)
- a_std = torch.sqrt(a_var)
- assert dtype is not None
- return (
- _maybe_convert_to_dtype(a_std, dtype),
- _maybe_convert_to_dtype(a_mean, original_dtype),
- )
- @register_decomposition(aten.var_mean)
- @out_wrapper("out0", "out1")
- def var_mean(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- unbiased: Optional[bool] = None,
- keepdim: bool = False,
- *,
- correction: Optional[NumberType] = None,
- ):
- dim, unbiased = _dim_var_dispatch(dim, unbiased)
- v = var(a, dim, unbiased, keepdim, correction=correction)
- m = mean(a, dim, keepdim)
- return v, m
- @register_decomposition(aten.addr)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self", "vec1", "vec2"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def addr(
- self: TensorLikeType,
- vec1: TensorLikeType,
- vec2: TensorLikeType,
- *,
- beta: NumberType = 1,
- alpha: NumberType = 1,
- ) -> TensorLikeType:
- torch._check(
- vec1.ndim == 1,
- lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D",
- )
- torch._check(
- vec2.ndim == 1,
- lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D",
- )
- self = self.expand(vec1.shape[0], vec2.shape[0])
- if utils.is_boolean_dtype(self.dtype):
- # Integers are accepted for booleans
- torch._check(
- is_weakly_lesser_type(type(beta), int),
- lambda: f"expected bool/int beta but got {type(beta)}",
- )
- torch._check(
- is_weakly_lesser_type(type(alpha), int),
- lambda: f"expected bool/int alpha but got {type(beta)}",
- )
- if not beta:
- return torch.outer(vec1, vec2) if alpha else torch.full_like(self, False)
- else:
- return torch.logical_or(
- self,
- torch.outer(vec1, vec2) if alpha else torch.full_like(self, False),
- )
- else:
- torch._check(
- is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)),
- lambda: f"cannot safely convert {type(beta)} to {self.dtype}",
- )
- torch._check(
- is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)),
- lambda: f"cannot safely convert {type(alpha)} to {self.dtype}",
- )
- if beta == 0:
- # This means NaNs from self are dropped if beta is zero
- return alpha * torch.outer(vec1, vec2)
- else:
- return beta * self + alpha * torch.outer(vec1, vec2)
- # CompositeImplicitAutograd - don't register decomp
- def atleast_1d(
- arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
- ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
- """Reference implementation of :func:`torch.atleast_1d`."""
- if not args and isinstance(arg, collections.abc.Sequence):
- args_ = arg
- else:
- assert not isinstance(arg, collections.abc.Sequence)
- args_ = (arg,) + args
- res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_)
- return res if len(res) > 1 else res[0]
- # Helper function with assert to avoid MyPy error
- # of incompatible type passed to unsqueeze
- def _unsqueeze_atleast(
- at_least_fn: Callable, dim: int, arg: TensorLikeType
- ) -> TensorLikeType:
- arg_ = at_least_fn(arg)
- assert isinstance(arg_, TensorLike)
- return unsqueeze(arg_, dim)
- # CompositeImplicitAutograd - don't register decomp
- def atleast_2d(
- arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
- ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
- """Reference implementation of :func:`torch.atleast_2d`."""
- if not args and isinstance(arg, collections.abc.Sequence):
- args_ = arg
- else:
- assert not isinstance(arg, collections.abc.Sequence)
- args_ = (arg,) + args
- unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0)
- res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_)
- return res if len(res) > 1 else res[0]
- # CompositeImplicitAutograd - don't register decomp
- def atleast_3d(
- arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
- ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
- """Reference implementation of :func:`torch.atleast_3d`."""
- if not args and isinstance(arg, collections.abc.Sequence):
- args_ = arg
- else:
- assert not isinstance(arg, collections.abc.Sequence)
- args_ = (arg,) + args
- unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1)
- res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_)
- return res if len(res) > 1 else res[0]
- def as_strided(
- a: TensorLikeType,
- size: ShapeType,
- stride: StrideType,
- storage_offset: Optional[int] = None,
- ) -> TensorLikeType:
- storage_offset_int = (
- storage_offset if storage_offset is not None else a.storage_offset()
- )
- return prims.as_strided(a, size, stride, storage_offset_int)
- @register_decomposition(aten.as_strided_scatter)
- @out_wrapper()
- def as_strided_scatter(
- input: TensorLikeType,
- src: TensorLikeType,
- size: ShapeType,
- stride: StrideType,
- storage_offset: Optional[int] = None,
- ) -> TensorLikeType:
- storage_offset_int = 0 if storage_offset is None else storage_offset
- return prims.as_strided_scatter(input, src, size, stride, storage_offset_int)
- def broadcast_shapes(*shapes) -> ShapeType:
- return torch.Size(_broadcast_shapes(*shapes))
- @aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.broadcast_tensors.default.py_impl(DispatchKey.Meta)
- def broadcast_tensors(*tensors) -> List[TensorLikeType]:
- if len(tensors) == 1 and not isinstance(tensors[0], Tensor):
- tensors = tensors[0]
- return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False))
- # CompositeImplicitAutograd - don't register decomp
- def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType:
- start = len(size) - len(a.shape)
- dims = tuple(range(start, len(a.shape) + start))
- return prims.broadcast_in_dim(a, size, dims)
- @register_decomposition(aten.cat)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("tensors",),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
- )
- def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
- def cat_compute_output_memory_format(inputs):
- format = None
- for t in inputs:
- f = utils.suggest_memory_format(t)
- if f == torch.contiguous_format:
- return f
- if format is not None and format != f:
- return torch.contiguous_format
- format = f
- assert format is not None
- return format
- if len(tensors) == 0:
- msg = "cat expects at least one tensor, but received zero!"
- raise ValueError(msg)
- for tensor in tensors:
- assert isinstance(tensor, TensorLike)
- utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)
- from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
- # This is a bit tricky. Naively, you would expect to just pick one
- # arbitrary tensor and check that all tensors match this tensor. However,
- # there is legacy behavior which says that if you have a 1-D empty tensor
- # (0,), this is permissible. So you can't assume that all the tensors
- # have same dimensionality, and you can't assume that the first tensor is
- # the correct stencil.
- #
- # We'll implement this in a few passes. First, we will try to infer the
- # ndim of the cat output. If this ndim != 1, then we know that all ndim =
- # 1 inputs must be empty, or are errors. If this ndim == 1, then life
- # is easy (the legacy special case coincides with regular handling).
- #
- # NB: The regular implementation of cat just filters out empty inputs,
- # but we do it slightly different here for better handling for unbacked
- # SymInts
- example = None
- for i, t in enumerate(tensors):
- if example is None:
- if t.ndim != 1:
- example = t
- else:
- if t.ndim != 1:
- torch._check(
- t.ndim == example.ndim,
- lambda: "Number of dimensions of tensors must match. "
- f"Expected {example.ndim}-D tensors, but got {t.ndim}-D for "
- f"tensor number {i} in the list",
- )
- if example is None:
- # example is None if everything is 1-D. If so, just arbitrarily pick
- # the first one
- example = tensors[0]
- shape = example.shape
- filtered = []
- for tensor_idx, tensor in enumerate(tensors):
- if len(shape) != len(tensor.shape):
- assert tensor.ndim == 1 # we've already checked this above
- # Don't suggest the legacy behavior in the error message
- torch._check(
- tensor.shape[0] == 0,
- lambda: f"Number of dimensions of tensors must match. "
- f"Expected {example.ndim}-D tensors, but got 1-D for "
- f"tensor number {tensor_idx} in the list",
- )
- else:
- # Remove inputs that are 1-D, zero size
- if tensor.ndim == 1 and guard_size_oblivious(tensor.shape[0] == 0):
- continue
- # Don't bother checking size match, prims.cat will handle it
- filtered.append(tensor)
- memory_format = cat_compute_output_memory_format(tensors)
- if len(filtered) == 0:
- t = tensors[0]
- # TODO: fix this to work with meta tensors
- try:
- requires_grad = any(x.requires_grad for x in tensors)
- except Exception:
- requires_grad = False
- return empty(
- (0,),
- dtype=t.dtype,
- device=t.device,
- requires_grad=requires_grad,
- memory_format=memory_format,
- )
- dim = utils.canonicalize_dim(filtered[0].ndim, dim)
- utils.validate_idx(filtered[0].ndim, dim)
- return prims.cat(filtered, dim).clone(memory_format=memory_format)
- # CompositeImplicitAutograd - don't register decomp
- @out_wrapper()
- def column_stack(tensors: TensorSequenceType) -> TensorLikeType:
- aligned_tensors = tuple(
- x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors
- )
- return cat(aligned_tensors, 1)
- def conj(input: TensorLikeType) -> TensorLikeType:
- if not utils.is_complex_dtype(input.dtype):
- return input
- if input.is_sparse:
- return torch.conj_physical(input)
- return prims.conj(input)
- # This replicates at::constant_pad_nd, defined in ATen/native/PadNd.cpp
- @register_decomposition(aten.constant_pad_nd)
- @out_wrapper()
- def constant_pad_nd(
- input: TensorLikeType, pad: List[int], value: NumberType = 0
- ) -> TensorLikeType:
- torch._check(
- len(pad) % 2 == 0,
- lambda: f"Length of pad must be even but instead it equals {len(pad)}",
- )
- input_sizes = input.shape
- l_inp = len(input_sizes)
- l_pad = len(pad) // 2
- l_diff = l_inp - l_pad
- torch._check(
- l_inp >= l_pad,
- lambda: "Length of pad should be no more than twice the number of "
- f"dimensions of the input. Pad length is {len(pad)} while the input has "
- f"{l_inp} dimensions.",
- )
- c_input = input
- for i in range(l_diff, l_inp):
- pad_idx = 2 * (l_inp - i - 1)
- if pad[pad_idx] < 0:
- c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx])
- if pad[pad_idx + 1] < 0:
- c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1])
- # if none of the pads are positive we can just return the result
- if builtins.all(p <= 0 for p in pad):
- return c_input.clone()
- new_shape = list(input_sizes[:l_diff])
- for i in range(l_pad):
- pad_idx = len(pad) - ((i + 1) * 2)
- new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
- torch._check(
- new_dim > 0,
- lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
- f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
- f"which is invalid. Check dimension {l_diff + i} of your input.",
- )
- new_shape.append(new_dim)
- memory_format = utils.suggest_memory_format(input)
- output = torch.empty(
- new_shape,
- dtype=input.dtype,
- device=input.device,
- requires_grad=input.requires_grad,
- memory_format=memory_format,
- )
- if value == 0 and input.dtype == torch.bool:
- value = False
- # torch.fill isn't typed to allow complex values
- output = torch.fill(output, value) # type: ignore[arg-type]
- c_output = output
- for i in range(l_diff, l_inp):
- pad_idx = 2 * (l_inp - i - 1)
- if pad[pad_idx] > 0:
- c_output = c_output.narrow(
- i, pad[pad_idx], c_output.shape[i] - pad[pad_idx]
- )
- if pad[pad_idx + 1] > 0:
- c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1])
- prims.copy_to(c_output, c_input)
- return output
- def contiguous(
- a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format
- ) -> Tensor:
- torch._check(
- memory_format != torch.preserve_format,
- lambda: "preserve memory format is unsupported by the contiguous operator",
- )
- if utils.is_contiguous_for_memory_format(a, memory_format=memory_format):
- return a
- return torch.clone(a, memory_format=memory_format)
- @out_wrapper()
- def dstack(tensors: TensorSequenceType) -> TensorLikeType:
- torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
- aligned_tensors = atleast_3d(*tensors)
- return cat(aligned_tensors, 2)
- @register_decomposition(aten.expand)
- def expand(a: Tensor, *shape) -> Tensor:
- from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
- # NOTE: cannot use utils.extract_shape_from_varargs here
- # because that also validates the shape, but the shape
- # given to expand may be "invalid"
- if len(shape) == 1 and isinstance(shape[0], Sequence):
- shape = tuple(shape[0])
- torch._check(
- len(shape) >= len(a.shape),
- lambda: "expand: the requested shape has too few dimensions!",
- )
- offset = len(shape) - len(a.shape)
- shape_ = list(shape)
- for idx, x in enumerate(a.shape):
- offset_idx = idx + offset
- requested_length = shape[offset_idx]
- torch._check(
- guard_size_oblivious(requested_length == x)
- or guard_size_oblivious(x == 1)
- or requested_length == -1,
- lambda: f"expand: attempting to expand a dimension of length {x}!",
- )
- shape_[offset_idx] = requested_length if requested_length != -1 else x
- # At this point shape must be valid
- utils.validate_shape(shape_)
- return prims.broadcast_in_dim(
- a, shape_, tuple(range(offset, len(a.shape) + offset))
- )
- # CompositeImplicitAutograd - don't register decomp
- def expand_as(a: Tensor, b: Tensor) -> Tensor:
- return a.expand(b.shape)
- def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]:
- if chunks <= 0:
- msg = f"Expected at least one chunk, but got {chunks}!"
- raise ValueError(msg)
- dim = utils.canonicalize_dim(a.ndim, dim)
- length = a.shape[dim]
- chunk_size = math.ceil(length / chunks)
- full_chunks = math.floor(length / chunk_size)
- tail_chunk_size = length % chunk_size
- result = []
- for i in range(full_chunks):
- result.append(narrow(a, dim, i * chunk_size, chunk_size))
- if tail_chunk_size != 0:
- result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size))
- return tuple(result)
- # Note: flatten, unlike other shape operators, returns the input tensor on a no-op (unless
- # a 0D tensor is flattened, in which case it's returned in 1D)
- # CompositeImplicitAutograd - don't register decomp
- def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType:
- start_dim = utils.canonicalize_dim(a.ndim, start_dim)
- end_dim = utils.canonicalize_dim(a.ndim, end_dim)
- # Short-circuits on no-op
- if start_dim == end_dim and a.ndim != 0:
- return a
- # Tries to take a view
- # TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view)
- new_shape, new_strides = prims._collapse_view_helper(a, start_dim, end_dim)
- if new_shape is not None:
- return prims.collapse_view(a, start_dim, end_dim)
- # Makes a copy if it can't make a view
- return prims.collapse(a, start_dim, end_dim)
- @register_decomposition(aten.flip)
- @out_wrapper()
- def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
- if not isinstance(dims, tuple) and not isinstance(dims, list):
- raise ValueError("dims has to be a sequence of ints")
- dims = utils.canonicalize_dims(a.ndim, dims) # type: ignore[assignment]
- utils.validate_no_repeating_dims(dims)
- return prims.rev(a, dims)
- # CompositeImplicitAutograd - don't register decomp
- def fliplr(a: TensorLikeType) -> TensorLikeType:
- if a.ndim < 2:
- raise RuntimeError("Input must be >= 2-d.")
- return flip(a, (1,))
- # CompositeImplicitAutograd - don't register decomp
- def flipud(a: TensorLikeType) -> TensorLikeType:
- if a.ndim < 1:
- raise RuntimeError("Input must be >= 1-d.")
- return flip(a, (0,))
- # CompositeImplicitAutograd - don't register decomp
- def narrow(
- a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int
- ) -> TensorLikeType:
- # Supports Tensor overload that was added for XLA:
- # https://github.com/pytorch/pytorch/issues/31558
- if isinstance(start, TensorLike):
- torch._check(
- start.dim() == 0 and utils.is_integer_dtype(start.dtype),
- lambda: "start must be an 0-dim integral Tensor.",
- )
- start = start.item() # type: ignore[assignment]
- torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
- torch._check(length >= 0, lambda: "narrow(): length must be non-negative.")
- dim = utils.canonicalize_dim(a.ndim, dim)
- dim_length = a.size(dim)
- torch._check_with(
- IndexError,
- -dim_length <= start and start <= dim_length, # type: ignore[arg-type]
- lambda: f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})",
- )
- if start < 0:
- start = start + dim_length
- torch._check(
- start <= dim_length - length, # type: ignore[arg-type]
- lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
- )
- return prims.slice_in_dim(a, start, start + length, axis=dim)
- # TODO: This must return a sparse tensor if the input is sparse, but refs have
- # no sparse support. See narrow_copy_sparse in core.
- narrow_copy = _make_copy_from_view(narrow)
- def _normalize(
- a: Tensor, norm_dims: DimsType, eps: float
- ) -> Tuple[Tensor, Tensor, Tensor]:
- """Computes mean and 1/std of a tensor along norm_dims.
- Used as a helper function for normalization layers.
- Args:
- a (Tensor): input tensor
- norm_dims (DimsType): dimensions to normalize over
- eps (float): epsilon for numerical stability
- Returns:
- out (Tensor): normalized tensor.
- mean (Tensor): mean of the tensor along norm_dims.
- rstd (Tensor): 1/std of the tensor along norm_dims.
- """
- norm_dims = utils.canonicalize_dims(a.ndim, norm_dims)
- computation_dtype = utils.get_computation_dtype(a.dtype)
- a_acc = _maybe_convert_to_dtype(a, computation_dtype)
- assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean
- biased_var, mean = torch.var_mean(
- a_acc, dim=norm_dims, unbiased=False, keepdim=True
- )
- rstd = torch.rsqrt(biased_var + eps)
- out = (a - mean) * rstd
- return out, mean, rstd
- # add all specified dimensions
- def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType:
- for dim in sorted(dimensions):
- x = torch.unsqueeze(x, dim)
- return x
- @register_decomposition(aten.native_group_norm.default)
- def native_group_norm(
- input: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- batch_size: int,
- num_channels: int,
- flattened_inner_size: int,
- num_groups: int,
- eps: float,
- ) -> Tuple[Tensor, Tensor, Tensor]:
- torch._check(
- input.ndim >= 2,
- lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
- )
- torch._check(
- num_channels % num_groups == 0,
- lambda: "Expected number of channels in input to be divisible by num_groups, "
- + f"but got input of shape {input.shape} and num_groups = {num_groups}",
- )
- # num_channels / num_groups and flattened inner dimension are the reduction axes
- reduction_dims = [2, 3]
- input_reshaped = torch.reshape(
- input,
- [batch_size, num_groups, num_channels // num_groups, flattened_inner_size],
- )
- out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps)
- out = out.view(input.shape)
- broadcast_dims = [0] + list(range(2, input.ndim))
- unsqueeze_bias = None
- if bias is not None:
- unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims)
- unsqueeze_weight = None
- if weight is not None:
- unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims)
- if unsqueeze_weight is not None:
- out = out * unsqueeze_weight
- if unsqueeze_bias is not None:
- out = out + unsqueeze_bias
- out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment]
- mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment]
- rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment]
- # remove broadcast dimensions from mean and rstd
- mean = torch.squeeze(mean, reduction_dims)
- rstd = torch.squeeze(rstd, reduction_dims)
- return (out, mean, rstd)
- @register_decomposition(aten.native_layer_norm)
- @out_wrapper("out0", "out1", "out2")
- def native_layer_norm(
- input: Tensor,
- normalized_shape: ShapeType,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- eps: float,
- ) -> Tuple[Tensor, Tensor, Tensor]:
- normalized_ndim = len(normalized_shape)
- torch._check(
- normalized_ndim >= 1,
- lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., "
- + "containing at least one element, but got normalized_shape = "
- + str(normalized_shape),
- )
- # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False
- # while torch.Size([1, 2, 3]) == (1, 2, 3) is True
- # therefore we use tuple(normalized_shape)
- torch._check(
- weight is None or weight.shape == tuple(normalized_shape),
- lambda: "Expected weight to be of same shape as normalized_shape, but got "
- + "weight of shape "
- + str(weight.shape) # type: ignore[union-attr]
- + " and normalized_shape = "
- + str(normalized_shape),
- )
- torch._check(
- bias is None or bias.shape == tuple(normalized_shape),
- lambda: "Expected bias to be of same shape as normalized_shape, but got "
- + "bias of shape "
- + str(bias.shape) # type: ignore[union-attr]
- + " and normalized_shape = "
- + str(normalized_shape),
- )
- torch._check(
- input.ndim >= normalized_ndim
- and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
- lambda: "Given normalized_shape="
- + str(normalized_shape)
- + ", expected input with shape "
- + str(normalized_shape)
- + ", but got input of size "
- + str(input.shape),
- )
- input = input.contiguous()
- if weight is not None:
- weight = weight.contiguous()
- if bias is not None:
- bias = bias.contiguous()
- axis = input.ndim - normalized_ndim
- reduction_dims = list(range(axis, input.ndim))
- out, mean, rstd = _normalize(input, reduction_dims, eps)
- if weight is None and bias is not None:
- out = out + bias
- elif weight is not None and bias is None:
- out = out * weight
- elif weight is not None and bias is not None:
- out = out * weight + bias
- out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment]
- if input.device.type == "cpu":
- mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment]
- rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment]
- return (out, mean, rstd)
- # TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode.
- # test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu
- @register_decomposition(aten.permute)
- def permute(a: TensorLikeType, *dims) -> TensorLikeType:
- _permutation = utils.canonicalize_dims(
- a.ndim, utils.extract_dims_from_varargs(dims)
- )
- return prims.transpose(a, _permutation)
- @register_decomposition(aten.renorm)
- @out_wrapper()
- def renorm(
- input: TensorLikeType, p: RealNumberType, dim: int, maxnorm: RealNumberType
- ) -> TensorLikeType:
- torch._check(not isinstance(p, complex), lambda: "renorm: p must be real-valued")
- torch._check(p > 0, lambda: "renorm: non-positive norm not supported")
- torch._check(
- not isinstance(maxnorm, complex), lambda: "renorm: maxnorm must be real-valued"
- )
- torch._check(
- maxnorm >= 0, lambda: f"renorm: expected maxnorm to be >= 0 but got {maxnorm}"
- )
- ndim = input.ndim
- torch._check(
- ndim > 1,
- lambda: f"renorm: input needs at least 2 dimensions, got {ndim} dimensions",
- )
- dim = utils.canonicalize_dim(ndim, dim)
- reduce_dims = list(range(ndim))
- del reduce_dims[dim]
- # For half and bfloat16, calculate norm in float precision then cast
- # normalization factor to half
- acc_type = utils.get_computation_dtype(input.dtype)
- if acc_type != input.dtype:
- norm = torch.linalg.vector_norm(
- input, p, reduce_dims, keepdim=True, dtype=acc_type
- )
- else:
- norm = torch.linalg.vector_norm(input, p, reduce_dims, keepdim=True)
- eps = 1e-7
- norm_factor = torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0)
- if acc_type != input.dtype:
- norm_factor = prims.convert_element_type(norm_factor, input.dtype)
- return (input * norm_factor).contiguous()
- # CompositeImplicitAutograd - don't register decomp
- @aten.stft.center.py_impl(DispatchKey.CompositeImplicitAutograd)
- def stft(
- input: Tensor,
- n_fft: int,
- hop_length: Optional[int] = None,
- win_length: Optional[int] = None,
- window: Optional[Tensor] = None,
- center: bool = True,
- pad_mode: str = "reflect",
- normalized: bool = False,
- onesided: Optional[bool] = None,
- return_complex: Optional[bool] = None,
- ) -> Tensor:
- torch._check(
- window is None or window.device == input.device,
- lambda: (
- f"stft input and window must be on the same device but got self on {input.device}"
- + f" and window on {window.device}" # type: ignore[union-attr]
- ),
- )
- hop_length_ = hop_length if hop_length is not None else n_fft // 4
- win_length_ = win_length if win_length is not None else n_fft
- if return_complex is None:
- return_complex_ = input.is_complex() or (
- window is not None and utils.is_complex_dtype(window.dtype)
- )
- torch._check(
- return_complex_,
- (
- "stft requires the return_complex parameter be given for real inputs, "
- + "and will further require that return_complex=True in a future PyTorch release."
- ),
- )
- else:
- return_complex_ = return_complex
- torch._check(
- utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype),
- lambda: "stft expected a tensor of floating point or complex values",
- )
- torch._check(1 <= input.ndim <= 2, lambda: "stft expected a 1D or 2D tensor")
- original_ndim = input.ndim
- if original_ndim == 1:
- input = input.unsqueeze(0)
- if center:
- extra_dims = 3 - input.ndim
- pad_amount = n_fft // 2
- extended_shape = [*itertools.repeat(1, extra_dims), *input.shape]
- input = aten.pad(input.view(extended_shape), [pad_amount, pad_amount], pad_mode)
- input = input.view(input.size()[extra_dims:])
- batch = input.size(0)
- length = input.size(1)
- torch._check(
- 0 < n_fft <= length,
- lambda: f"stft expected 0 < n_fft <= {length}, but got n_fft={n_fft}",
- )
- torch._check(
- hop_length_ > 0,
- lambda: f"stft expected hop_length > 0 but got hop_length={hop_length_}",
- )
- torch._check(
- 0 < win_length_ <= n_fft,
- lambda: f"stft expected 0 < win_length <= n_fft but got win_length={win_length_}",
- )
- torch._check(
- window is None or window.shape == (win_length_,),
- lambda: (
- f"expected a 1D window tensor of size equal to win_length={win_length_}, "
- + f"but got window with size {window.shape}" # type: ignore[union-attr]
- ),
- )
- if win_length_ < n_fft:
- if window is None:
- window = torch.ones(win_length_, dtype=input.dtype, device=input.device)
- left = (n_fft - win_length_) // 2
- window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left])
- input = input.unfold(dimension=-1, size=n_fft, step=hop_length_)
- if window is not None:
- input = input * window
- complex_fft = utils.is_complex_dtype(input.dtype)
- onesided = onesided if onesided is not None else not complex_fft
- norm = "ortho" if normalized else None
- if onesided:
- torch._check(
- not complex_fft,
- lambda: "Cannot have onesided output if window or input is complex",
- )
- out = torch.fft.rfft(input, dim=-1, norm=norm)
- else:
- out = torch.fft.fft(input, dim=-1, norm=norm)
- out.transpose_(1, 2)
- if original_ndim == 1:
- out = out.squeeze_(0)
- return out if return_complex_ else torch.view_as_real(out)
- # CompositeImplicitAutograd - don't register decomp
- @aten.istft.default.py_impl(DispatchKey.CompositeImplicitAutograd)
- def istft(
- input: Tensor,
- n_fft: int,
- hop_length: Optional[int] = None,
- win_length: Optional[int] = None,
- window: Optional[Tensor] = None,
- center: bool = True,
- normalized: bool = False,
- onesided: Optional[bool] = None,
- length: Optional[int] = None,
- return_complex=False,
- ) -> Tensor:
- torch._check(
- window is None or window.device == input.device,
- lambda: (
- f"istft input and window must be on the same device but got self on {input.device}"
- + f" and window on {window.device}" # type: ignore[union-attr]
- ),
- )
- hop_length_ = hop_length if hop_length is not None else n_fft // 4
- win_length_ = win_length if win_length is not None else n_fft
- torch._check(
- utils.is_complex_dtype(input.dtype),
- lambda: (
- "istft input and window must be on the same device but got self on "
- + f"{input.device} and window on {window.device}" # type: ignore[union-attr]
- ),
- )
- n_frames = input.size(-1)
- fft_size = input.size(-2)
- expected_output_signal_len = n_fft + hop_length_ * (n_frames - 1)
- torch._check(input.numel() > 0, lambda: "istft input tensor cannot be empty")
- torch._check(
- 2 <= input.ndim <= 3,
- lambda: f"istft expected a tensor with 2 or 3 dimensions, but got {input.ndim}",
- )
- onesided_ = onesided if onesided is not None else fft_size != n_fft
- if onesided_:
- torch._check(
- n_fft // 2 + 1 == fft_size,
- lambda: (
- "istft expected the frequency dimension (3rd to the last) of the input tensor "
- + "to match n_fft / 2 + 1 when onesided=True, but got {fft_size}"
- ),
- )
- else:
- torch._check(
- n_fft == fft_size,
- lambda: (
- "istft expected the frequency dimension (3rd to the last) of the input tensor "
- + "to match n_fft when onesided=False, but got {fft_size}",
- ),
- )
- torch._check(
- 0 < hop_length_ <= win_length_,
- lambda: "istft expected 0 < hop_length <= win_length",
- )
- torch._check(
- 0 < win_length_ <= n_fft, lambda: "istft expected 0 < win_length <= n_fft"
- )
- torch._check(
- window is None or window.shape == (win_length_,),
- lambda: "Invalid window shape. window has to be 1D and length of `win_length`",
- )
- if window is None:
- real_dtype = utils.corresponding_real_dtype(input.dtype)
- window_ = torch.ones(win_length_, dtype=real_dtype, device=input.device)
- else:
- window_ = window
- if win_length_ != n_fft:
- left = (n_fft - win_length_) // 2
- window_ = aten.constant_pad_nd(window_, (left, n_fft - win_length_ - left), 0)
- original_ndim = input.ndim
- if input.ndim == 2:
- input = input.unsqueeze(0)
- input = input.transpose(1, 2)
- norm = "ortho" if normalized else None
- if return_complex:
- torch._check(
- not onesided_,
- lambda: "cannot have onesided output if window or input is complex",
- )
- input = torch.fft.ifft(input, dim=-1, norm=norm)
- else:
- torch._check(
- window is None or not utils.is_complex_dtype(window.dtype),
- lambda: "Complex windows are incompatible with return_complex=False",
- )
- if not onesided_:
- input = input.narrow(dim=-1, start=0, length=n_fft // 2 + 1)
- input = torch.fft.irfft(input, dim=-1, norm=norm)
- assert input.size(2) == n_fft
- y_tmp = input * window_.view([1, 1, n_fft])
- y = aten.unfold_backward(
- y_tmp,
- input_sizes=(y_tmp.size(0), expected_output_signal_len),
- dim=1,
- size=n_fft,
- step=hop_length_,
- )
- window_envelop = aten.unfold_backward(
- window_.pow(2).expand((1, n_frames, n_fft)),
- input_sizes=(y_tmp.size(0), expected_output_signal_len),
- dim=1,
- size=n_fft,
- step=hop_length_,
- )
- assert expected_output_signal_len == y.size(1)
- assert expected_output_signal_len == window_envelop.size(1)
- start = n_fft // 2 if center else 0
- if length is not None:
- end = start + length
- elif center:
- end = expected_output_signal_len - n_fft // 2
- else:
- end = expected_output_signal_len
- length = max(0, end - start)
- y = y.narrow(dim=1, start=start, length=length)
- window_envelop = window_envelop.narrow(dim=1, start=start, length=length)
- window_envelop_lowest = window_envelop.abs().min().lt(1e-11)
- torch._check(
- not window_envelop_lowest.item(),
- lambda: "window overlap add min less than 1e-11",
- )
- y = y / window_envelop
- if original_ndim == 2:
- y = y.squeeze(0)
- if end > expected_output_signal_len:
- warnings.warn(
- "The length of signal is shorter than the length parameter. Result is being "
- + "padded with zeros in the tail. Please check your center and hop_length settings"
- )
- y = aten.constant_pad_nd(y, (0, end - expected_output_signal_len), 0)
- return y
- # Get the new shape and stride after applying unfold to an input tensor
- def _get_unfold_shape_stride(
- a_shape: ShapeType, a_stride: StrideType, dimension: int, size: int, step: int
- ):
- a_ndim = len(a_shape)
- dim = utils.canonicalize_dim(a_ndim, dimension, wrap_scalar=True)
- max_size = 1 if a_ndim == 0 else a_shape[dim]
- last_stride = 1 if a_ndim == 0 else a_stride[dim]
- torch._check(
- size <= max_size,
- lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}",
- )
- torch._check(
- step > 0,
- lambda: f"Step is {step} but must be > 0",
- )
- shape = list(a_shape)
- strides = list(a_stride)
- shape.append(size)
- strides.append(last_stride)
- if dim < a_ndim:
- shape[dim] = (shape[dim] - size) // step + 1
- strides[dim] *= step
- return shape, strides
- @register_decomposition(aten.repeat)
- @out_wrapper()
- def repeat(a: Tensor, *repeat_shape) -> Tensor:
- repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False)
- torch._check(
- len(repeat_shape) >= len(a.shape),
- lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
- )
- if len(repeat_shape) == 0:
- return torch.clone(a)
- num_new_dimensions = len(repeat_shape) - a.ndim
- padded_shape = [1] * num_new_dimensions
- for dim_size in a.shape:
- padded_shape.append(dim_size)
- target_shape = tuple(
- padded_size * repeat_size
- for padded_size, repeat_size in zip(padded_shape, repeat_shape)
- )
- # return an empty tensor if one of the repeat_shape dimensions is zero
- if 0 in repeat_shape:
- return torch.empty(
- target_shape,
- dtype=a.dtype,
- device=a.device,
- requires_grad=a.requires_grad,
- memory_format=utils.suggest_memory_format(a),
- )
- urtensor_shape = target_shape
- urtensor_stride = utils.make_contiguous_strides_for(target_shape)
- for dim, dim_size in enumerate(padded_shape):
- # repeat each dimension by using unfold_copy operation
- urtensor_shape, urtensor_stride = _get_unfold_shape_stride(
- urtensor_shape, urtensor_stride, dim, dim_size, max(dim_size, 1)
- )
- # derive permute order by sorting urtensor strides
- enumerated_stride = list(enumerate(urtensor_stride))
- enumerated_stride.sort(key=operator.itemgetter(1), reverse=True)
- permute_order, sorted_stride = zip(*enumerated_stride)
- # add new and expand dimensions according to urtensor
- repeat_xtensor = a.expand(urtensor_shape)
- # clone tensor to concretize expanded dimensions
- cloned_result = torch.clone(repeat_xtensor)
- # transpose axis so strides are in sorted order
- permuted_result = cloned_result.permute(permute_order)
- # reshape to get contiguous tensor with correct target shape
- return permuted_result.reshape(target_shape)
- def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
- from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq
- # Creates a valid shape
- shape = utils.extract_shape_from_varargs(shape, validate=False)
- # Reshape may be given a shape with a -1 length
- # This indicates that the dimension's length should be inferred
- shape = utils.infer_size(shape, a.numel())
- # Special-cases tensors with no elements
- if guard_size_oblivious(a.numel() == 0):
- return as_strided(a, shape, utils.make_contiguous_strides_for(shape))
- # Special-cases reshaping zero dim tensors
- if a.ndim == 0:
- _a = a
- for length in shape:
- assert length == 1
- _a = unsqueeze(_a, -1)
- if _a is a:
- return prims.view_of(a)
- else:
- return _a
- # Special-cases reshaping to zero dim tensors
- if len(shape) == 0:
- _a = a
- for length in a.shape:
- assert length == 1
- _a = squeeze(_a, -1)
- if _a is a:
- return prims.view_of(a)
- else:
- return _a
- if a.is_contiguous():
- # Special-cases for nd_to_1d
- if len(shape) == 1 and a.ndim > 1:
- return torch.as_strided(a, [a.numel()], [1])
- # Special-cases for 1d_to_2d
- if len(shape) == 2 and a.ndim == 1:
- dim0 = shape[0]
- dim1 = shape[1]
- return torch.as_strided(a, [dim0, dim1], [dim1, 1])
- # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
- # NOTE [Reshape Algorithm]
- # This algorithm works by attempting to greedily construct the desired dimensions in
- # the output shape, left to right. It does this by, conceptually, accumulating
- # dimensions of the original tensor, also left to right, until the dimension
- # can be constructed using prims.split_dim.
- # The algorithm also has special handling for tail squeezes/unsqueezes, like
- # if a reshape from (5, 5) to (5, 5, 1) or vice versa.
- #
- # This algorithm does not flatten the original tensor and then split dims as appropriate
- # because that would create copies more often than this algorithm. flatten is the only
- # operation below which can create a view or a copy, and while it prefers creating
- # views it may sometimes create a copy if the tensor's strides do not permit a view.
- # As a result, this algorithm tries to minimize flattening.
- #
- # Note that a better version of this algorithm may exist. Regions which could be
- # flattened without creating a copy can be identified in advance, and that might
- # allow fewer flatten calls or faster short-circuiting to make a copy.
- idx = 0
- a_ = a
- for length in shape:
- # Handles tail unsqueezes
- if idx >= a_.ndim:
- assert length == 1
- last_dim = a_.ndim - 1
- # NOTE: using split_dim instead of unsqueeze may seem silly here,
- # but it's necessary to get the strides correct
- a_ = prims.split_dim(a_, last_dim, a_.shape[last_dim])
- idx = idx + 1
- continue
- # Skips dimensions that are already the correct length
- if guard_size_oblivious(length == a_.shape[idx]):
- idx = idx + 1
- continue
- # Gathers enough original dimensions such that this new dimension can be created
- # Note that this accumulation will terminate because we've verified a and the shape
- # specify the same number of elements above
- accum = a_.shape[idx]
- end = idx
- while guard_size_oblivious(accum % length != 0):
- end = end + 1
- accum = accum * a_.shape[end]
- if end != idx:
- # NOTE: in this case multiple dimensions must be flatten to create the desired dimension
- # This flattening is why reshape sometimes creates a copy -- because flattening
- # may return a view of a copy
- # Checks if collapse can be a view and short-circuits to copying reshape if it can't
- new_shape, new_strides = prims._collapse_view_helper(a_, idx, end)
- if new_shape is None:
- if allow_copy:
- return prims.reshape(a, shape)
- msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
- raise ValueError(msg)
- a_ = flatten(a_, idx, end)
- # Splits the (possibly flattened) dimension to create the desired dim length
- if guard_size_oblivious(accum != length):
- a_ = prims.split_dim(a_, idx, length)
- idx = idx + 1
- # Squeezes tail
- while idx < a_.ndim:
- torch._check(
- a_.shape[idx] == 1,
- lambda: f"a.size({idx}) expected to be 1 but got {a_.shape[idx]}",
- )
- a_ = squeeze(a_, idx)
- if a_ is a:
- return prims.view_of(a)
- else:
- return a_
- # CompositeImplicitAutograd - don't register decomp
- # NOTE: shape is a vararg because Tensor.reshape can be called with as
- # Tensor.reshape(a, b, c) or Tensor.reshape((a, b, c)) Function call
- # torch.reshape doesn't support unpacked shapes
- def reshape(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
- return _reshape_view_helper(a, *shape, allow_copy=True)
- # CompositeImplicitAutograd - don't register decomp
- def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
- return self.reshape(other.size())
- @register_decomposition(aten.roll)
- @out_wrapper()
- def roll(
- a: TensorLikeType, shifts: DimsType, dims: DimsType = tuple()
- ) -> TensorLikeType:
- """Reference implementation of :func:`torch.roll`."""
- dims = utils.canonicalize_dims(a.ndim, dims)
- # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
- if not isinstance(shifts, Iterable):
- shifts = (shifts,)
- if not isinstance(dims, Iterable):
- dims = (dims,)
- # Avoid modulo by zero
- if a.numel() == 0:
- # Keeping this as ref for now as FakeTensor runs into some issues with complex tensors
- return a.clone()
- if a.dim() == 0 and len(dims) > 0:
- raise IndexError(
- f"Dimension specified as {dims[0]} but tensor has no dimensions"
- )
- len_shifts = len(shifts)
- len_dims = len(dims)
- if len_shifts != 1 or len_dims != 1:
- if len_shifts == 0:
- raise RuntimeError("`shifts` required")
- # Takes care of the case when dims is not specified (default)
- # By default, the tensor is flattened before shifting, after which the original shape is restored
- if len_dims == 0 and len_shifts == 1:
- return torch.roll(torch.flatten(a), shifts, 0).view(a.shape)
- if len_shifts != len_dims:
- raise RuntimeError(
- f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
- )
- assert len_dims > 1
- tail_shifts = shifts[1:]
- tail_dims = dims[1:]
- first_dim_rolled = torch.roll(a, (shifts[0],), dims[0])
- return torch.roll(first_dim_rolled, tail_shifts, tail_dims)
- # This path is taken when only one dimension is rolled
- # For example to get `first_dim_rolled` above
- dim = dims[0]
- size = a.shape[dim]
- start = (size - shifts[0]) % size
- idx = torch.arange(size, device=a.device)
- return a.index_select(dim, torch.fmod(start + idx, size))
- @register_decomposition(aten.rot90)
- @out_wrapper()
- def rot90(
- a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1)
- ) -> TensorLikeType:
- """Reference implementation of :func:`torch.rot90`."""
- if len(dims) != 2:
- raise RuntimeError(
- f"expected total rotation dims == 2, but got dims = {len(dims)}"
- )
- if a.ndim < 2:
- raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}")
- # Do this after the initial checks to be compatible with the behavior in
- # core.
- dims = utils.canonicalize_dims(a.ndim, dims)
- if dims[0] == dims[1]:
- raise RuntimeError(
- f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
- )
- k = k % 4 # Rotation direction is from the second towards the first axis for k < 0
- if k == 1:
- return torch.transpose(torch.flip(a, (dims[1],)), dims[0], dims[1])
- elif k == 2:
- return torch.flip(a, dims)
- elif k == 3:
- return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1])
- else:
- return clone(a, memory_format=torch.contiguous_format)
- def _check_stack_inputs(tensors: TensorSequenceType) -> None:
- entry_shape = tensors[0].shape
- for i in range(1, len(tensors)):
- assert tensors[i].shape == entry_shape, (
- f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 "
- f"and {tensors[i].shape} at entry {i}"
- )
- @register_decomposition(aten.stack)
- @out_wrapper()
- def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
- assert len(tensors) > 0, "stack expects a non-empty TensorList"
- wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim)
- # Refs need sparse support to check other condition
- if wrapped_dim < tensors[0].ndim: # and not tensors[0].is_sparse:
- _check_stack_inputs(tensors)
- result_sizes = list(tensors[0].shape)
- result_sizes.insert(wrapped_dim, len(tensors))
- out = torch.cat(tensors, wrapped_dim)
- return out.view(result_sizes)
- # If dim == tensors[0].ndim, view cannot efficiently handle it
- return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim)
- # CompositeImplicitAutograd - don't register decomp
- @out_wrapper()
- def softmax(
- a: TensorLikeType,
- dim: int,
- dtype: Optional[torch.dtype] = None,
- ) -> TensorLikeType:
- result_dtype = dtype or a.dtype
- computation_dtype = utils.get_computation_dtype(result_dtype)
- a_ = _maybe_convert_to_dtype(a, computation_dtype)
- if a.numel() == 0:
- a_exp = exp(a_)
- else:
- a_max = amax(a_, dim, keepdim=True)
- a_exp = exp(a_ - a_max)
- return _maybe_convert_to_dtype(
- true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype
- ) # type: ignore[return-value]
- # CompositeImplicitAutograd - don't register decomp
- @out_wrapper()
- def hstack(tensors: TensorSequenceType) -> TensorLikeType:
- torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
- aligned_tensors = atleast_1d(*tensors)
- if aligned_tensors[0].ndim == 1:
- return cat(aligned_tensors, 0)
- return cat(aligned_tensors, 1)
- # CompositeImplicitAutograd - don't register decomp
- @out_wrapper()
- def vstack(tensors: TensorSequenceType) -> TensorLikeType:
- torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
- aligned_tensors = atleast_2d(*tensors)
- return cat(aligned_tensors, 0)
- # CompositeImplicitAutograd - don't register decomp
- def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:
- dim = utils.canonicalize_dim(a.ndim, dim)
- torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
- return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :]))
- @register_decomposition(aten.unbind)
- def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
- from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
- dim = utils.canonicalize_dim(t.ndim, dim)
- torch._check_index(
- len(t.shape) > 0,
- lambda: "Dimension specified as 0 but tensor has no dimensions",
- )
- if guard_size_oblivious(t.shape[dim] == 0):
- return tuple()
- else:
- return tuple(
- torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
- )
- @out_wrapper()
- def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
- return x.clone(memory_format=torch.contiguous_format).index_copy_(
- dim, index, tensor
- )
- def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
- dim = utils.canonicalize_dims(x.ndim, dim)
- torch._check(
- index.ndim <= 1,
- lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
- )
- # Treat scalars as elements of \R^1
- y = x.unsqueeze(0) if x.ndim == 0 else x
- idx = (slice(None),) * dim + (index,)
- y[idx] = tensor
- return x
- @register_decomposition(aten.index_fill)
- @out_wrapper()
- def index_fill(
- x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
- ):
- return _index_fill(x, dim, index, value, inplace=False)
- @register_decomposition(aten.index_fill_)
- def index_fill_(
- x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
- ):
- return _index_fill(x, dim, index, value, inplace=True)
- def _index_fill(
- x: TensorLike,
- dim: int,
- index: TensorLike,
- value: Union[NumberType, TensorLike],
- *,
- inplace: bool,
- ):
- torch._check(
- index.ndim <= 1,
- lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
- )
- if isinstance(value, TensorLike):
- torch._check(
- value.ndim == 0,
- lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr]
- f"Got a tensor with {value.ndim} dimensions.",
- ) # type: ignore[arg-type]
- else:
- value = torch.scalar_tensor(
- value, dtype=x.dtype, layout=x.layout, device=x.device # type: ignore[arg-type]
- )
- # index_copy has some unnecessary preconditions when x is a scalar. We do this to work through them
- zero_dim = x.ndim == 0
- y = x.unsqueeze(0) if zero_dim else x
- # index_copy does not broadcast on value so we have to do it manually
- shape = list(y.shape)
- shape[dim] = index.numel()
- value = value.expand(shape)
- index_copy = Tensor.index_copy_ if inplace else torch.index_copy
- out = index_copy(y, dim, index, value) # type: ignore[operator]
- if inplace:
- return x
- else:
- if zero_dim:
- # The clone is necessary so that it returns a fresh tensor rather than a view
- out = out.squeeze(0).clone()
- # index_fill preserves the strides. index_copy always returns contiguous tensors
- if out.stride() != x.stride():
- new_out = torch.empty_like(x)
- new_out.copy_(out)
- out = new_out
- return out
- @out_wrapper()
- def index_add(
- x: TensorLike,
- dim: int,
- index: TensorLike,
- tensor: TensorLike,
- *,
- alpha: NumberType = 1,
- ):
- # index_add always returns a new contiguous tensor
- return x.clone(memory_format=torch.contiguous_format).index_add_(
- dim, index, tensor, alpha=alpha # type: ignore[arg-type]
- )
- @register_decomposition(aten.index_select)
- @out_wrapper()
- def index_select(x: TensorLike, dim: int, index: TensorLike):
- dim = utils.canonicalize_dims(x.ndim, dim)
- torch._check(
- index.ndim <= 1,
- lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
- )
- if index.ndim == 0:
- index = index.unsqueeze(0)
- if x.ndim == 0:
- # Treat scalars as elements of \R^1
- # We cannot use x[idx] here as it accesses item() (??), hence this awkward construction
- return torch.empty_like(x).index_copy(0, index, x.expand_as(index))
- idx = (slice(None),) * dim + (index,)
- return x[idx]
- @register_decomposition(aten.squeeze.dims)
- def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
- from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
- if dim is None:
- dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1)
- return prims.squeeze(a, dims) if dims else prims.view_of(a)
- ndim = a.ndim
- dim = utils.canonicalize_dims(ndim, dim)
- dims = (dim,) if isinstance(dim, Dim) else dim
- # Short-circuits if the tensor has no dimensions
- if ndim == 0:
- assert len(dims) == 0 or dims == (0,)
- return prims.view_of(a)
- # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1
- dims = tuple(d for d in dims if guard_size_oblivious(a.shape[d] == 1))
- if len(dims) == 0:
- return prims.view_of(a)
- if len(dims) == 1:
- return prims.squeeze(a, dims)
- dims_list = list(dims)
- dims_list = sorted(dims_list, reverse=True)
- for i in dims_list:
- a = squeeze(a, i)
- return a
- # Note: does not work with TensorMetas because of data-dependent control-flow
- # CompositeImplicitAutograd - don't register decomp
- def tensor_split(
- a: TensorLikeType,
- indices_or_sections: Union[Tensor, DimsType],
- dim: int = 0,
- ) -> Tuple[TensorLikeType, ...]:
- _dim = utils.canonicalize_dim(a.ndim, dim)
- if a.ndim == 0:
- msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!"
- raise ValueError(msg)
- # If indices_or_sections is a tensor, it must be a CPU Long tensor
- if isinstance(indices_or_sections, TensorLike):
- if not indices_or_sections.device.type == "cpu":
- msg = (
- f"tensor_split: if indices_or_sections is a tensor it must be on the CPU, "
- f"but received one on {indices_or_sections.device}"
- )
- raise ValueError(msg)
- if indices_or_sections.dtype != torch.long:
- msg = "tensor_split: if indices_or_sections is a tensor it must have long dtype, "
- f" but received one with dtype {indices_or_sections.dtype}"
- raise ValueError(msg)
- # Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length
- if isinstance(indices_or_sections, IntLike) or (
- isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0
- ):
- sections: int = (
- indices_or_sections # type: ignore[assignment]
- if isinstance(indices_or_sections, Number)
- else indices_or_sections.item()
- )
- if sections <= 0:
- msg = f"tensor_split: number of sections must be greater than 0, but was {sections}"
- raise ValueError(msg)
- splits = []
- dim_size = a.shape[_dim]
- min_split_size = math.floor(dim_size / sections)
- num_splits_one_extra = dim_size % sections
- start_idx = 0
- for split_idx in range(sections):
- split_size = (
- min_split_size + 1
- if (split_idx < num_splits_one_extra)
- else min_split_size
- )
- s = prims.slice_in_dim(a, start_idx, start_idx + split_size, axis=_dim)
- splits.append(s)
- start_idx = start_idx + split_size
- return tuple(splits)
- # Case 1 -- indices_or_sections is a sequence of integers or a 1D tensor describing the splits
- else:
- indices = indices_or_sections
- if isinstance(indices_or_sections, TensorLike):
- if indices_or_sections.ndim != 1:
- msg = "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, "
- f"but received a tensor with {indices_or_sections.ndim} dimensions"
- raise ValueError(msg)
- indices = indices_or_sections.tolist()
- splits = []
- start_idx = 0
- for x in indices:
- splits.append(prims.slice_in_dim(a, start_idx, x, axis=_dim))
- start_idx = x
- splits.append(prims.slice_in_dim(a, start_idx, a.shape[_dim], axis=_dim))
- return tuple(splits)
- # CompositeImplicitAutograd - don't register decomp
- def hsplit(
- a: TensorLikeType, indices_or_sections: DimsType
- ) -> Tuple[TensorLikeType, ...]:
- torch._check(
- a.ndim >= 1,
- lambda: (
- "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "
- + str(a.ndim)
- + " dimensions!"
- ),
- )
- dim = 0 if a.ndim == 1 else 1
- if isinstance(indices_or_sections, IntLike):
- split_size = indices_or_sections
- torch._check(
- (split_size != 0 and a.shape[dim] % split_size == 0),
- lambda: (
- "torch.hsplit attempted to split along dimension "
- + str(dim)
- + ", but the size of the dimension "
- + str(a.shape[dim])
- + " is not divisible by the split_size "
- + str(split_size)
- + "!"
- ),
- )
- return tensor_split(a, split_size, dim)
- torch._check_type(
- isinstance(indices_or_sections, (list, tuple)),
- lambda: (
- "hsplit(): received an invalid combination of arguments. "
- "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
- f"but got type {type(indices_or_sections)}"
- ),
- )
- split_sizes = indices_or_sections
- return tensor_split(a, split_sizes, dim)
- # CompositeImplicitAutograd - don't register decomp
- def vsplit(
- a: TensorLikeType, indices_or_sections: DimsType
- ) -> Tuple[TensorLikeType, ...]:
- torch._check(
- a.ndim >= 2,
- lambda: (
- "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "
- + str(a.ndim)
- + " dimensions!"
- ),
- )
- if isinstance(indices_or_sections, IntLike):
- split_size = indices_or_sections
- torch._check(
- (split_size != 0 and a.shape[0] % split_size == 0),
- lambda: (
- f"torch.vsplit attempted to split along dimension 0"
- f", but the size of the dimension "
- f"{a.shape[0]}"
- f" is not divisible by the split_size "
- f"{split_size}"
- f"!"
- ),
- )
- return tensor_split(a, split_size, 0)
- torch._check_type(
- isinstance(indices_or_sections, (list, tuple)),
- lambda: (
- "vsplit(): received an invalid combination of arguments. "
- "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
- f"but got type {type(indices_or_sections)}"
- ),
- )
- split_sizes = indices_or_sections
- return tensor_split(a, split_sizes, 0)
- @register_decomposition(aten.diag.out)
- @out_wrapper()
- def diag(
- self: TensorLikeType,
- offset: int = 0,
- ) -> TensorLikeType:
- ndim = self.dim()
- torch._check(
- ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D"
- )
- if ndim == 1:
- return torch.diag_embed(self, offset)
- else:
- return torch.diagonal_copy(self, offset)
- @register_decomposition(aten.diagonal_scatter)
- @out_wrapper()
- def diagonal_scatter(
- input: TensorLikeType,
- src: TensorLikeType,
- offset: int = 0,
- dim1: int = 0,
- dim2: int = 1,
- ) -> TensorLikeType:
- out = utils.clone_preserve_strides(input)
- diag = out.diagonal(offset, dim1, dim2)
- torch._check(
- diag.shape == src.shape,
- lambda: "expected src to have a size equal to the diagonal of the input."
- f"Got {src.shape} for a diagonal of shape {diag.shape}",
- )
- copy_to(diag, src)
- return out
- @register_decomposition(aten.diagonal)
- def diagonal(
- self: TensorLikeType,
- offset: int = 0,
- dim1: int = 0,
- dim2: int = 1,
- ) -> TensorLikeType:
- """
- Reference implementation of torch.diagonal
- """
- num_dims = self.dim()
- dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims)
- dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims)
- torch._check(
- dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
- )
- storage_offset = self.storage_offset()
- if offset >= 0:
- diag_size = max(min(self.size()[dim1], self.size()[dim2] - offset), 0)
- else:
- diag_size = max(min(self.size()[dim1] + offset, self.size()[dim2]), 0)
- if diag_size > 0:
- if offset >= 0:
- storage_offset += offset * self.stride()[dim2]
- else:
- storage_offset -= offset * self.stride()[dim1]
- sizes = [s for i, s in enumerate(self.size()) if i not in (dim1, dim2)]
- sizes.append(diag_size)
- strides = [s for i, s in enumerate(self.stride()) if i not in (dim1, dim2)]
- strides.append(self.stride()[dim1] + self.stride()[dim2])
- result = self.as_strided(size=sizes, stride=strides, storage_offset=storage_offset)
- return result
- diagonal_copy = _make_copy_from_view(diagonal)
- @register_decomposition(aten.diag_embed)
- @out_wrapper()
- def diag_embed(
- t: TensorLikeType,
- offset: int = 0,
- dim1: int = -2,
- dim2: int = -1,
- ) -> TensorLikeType:
- """
- Reference implementation of torch.diag_embed
- """
- # convert from negative dims
- rank = t.ndim + 1
- dim1 = utils.canonicalize_dim(rank=rank, idx=dim1)
- dim2 = utils.canonicalize_dim(rank=rank, idx=dim2)
- # as per the docs, exchanging dims is equivalent to changing the sign of
- # offset
- if dim1 > dim2:
- dim1, dim2 = dim2, dim1
- offset = -offset
- torch._check(
- dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
- )
- # as per the docs, the size of last dim is placed at dim1 and dim2
- last_dim = t.size(-1)
- if offset != 0:
- # add padding to match the new size
- t_shape = list(t.shape)
- t_shape[-1] = builtins.abs(offset)
- z = torch.zeros(t_shape, dtype=t.dtype, device=t.device, requires_grad=False)
- pair = (z, t) if offset > 0 else (t, z)
- t = torch.cat(pair, dim=-1)
- # make sure the diagonal always has the same size
- last_dim += builtins.abs(offset)
- # preserve original data, but place 1 at dim1 and move last dim to dim2
- t = t.unsqueeze(dim1).movedim(-1, dim2)
- # generate ranges shifting indices based on offset
- a_range = torch.arange(last_dim, device=t.device, dtype=torch.int64)
- b_range = torch.arange(
- offset, last_dim + offset, device=t.device, dtype=torch.int64
- )
- # broadcast
- cond = a_range == b_range.unsqueeze(-1)
- cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(t.shape))]
- cond = cond.reshape(cond_shape)
- # aten.diag_embed always returns a new contiguous tensor
- # contiguous() is needed to correctly model the output stride
- return utils.mask_tensor(cond, t).contiguous()
- @register_decomposition(aten.block_diag)
- @out_wrapper()
- def _block_diag_iterable(tensors: List[TensorLikeType]) -> TensorLikeType:
- """
- Reference implementation of torch.block_diag
- """
- tensors_2d = [
- tensor.view(1, -1) if tensor.dim() <= 1 else tensor for tensor in tensors
- ]
- ncols = builtins.sum(tensor.shape[1] for tensor in tensors_2d)
- device = tensors_2d[0].device
- result = []
- col_start = 0
- for i, tensor in enumerate(tensors_2d):
- torch._check(
- tensor.dim() == 2,
- lambda: "Input tensors must have 2 or fewer dimensions. "
- f"Input {i} has {tensor.dim()} dimensions",
- )
- torch._check(
- tensor.device == device,
- lambda: "Input tensors must all be on the same device. "
- f"Input 0 is on device {device} and input {i} is on device {tensor.device}.",
- )
- row, col = tensor.shape
- left = torch.zeros((row, col_start), device=device, dtype=tensor.dtype)
- right = torch.zeros(
- (row, ncols - col_start - col), device=device, dtype=tensor.dtype
- )
- result += [torch.cat((left, tensor, right), dim=1)]
- col_start += col
- return torch.cat(result, dim=0)
- def block_diag(*tensors: List[TensorLikeType]) -> TensorLikeType:
- """
- This is used as an input to PythonRefInfo. `torch.block_diag`
- expects arguments splatted, but `aten.block_diag` expects only
- one argument that is a list of Tensors.
- """
- return _block_diag_iterable(tensors)
- # CompositeImplicitAutograd - don't register decomp
- def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
- if a.ndim < 3:
- raise RuntimeError(
- f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!"
- )
- if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0):
- raise RuntimeError(
- "torch.dsplit attempted to split along dimension 2, "
- + f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!"
- )
- return tensor_split(a, sections, 2)
- @register_decomposition(aten.t.default)
- def t(a: TensorLikeType):
- # TODO: Add sparse support
- # if a.is_sparse:
- # sparse_dim = a.sparse_dim()
- # dense_dim = a.dense_dim()
- # if not (sparse_dim <= 2 and dense_dim == 0):
- # raise RuntimeError(
- # f"t() expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and"
- # f"{dense_dim} dense dimensions"
- # )
- if a.ndim > 2:
- raise RuntimeError(
- f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D"
- )
- return torch.transpose(a, 0, 0 if a.ndim < 2 else 1)
- # CompositeImplicitAutograd - don't register decomp
- def T(a: TensorLikeType) -> TensorLikeType:
- # n != 2 && n != 0 is deprecated in regular PyTorch.
- torch._check(
- a.ndim in (0, 2),
- lambda: (
- "The use of `x.T` on tensors of dimension other than 0 or 2 "
- "to reverse their shape is not supported."
- ),
- )
- return a.t()
- @register_decomposition(aten.alias)
- def alias(a: TensorLikeType) -> TensorLikeType:
- return prims.view_of(a)
- @register_decomposition(aten.transpose)
- def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType:
- _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc]
- if a.ndim <= 1 or dim0 == dim1:
- return aten.alias.default(a)
- _permutation = list(range(0, a.ndim))
- _permutation[_dim0] = _dim1
- _permutation[_dim1] = _dim0
- return torch.permute(a, _permutation)
- # Aliases for transpose
- swap_axes = transpose
- @register_decomposition(aten.unfold)
- def unfold(
- self: TensorLikeType, dimension: int, size: int, step: int
- ) -> TensorLikeType:
- shape, strides = _get_unfold_shape_stride(
- self.shape, self.stride(), dimension, size, step
- )
- return self.as_strided(shape, strides)
- @register_decomposition(aten.unfold_copy)
- @out_wrapper()
- def unfold_copy(self: TensorLikeType, dimension: int, size: int, step: int):
- return self.unfold(dimension, size, step).clone(
- memory_format=torch.contiguous_format
- )
- def _cumsumprod_common(
- func,
- init,
- a: TensorLikeType,
- dim: int,
- *,
- dtype: Optional[torch.dtype] = None,
- out: Optional[Tensor] = None,
- ) -> TensorLikeType:
- # We implement all the kwargs of a reduction. ATen just handles dtype
- # nb. This decomposition may not be as efficient as a backend-specific implementation
- ndim = a.ndim
- dim = utils.canonicalize_dim(ndim, dim)
- if ndim == 0:
- return func(a.unsqueeze(0), dim=0, dtype=dtype, out=out)
- a = a.unsqueeze(dim + 1)
- rg = torch.arange(a.shape[dim], device=a.device)
- mask = rg.unsqueeze(1) <= rg
- for _ in range(ndim - dim - 1):
- mask = mask.unsqueeze(-1)
- masked_a = torch.where(mask, a, init)
- return func(masked_a, dim=dim, dtype=dtype, out=out)
- @register_decomposition(aten.cumsum)
- def cumsum(
- a: TensorLikeType,
- dim: int,
- *,
- dtype: Optional[torch.dtype] = None,
- out: Optional[Tensor] = None,
- ) -> TensorLikeType:
- return _cumsumprod_common(func=sum, init=0, a=a, dim=dim, dtype=dtype, out=out)
- @register_decomposition(aten.cumprod)
- def cumprod(
- a: TensorLikeType,
- dim: int,
- *,
- dtype: Optional[torch.dtype] = None,
- out: Optional[Tensor] = None,
- ) -> TensorLikeType:
- return _cumsumprod_common(func=prod, init=1, a=a, dim=dim, dtype=dtype, out=out)
- # Note: although squeeze is documented as having the out= kwarg it doesn't
- @register_decomposition(aten.unsqueeze)
- def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType:
- # Note that unsqueeze canonicalizes with rank + 1 because it allows
- # a new innermost dimension to be specified
- ndim = a.ndim + 1
- dim = utils.canonicalize_dim(ndim, dim)
- return prims.expand_dims(a, (dim,), ndim=ndim)
- # NOTE: shape is a vararg because Tensor.reshape can be called with as
- # Tensor.view(a, b, c) or Tensor.view((a, b, c)) Function call torch.view
- # doesn't support unpacked shapes
- # TODO: Turn this into a decomposition (currently fails on reshape meta tests)
- @register_decomposition(aten.view.default)
- def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
- return _reshape_view_helper(a, *shape, allow_copy=False)
- # CompositeImplicitAutograd - don't register decomp
- def view_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
- return self.view(other.size())
- # CompositeImplicitAutograd - don't register decomp
- def ravel(a: TensorLikeType) -> TensorLikeType:
- return reshape(a, (-1,))
- # CompositeImplicitAutograd - don't register decomp
- # missing ref impl. for aten.gather
- @out_wrapper()
- def take_along_dim(
- a: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = None
- ) -> torch.Tensor:
- torch._check(
- a.ndim == indices.ndim,
- lambda: (
- "torch.take_along_dim(): input and indices should have the same "
- f"number of dimensions, but got {a.ndim} dimensions for input, and "
- f"{indices.ndim} dimensions for indices"
- ),
- )
- torch._check(
- utils.is_integer_dtype(indices.dtype),
- lambda: (
- "torch.take_along_dim(): dtype of indices should be int but got "
- f"{indices.dtype} instead"
- ),
- )
- if dim is None:
- return torch.gather(a.view(-1), 0, indices.view(-1))
- else:
- self_sizes = list(a.shape)
- self_sizes[dim] = indices.size(dim)
- broadcast_shape = utils.infer_size_shapes(self_sizes, indices.size())
- indices_broadcast = broadcast_to(indices, broadcast_shape)
- indices_sizes = list(indices.shape)
- indices_sizes[dim] = a.size(dim)
- broadcast_shape = utils.infer_size_shapes(indices_sizes, a.size())
- self_broadcast = broadcast_to(a, broadcast_shape)
- return torch.gather(self_broadcast, dim, indices_broadcast)
- @out_wrapper()
- def empty(
- *shape,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[DeviceLikeType] = None,
- requires_grad: bool = False,
- pin_memory: bool = False,
- memory_format: torch.memory_format = torch.contiguous_format,
- ) -> TensorLikeType:
- torch._check(
- memory_format != torch.preserve_format,
- lambda: "torch.empty: the Preserve memory format is not supported",
- )
- shape = utils.extract_shape_from_varargs(shape)
- if memory_format == torch.contiguous_format:
- strides = utils.make_contiguous_strides_for(shape)
- elif memory_format == torch.channels_last_3d:
- strides = utils.make_channels_last_3d_strides_for(shape)
- else: # memory_format == torch.channels_last
- torch._check(
- memory_format == torch.channels_last,
- lambda: f"torch.empty: received an unknown memory format {memory_format}!",
- )
- strides = utils.make_channels_last_2d_strides_for(shape)
- return torch.empty_strided(
- shape,
- strides,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- @out_wrapper()
- def empty_permuted(
- shape,
- physical_layout,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[DeviceLikeType] = None,
- requires_grad: bool = False,
- pin_memory: bool = False,
- ) -> TensorLikeType:
- return prims.empty_permuted(
- shape,
- physical_layout,
- dtype=dtype,
- device=device,
- requires_grad=requires_grad,
- )
- @register_decomposition(aten.new_empty)
- @out_wrapper()
- def new_empty(
- a: TensorLikeType,
- size: ShapeType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[DeviceLikeType] = None,
- pin_memory: bool = False,
- ) -> TensorLikeType:
- dtype = a.dtype if dtype is None else dtype
- layout = a.layout if layout is None else layout
- device = a.device if device is None else device
- return torch.empty(
- size,
- dtype=dtype,
- device=device,
- pin_memory=pin_memory,
- layout=layout,
- )
- @register_decomposition(aten.new_empty_strided)
- @out_wrapper()
- def new_empty_strided(
- a: TensorLikeType,
- size: ShapeType,
- stride: StrideType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[DeviceLikeType] = None,
- pin_memory: bool = False,
- ) -> TensorLikeType:
- """
- Reference implementation of torch.Tensor.new_empty_strided
- """
- dtype = a.dtype if dtype is None else dtype
- layout = a.layout if layout is None else layout
- device = a.device if device is None else device
- return torch.empty_strided(
- size,
- stride,
- dtype=dtype,
- device=device,
- pin_memory=pin_memory,
- layout=layout,
- )
- @register_decomposition(aten.zeros.default)
- @out_wrapper()
- def zeros(
- *size,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[DeviceLikeType] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- size = utils.extract_shape_from_varargs(size)
- if dtype is None:
- dtype = torch.get_default_dtype()
- return torch.full(
- size,
- False if dtype == torch.bool else 0,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- @register_decomposition(aten.new_zeros)
- @out_wrapper()
- def new_zeros(
- a: TensorLikeType,
- size: ShapeType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[DeviceLikeType] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- dtype = a.dtype if dtype is None else dtype
- layout = a.layout if layout is None else layout
- device = a.device if device is None else device
- return torch.full(
- size,
- False if (dtype or a.dtype) == torch.bool else 0,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- @register_decomposition(aten.ones.default)
- @out_wrapper()
- def ones(
- *size,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[DeviceLikeType] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- size = utils.extract_shape_from_varargs(size)
- if dtype is None:
- dtype = torch.get_default_dtype()
- return torch.full(
- size,
- True if dtype == torch.bool else 1,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- @register_decomposition(aten.new_ones)
- @out_wrapper()
- def new_ones(
- a: TensorLikeType,
- size: ShapeType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[DeviceLikeType] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- dtype = a.dtype if dtype is None else dtype
- layout = a.layout if layout is None else layout
- device = a.device if device is None else device
- return torch.full(
- size,
- True if (dtype or a.dtype) == torch.bool else 1,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- @register_decomposition(aten.new_full)
- @out_wrapper()
- def new_full(
- a: TensorLikeType,
- size: ShapeType,
- fill_value: NumberType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[DeviceLikeType] = None,
- pin_memory: bool = False,
- ) -> TensorLikeType:
- dtype = a.dtype if dtype is None else dtype
- layout = a.layout if layout is None else layout
- device = a.device if device is None else device
- return torch.full(
- size,
- fill_value,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- )
- @register_decomposition(aten.empty_like)
- @out_wrapper()
- def empty_like(
- a: TensorLikeType,
- *,
- dtype: Optional[torch.dtype] = None,
- device: Optional[DeviceLikeType] = None,
- layout: Optional[torch.layout] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- memory_format: torch.memory_format = torch.preserve_format,
- ) -> TensorLikeType:
- dtype = a.dtype if dtype is None else dtype
- layout = a.layout if layout is None else layout
- device = a.device if device is None else device
- if memory_format != torch.preserve_format:
- return torch.empty(
- a.shape,
- dtype=dtype,
- layout=layout,
- device=device,
- requires_grad=requires_grad,
- pin_memory=pin_memory,
- memory_format=memory_format,
- )
- # memory_format == torch.preserve_format
- logical_to_physical_perm = (
- utils.compute_elementwise_output_logical_to_physical_perm(a)
- )
- # identity perm is [2, 1, 0]
- return torch.empty_permuted(
- a.shape,
- logical_to_physical_perm,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- @register_decomposition([aten.arange.start_step, aten.arange.start_out])
- @out_wrapper()
- def arange(
- start: NumberType = 0,
- end: Optional[NumberType] = None,
- step: NumberType = 1,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[DeviceLikeType] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- utils.check_layout(layout)
- utils.check_pin_memory(pin_memory)
- device = torch.device(utils.device_or_default(device))
- assert not isinstance(start, complex)
- assert not isinstance(end, complex)
- assert not isinstance(step, complex)
- # Case: torch.arange(5)
- if end is None:
- end = start
- start = 0
- torch._check(step != 0, lambda: "step must be nonzero")
- if step > 0:
- torch._check(
- end >= start,
- lambda: "upper bound and lower bound inconsistent with step sign",
- )
- elif step < 0:
- torch._check(
- end <= start,
- lambda: "upper bound and lower bound inconsistent with step sign",
- )
- def is_finite(x):
- return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x)
- torch._check(
- is_finite(start) and is_finite(end),
- lambda: f"unsupported range: {start} -> {end}",
- )
- torch._check(
- is_finite(step),
- lambda: f"step must be finite but got {step}",
- )
- args = (start, end, step)
- integer_args = builtins.all(isinstance(arg, IntLike) for arg in args)
- if dtype is None:
- dtype = torch.int64 if integer_args else torch.get_default_dtype()
- is_integer = utils.is_integer_dtype(dtype)
- if is_integer:
- xstart = sym_int(start)
- xend = sym_int(end)
- xstep = sym_int(step)
- # For int64 we truncate arguments to int before calculating length, but
- # other integral dtypes we don't. Weird... but needed to match ATen shapes.
- if dtype == torch.int64:
- # Uses floordiv to avoid ceil in inductor.
- sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined]
- length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined]
- else:
- length = math.ceil((end - start) / step)
- if is_integer:
- return prims.iota(
- length,
- start=xstart, # type: ignore[possibly-undefined]
- step=xstep, # type: ignore[possibly-undefined]
- dtype=dtype,
- device=device,
- requires_grad=requires_grad,
- )
- index = prims.iota(
- length,
- start=0,
- step=1,
- dtype=torch.int64,
- device=device,
- requires_grad=False,
- )
- computation_dtype = (
- torch.long if integer_args else utils.get_acc_type(dtype, device)
- )
- index = _maybe_convert_to_dtype(index, computation_dtype)
- result = start + step * index
- result = _maybe_convert_to_dtype(result, dtype)
- if requires_grad:
- result.requires_grad_(True)
- return result
- @register_decomposition(aten.lerp)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("start", "end", "weight"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]):
- inputs = [start, end]
- if isinstance(weight, Number):
- weight = start.new_full((), weight) # type: ignore[arg-type]
- else:
- inputs.append(weight)
- assert isinstance(weight, Tensor) # mypy
- # We implement it this way for numerical stability. We assume (in the stability optimisation)
- # that 0 <= weight <= 1. We take the abs to deal with complex numbers
- # We want to perform operations near zero, which is where floating points are most precise
- # thus, we perform the following optimisation:
- # If weight.abs() >= 0.5:
- # return (1 - weight) * (start - end) + end
- mask = weight.abs() >= 0.5
- coeff = torch.where(mask, weight - 1, weight)
- base = torch.where(mask, end, start)
- output = coeff * (end - start) + base
- # make sure the decomposition output's stride is same as non-decomposition path.
- stride = utils.compute_elementwise_output_strides(*_maybe_broadcast(*inputs))
- if output.stride() != stride:
- output = prims.copy_strided(output, stride)
- return handle_noncontiguous_outputs(inputs, output)
- @register_decomposition(aten.linspace)
- @out_wrapper()
- def linspace(
- start: Union[NumberType, TensorLikeType],
- end: Union[NumberType, TensorLikeType],
- steps: NumberType,
- *,
- dtype: Optional[torch.dtype] = None,
- device: Optional[DeviceLikeType] = None,
- layout: torch.layout = torch.strided,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- if isinstance(start, TensorLikeType):
- torch._check(
- start.dim() == 0,
- lambda: "linspace only supports 0-dimensional start and end tensors",
- )
- start = _maybe_convert_to_dtype(start, torch.float64)
- if isinstance(end, TensorLikeType):
- torch._check(
- end.dim() == 0,
- lambda: "linspace only supports 0-dimensional start and end tensors",
- )
- end = _maybe_convert_to_dtype(end, torch.float64)
- if py_any(isinstance(arg, complex) for arg in (start, end, steps)):
- default_complex_dtype = utils.corresponding_complex_dtype(
- torch.get_default_dtype()
- )
- if dtype is None:
- dtype = default_complex_dtype
- else:
- torch._check(
- utils.is_complex_dtype(dtype),
- lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
- )
- else:
- dtype = dtype or torch.get_default_dtype()
- assert isinstance(dtype, torch.dtype)
- # steps does not participate in the computation of the dtype
- torch._check_type(
- isinstance(steps, IntLike),
- lambda: f"received an invalid combination of arguments - got \
- ({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})",
- )
- assert isinstance(steps, IntLike) # for mypy
- torch._check(steps >= 0, lambda: "number of steps must be non-negative")
- factory_kwargs = {
- "layout": layout,
- "device": device,
- "pin_memory": pin_memory,
- "requires_grad": requires_grad,
- }
- if steps == 0:
- return torch.full((0,), 0, dtype=dtype, **factory_kwargs) # type: ignore[arg-type]
- if steps == 1:
- if isinstance(start, TensorLikeType):
- return torch.empty((steps,), dtype=dtype, **factory_kwargs).copy_(start) # type: ignore[arg-type]
- else:
- return torch.full((steps,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type]
- # Perform in arange in int because some backends like ATen or Triton do not support all the dtypes
- rg = torch.arange(0, steps, **factory_kwargs) # type: ignore[arg-type]
- # Small types need to be computed in higher precision as this is, at heart, an associative scan
- dtype_red = (
- torch.int64
- if (utils.is_boolean_dtype(dtype) or utils.is_integer_dtype(dtype))
- else dtype
- )
- computation_dtype, _ = utils.reduction_dtypes(
- rg, REDUCTION_OUTPUT_TYPE_KIND.SAME, dtype_red
- )
- cast_rg = partial(_maybe_convert_to_dtype, dtype=computation_dtype)
- # We implement torch.lerp without performing rg / (steps - 1) explicitly
- # With this we get out[0] == start, out[-1] == end
- step = (end - start) / (steps - 1)
- out = torch.where(
- rg < steps / 2,
- start + step * cast_rg(rg), # type: ignore[arg-type,operator]
- end - step * cast_rg((steps - 1) - rg), # type: ignore[arg-type,operator]
- )
- return _maybe_convert_to_dtype(out, dtype) # type: ignore[return-value]
- @register_decomposition(aten.logspace)
- @out_wrapper()
- def logspace(
- start: Union[NumberType, TensorLikeType],
- end: Union[NumberType, TensorLikeType],
- steps: NumberType,
- base: NumberType = 10,
- *,
- dtype: Optional[torch.dtype] = None,
- device: Optional[DeviceLikeType] = None,
- layout: torch.layout = torch.strided,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- if dtype is None:
- dtype = torch.get_default_dtype()
- # NB: NumPy doesn't have this cast
- if prims.utils.is_integer_dtype(dtype):
- if isinstance(start, FloatLike):
- start = sym_int(start)
- elif isinstance(start, TensorLikeType):
- torch._check(
- start.dim() == 0,
- lambda: "logspace only supports 0-dimensional start and end tensors",
- )
- start = _maybe_convert_to_dtype(start, dtype)
- if isinstance(end, FloatLike):
- end = sym_int(end)
- elif isinstance(end, TensorLikeType):
- torch._check(
- end.dim() == 0,
- lambda: "logspace only supports 0-dimensional start and end tensors",
- )
- end = _maybe_convert_to_dtype(end, dtype)
- if py_any(isinstance(arg, complex) for arg in (start, end, steps)):
- default_complex_dtype = utils.corresponding_complex_dtype(
- torch.get_default_dtype()
- )
- dtype = default_complex_dtype
- _dtype = None # torch.linspace will update the correct dtype
- else:
- _dtype = torch.float64
- assert not isinstance(base, complex) # for mypy
- if base < 0:
- raise NotImplementedError
- ret = torch.linspace( # type: ignore[misc]
- start, # type: ignore[arg-type]
- end, # type: ignore[arg-type]
- steps, # type: ignore[arg-type]
- dtype=_dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- return _maybe_convert_to_dtype(torch.pow(base, ret), dtype) # type: ignore[arg-type,return-value]
- @overload
- def meshgrid(tensors: Sequence[TensorLikeType], indexing: str):
- pass
- @overload
- def meshgrid(*tensors: TensorLikeType, indexing: str):
- pass
- @register_decomposition(aten.meshgrid)
- def meshgrid(
- *tensors: Union[TensorLikeType, List[TensorLikeType], Tuple[TensorLikeType]],
- indexing: str,
- ) -> List[TensorLikeType]:
- # This ref simultaneously handles two overloads (see stubs above)
- # The `indexing` argument is currently optional for torch.meshgrid, but we
- # plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276
- if isinstance(tensors[0], (list, tuple)):
- assert len(tensors) == 1
- tensors = tuple(tensors[0])
- torch._check(
- py_all(isinstance(a, TensorLike) for a in tensors),
- lambda: "meshgrid expects its inputs to be tensors",
- )
- torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
- for i in range(len(tensors) - 1):
- torch._check(
- tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr]
- lambda: "meshgrid expects all tensors to have the same dtype",
- )
- torch._check(
- tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr]
- lambda: "meshgrid expects all tensors to have the same device",
- )
- swap_first_and_second_tensors = False
- if indexing == "xy":
- swap_first_and_second_tensors = len(tensors) >= 2
- if swap_first_and_second_tensors:
- tensors = (tensors[1], tensors[0], *tensors[2:])
- else:
- torch._check(
- indexing == "ij",
- lambda: (
- 'torch.meshgrid: indexing must be one of "xy" or "ij", '
- f"but received: {indexing}"
- ),
- )
- result_shape: List[int] = []
- for t in tensors:
- assert isinstance(t, TensorLike) # mypy
- torch._check(
- t.ndim == 0 or t.ndim == 1,
- lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}",
- )
- result_shape.append(t.numel())
- grids: List[TensorLikeType] = []
- for i, t in enumerate(tensors):
- assert isinstance(t, TensorLike) # mypy
- if t.ndim == 0:
- t = t.view((1,))
- grids.append(prims.broadcast_in_dim(t, result_shape, (i,)))
- if swap_first_and_second_tensors:
- # Swap outputs if we originally swapped at the beginning
- grids[0], grids[1] = grids[1], grids[0]
- return grids
- # CompositeImplicitAutograd - don't register decomp
- def movedim(
- input: TensorLikeType,
- source: Union[int, DimsSequenceType],
- destination: Union[int, DimsSequenceType],
- ) -> TensorLikeType:
- """
- Reference implementation of torch.movedim
- """
- if type(source) is int:
- source = (source,)
- if type(destination) is int:
- destination = (destination,)
- # Converts to list to produce a compatible error message with core PyTorch,
- # which prints sequences in square brackets.
- torch._check(
- len(source) == len(destination), # type: ignore[arg-type]
- lambda: (
- "movedim: Invalid source or destination dims: source " # type: ignore[arg-type]
- f"({list(source)} dims) should contain the same number " # type: ignore[arg-type]
- f"of dims as destination ({list(destination)} dims)" # type: ignore[arg-type]
- ),
- )
- rank = input.ndim
- ss = tuple(utils.canonicalize_dims(rank=rank, indices=source)) # type: ignore[arg-type]
- ds = tuple(utils.canonicalize_dims(rank=rank, indices=destination)) # type: ignore[arg-type]
- sss = set(ss)
- dss = set(ds)
- # See above on why this converts to list in error messages.
- torch._check(
- len(ss) == len(sss),
- lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type]
- )
- torch._check(
- len(ds) == len(dss),
- lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type]
- )
- m = dict(zip(ds, ss))
- dims = []
- si = 0 # source index
- for di in range(rank):
- # check if the destination index is in the mapping
- s = m.get(di)
- if s is not None:
- # insert source index if found
- dims.append(s)
- else:
- # insert source index sequentially, skipping indices from the mapping
- while si in sss:
- si += 1
- dims.append(si)
- si += 1
- result = torch.permute(input, tuple(dims))
- return result
- # NOTE: for convenience, shape can be a tuple of ints or a tuple containing a tuple of ints
- @register_decomposition(aten.empty_strided)
- @out_wrapper()
- def empty_strided(
- shape: Union[ShapeType, Tuple[ShapeType]],
- strides: StrideType,
- *,
- dtype: Optional[torch.dtype] = None,
- device: Optional[DeviceLikeType] = None,
- layout: torch.layout = torch.strided,
- requires_grad: bool = False,
- pin_memory: bool = False,
- ) -> TensorLikeType:
- # Layout == strided, pin_memory is False
- utils.check_layout(layout)
- utils.check_pin_memory(pin_memory)
- shape = utils.extract_shape_from_varargs(shape)
- dtype = torch.get_default_dtype() if dtype is None else dtype
- device = torch.device("cpu") if device is None else device
- return prims.empty_strided(
- shape,
- strides,
- dtype=dtype,
- device=device,
- requires_grad=requires_grad,
- )
- @register_decomposition(aten.eye)
- @out_wrapper()
- def eye(
- n: int,
- m: Optional[int] = None,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[DeviceLikeType] = None,
- pin_memory: bool = False,
- requires_grad: bool = False, # TODO: unused
- ) -> TensorLikeType:
- """
- Reference implementation of torch.eye
- """
- if m is None:
- m = n
- torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
- torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
- range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False)
- range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False)
- cond = range_n.unsqueeze(-1) == range_m
- if dtype is torch.bool:
- return cond
- else:
- one = torch.ones(
- (1,),
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=False,
- )
- return torch.where(cond, one, 0)
- # TODO: Use requires_grad. All refs taking the requires_grad kwarg must
- # return a leaf tensor.
- # result.requires_grad_(requires_grad)
- @register_decomposition([aten.full.default, aten.full.out])
- @out_wrapper()
- def full(
- shape: ShapeType,
- fill_value: NumberType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[DeviceLikeType] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- utils.check_layout(layout)
- utils.check_pin_memory(pin_memory)
- dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
- device = device if device is not None else torch.device("cpu")
- e = empty(
- shape,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- return torch.fill(e, fill_value) # type: ignore[arg-type]
- def full_like(
- a: TensorLikeType,
- fill_value: NumberType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[DeviceLikeType] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- memory_format: torch.memory_format = torch.preserve_format,
- ) -> TensorLikeType:
- e = torch.empty_like(
- a,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- memory_format=memory_format,
- )
- return fill(e, fill_value)
- @register_decomposition(aten.zeros_like)
- @out_wrapper()
- def zeros_like(
- a: TensorLikeType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[DeviceLikeType] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- memory_format: torch.memory_format = torch.preserve_format,
- ) -> TensorLikeType:
- return torch.full_like(
- a,
- False if (dtype or a.dtype) == torch.bool else 0,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- memory_format=memory_format,
- )
- @register_decomposition(aten.ones_like)
- @out_wrapper()
- def ones_like(
- a: TensorLikeType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[DeviceLikeType] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- memory_format: torch.memory_format = torch.preserve_format,
- ) -> TensorLikeType:
- return torch.full_like(
- a,
- True if (dtype or a.dtype) == torch.bool else 1,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- memory_format=memory_format,
- )
- @register_decomposition(aten.randn.default)
- @out_wrapper()
- def randn(
- *shape,
- dtype: Optional[torch.dtype] = None,
- device: Optional[DeviceLikeType] = None,
- layout: Optional[torch.layout] = None,
- requires_grad: bool = False,
- pin_memory: bool = False,
- ) -> TensorLikeType:
- utils.check_pin_memory(pin_memory)
- shape_ = utils.extract_shape_from_varargs(shape)
- dtype = utils.dtype_or_default(dtype)
- device = utils.device_or_default(device)
- return prims.normal(
- shape_,
- mean=0.0,
- std=1.0,
- dtype=dtype,
- device=device,
- requires_grad=requires_grad,
- )
- def scalar_tensor(
- a: NumberType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[DeviceLikeType] = None,
- pin_memory: bool = False,
- ) -> TensorLikeType:
- utils.check_layout(layout)
- utils.check_pin_memory(pin_memory)
- dtype = dtype if dtype is not None else utils.type_to_dtype(type(a))
- device = device if device is not None else torch.device("cpu")
- return prims.scalar_tensor(a, dtype=dtype, device=device)
- #
- # Randomness References
- #
- def _uniform_helper(
- shape: ShapeType,
- low: Union[bool, int, float] = 0.0,
- high: Union[bool, int, float] = 1.0,
- *,
- dtype: torch.dtype,
- device: DeviceLikeType,
- ) -> TensorLikeType:
- utils.validate_shape(shape)
- assert isinstance(low, Number)
- assert isinstance(high, Number)
- low = sym_float(low)
- high = sym_float(high)
- assert isinstance(dtype, torch.dtype)
- device = utils.canonicalize_device(device)
- return prims._uniform_helper(shape, low=low, high=high, dtype=dtype, device=device)
- @register_decomposition(aten.masked_fill)
- @out_wrapper()
- def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType):
- python_type = utils.dtype_to_type(a.dtype)
- if isinstance(value, Number):
- value_type = type(value)
- else:
- # NOTE: Could not use value = item(value) as it resulted in
- # RuntimeError: Cannot cast FakeTensor(cpu) to number
- value_ndim = value.ndim
- torch._check(
- value_ndim == 0,
- lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension",
- )
- # `masked_fill` allows cpu scalar to be moved to cuda and xpu but not otherwise.
- is_cpu_scalar = (
- a.device.type in ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
- and value.device.type == "cpu"
- )
- torch._check(
- is_cpu_scalar or value.device == a.device,
- lambda: "Expected `value` to be on same device as `a`",
- )
- value_type = utils.dtype_to_type(value.dtype)
- if value_type is complex:
- # only downcasting from complex to lower type is not allowed.
- # We allow casting `value` to lower type for other case
- # Eg. float -> int.
- # Ref: https://github.com/pytorch/pytorch/issues/79195
- torch._check(
- utils.is_weakly_lesser_type(value_type, python_type),
- lambda: f"could not convert to type {python_type} without overflow",
- )
- # Since `where` allows type-promotion,
- # cast value to correct type before passing to `where`
- value = _maybe_convert_to_dtype(value, a.dtype)
- r = torch.where(mask, value, a) # type: ignore[arg-type]
- # aten.mask_fill always return a new contiguous tensor
- # contiguous() is needed to correctly model the output stride
- return r.contiguous()
- @register_decomposition(aten.masked_fill_)
- def masked_fill_(
- a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType
- ) -> TensorLikeType:
- b = torch.masked_fill(a, mask, value) # type: ignore[arg-type]
- a.copy_(b)
- return a
- # CompositeImplicitAutograd - don't register decomp
- def allclose(
- a: TensorLikeType,
- b: TensorLikeType,
- rtol: float = 1e-05,
- atol: float = 1e-08,
- equal_nan: bool = False,
- ) -> bool:
- """
- Reference implementation of torch.allclose
- """
- _check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
- return bool(
- torch.all(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).item()
- )
- def equal(a: TensorLikeType, b: TensorLikeType) -> bool:
- utils.check_same_device(a, b, allow_cpu_scalar_tensors=False)
- utils.check_same_dtype(a, b)
- # Shape check
- if a.ndim != b.ndim:
- return False
- for x, y in zip(a.shape, b.shape):
- if x != y:
- return False
- # Short-circuits if there are no elements to validate
- if a.numel() == 0:
- return True
- return item(all(eq(a, b))) # type: ignore[return-value]
- @register_decomposition(aten.norm)
- @out_wrapper(exact_dtype=True)
- def norm(
- input: TensorLikeType,
- p: Optional[Union[float, str]] = "fro",
- dim: Optional[DimsType] = None,
- keepdim: bool = False,
- *,
- dtype: Optional[torch.dtype] = None,
- ) -> TensorLikeType:
- # In these cases we compute the "Frobenius norm"
- if (
- p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2)
- ) or p is None:
- p = 2
- if isinstance(dim, Dim):
- dim = [dim]
- if isinstance(p, str):
- # Here we either call the nuclear norm, or we call matrix_norm with some arguments
- # that will throw an error
- if dim is None:
- dim = tuple(range(input.ndim))
- return torch.linalg.matrix_norm(input, p, dim, keepdim, dtype=dtype)
- else:
- return torch.linalg.vector_norm(input, p, dim, keepdim, dtype=dtype)
- @register_decomposition(aten.trace)
- @out_wrapper()
- def trace(self: TensorLikeType) -> TensorLikeType:
- torch._check(
- self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
- )
- return torch.sum(torch.diag(self, 0))
- def _make_r_binary_op(base_op):
- def rop(
- a: Union[TensorLikeType, NumberType],
- b: Union[TensorLikeType, NumberType],
- ) -> TensorLikeType:
- return base_op(b, a)
- return rop
- rtruediv = _make_r_binary_op(true_divide)
- rfloordiv = _make_r_binary_op(floor_divide)
- rpow = _make_r_binary_op(pow)
- @register_decomposition(aten.triu)
- @out_wrapper()
- def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
- torch._check(
- a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions"
- )
- h, w = a.shape[-2:]
- mask = (
- torch.arange(w, device=a.device).unsqueeze(-2)
- - torch.arange(h, device=a.device).unsqueeze(-1)
- ) >= diagonal
- # aten.triu always returns a new contiguous tensor
- # contiguous() is needed to correctly model the output stride
- return utils.mask_tensor(mask, a).contiguous()
- @register_decomposition(aten.tril)
- @out_wrapper()
- def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
- torch._check(
- a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions"
- )
- h, w = a.shape[-2:]
- mask = (
- torch.arange(w, device=a.device).unsqueeze(-2)
- - torch.arange(h, device=a.device).unsqueeze(-1)
- ) <= diagonal
- # aten.tril always returns a new contiguous tensor
- # contiguous() is needed to correctly model the output stride
- return utils.mask_tensor(mask, a).contiguous()
- # This is based on get_tril_size in aten/src/ATen/native/TensorFactories.h
- # The components of the matrix that belong to the lower triangle with offset
- # form a pentagon that can be broken down into a top trapezoid and a bottom
- # rectangle. For the implementation of tril_indices, we need the sizes of
- # both of these, as well as the length of the top side of the trapezoid.
- def _get_tril_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]:
- if row == 0 or col == 0:
- return 0, 0, 0
- m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0)
- m_last_row = max(0, min(col, row + offset))
- n_row_all = max(0, min(row, row + offset))
- n_row_trapezoid = m_last_row - m_first_row + 1
- # Number of elements in top trapezoid
- trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2
- # Number of elements in bottom rectangle
- diff_row = n_row_all - n_row_trapezoid
- rectangle_size = max(0, diff_row * col)
- return trapezoid_size, rectangle_size, m_first_row
- def _trilu_checks(
- name: str,
- row: int,
- col: int,
- dtype: torch.dtype,
- layout: torch.layout,
- pin_memory: bool,
- ):
- torch._check(row >= 0, lambda: f"row must be non-negative, got {row}")
- torch._check(col >= 0, lambda: f"col must be non-negative, got {col}")
- torch._check(
- dtype in (torch.int32, torch.int64),
- lambda: f"\"{name}\" not implemented for '{dtype}'",
- )
- # This is based on tril_indices_cuda in aten/src/ATen/native/cuda/TensorFactories.cu
- @register_decomposition(aten.tril_indices)
- @out_wrapper()
- def tril_indices(
- row: int,
- col: int,
- offset: int = 0,
- *,
- dtype: torch.dtype = torch.long,
- layout: torch.layout = torch.strided,
- device: DeviceLikeType = "cpu",
- pin_memory: bool = False,
- ) -> TensorLikeType:
- _trilu_checks("tril_indices", row, col, dtype, layout, pin_memory)
- trapezoid_size, rectangle_size, m_first_row = _get_tril_sizes(row, col, offset)
- row_offset = max(0, -offset)
- arange_kw = partial(
- torch.arange, layout=layout, device=device, pin_memory=pin_memory
- )
- # first we do the indices for top trapezoid
- xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64)
- b = m_first_row - 0.5
- row_inds1 = torch.floor(-b + torch.sqrt(b * b + 2 * xs1))
- col_inds1 = torch.floor(xs1 - (2 * m_first_row - 1 + row_inds1) * row_inds1 * 0.5)
- row_inds1 = _maybe_convert_to_dtype(row_inds1 + row_offset, dtype)
- col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype)
- # then bottom rectangle
- xs2 = arange_kw(0, rectangle_size, dtype=dtype)
- row_inds2 = xs2 // col + (col - m_first_row + 1 + row_offset)
- col_inds2 = xs2 % col
- return torch.stack(
- (torch.cat((row_inds1, row_inds2)), torch.cat((col_inds1, col_inds2)))
- )
- # Similar to _get_tril_sizes above, but here there is a top trapezoid and
- # a bottom rectangle instead. Note that you can't reduce this to
- # _get_tril_sizes(col, row, -offset) because that would correspond to
- # decomposing into a left trapezoid and right rectangle.
- def _get_triu_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]:
- if row == 0 or col == 0:
- return 0, 0, 0
- m_first_row = max(0, col - offset) if offset > 0 else col
- # Number of elements in top rectangle
- rectangle_size = max(0, min(row, -offset) * col)
- # Number of elements in bottom trapezoid
- trapezoid_size_tril, rectangle_size_tril, _ = _get_tril_sizes(row, col, offset - 1)
- triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril)
- trapezoid_size = triu_size - rectangle_size
- return trapezoid_size, rectangle_size, m_first_row
- @register_decomposition(aten.triu_indices)
- @out_wrapper()
- def triu_indices(
- row: int,
- col: int,
- offset: int = 0,
- *,
- dtype: torch.dtype = torch.long,
- layout: torch.layout = torch.strided,
- device: DeviceLikeType = "cpu",
- pin_memory: bool = False,
- ) -> TensorLikeType:
- _trilu_checks("triu_indices", row, col, dtype, layout, pin_memory)
- trapezoid_size, rectangle_size, m_first_row = _get_triu_sizes(row, col, offset)
- col_offset = max(0, offset)
- arange_kw = partial(
- torch.arange, layout=layout, device=device, pin_memory=pin_memory
- )
- # indices for top rectangle
- xs2 = arange_kw(0, rectangle_size, dtype=dtype)
- row_inds2 = xs2 // col
- col_inds2 = xs2 % col
- # bottom trapezoid
- xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64)
- b = -0.5 - m_first_row
- row_inds1 = torch.floor(-b - torch.sqrt(b * b - 2 * xs1))
- col_inds1 = torch.floor(xs1 - ((2 * m_first_row - 1 - row_inds1) * row_inds1) * 0.5)
- row_inds1 = _maybe_convert_to_dtype(row_inds1, dtype)
- col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype)
- if col:
- row_inds1 = row_inds1 + (rectangle_size // col)
- col_inds1 = col_inds1 + col_offset
- return torch.stack(
- (torch.cat((row_inds2, row_inds1)), torch.cat((col_inds2, col_inds1)))
- )
- @register_decomposition(aten.bucketize)
- @out_wrapper(exact_dtype=True)
- def bucketize(
- a: TensorLikeType,
- boundaries: TensorLikeType,
- *,
- out_int32: bool = False,
- right: bool = False,
- ):
- torch._check(
- boundaries.dim() == 1,
- lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})",
- )
- out_dtype = torch.int32 if out_int32 else torch.int64
- n_boundaries = boundaries.shape[-1]
- if n_boundaries == 0:
- return torch.zeros_like(a)
- # We are trying to find the bucket (defined by pairs of consecutive elements of `boundaries`)
- # each element of `a` belongs to. We use binary search to achieve logarithimic complexity,
- # but each step of the search is done "in parallel" over all elements of `a`
- # can't use int32 as indexes, so we have to do all computations with int64 and convert at the end
- start = torch.zeros(a.shape, device=a.device, dtype=torch.int64)
- end = start + n_boundaries
- # Max depth of the binary search
- # Since we can't break out of the loop at different points for different elements of a,
- # we just do the max amount of iterations that binary search requires and add condition
- # tensor (cond_update below) to stop updating once the search terminates
- # For first iteration through loop we can skip some checks, we have separate implementation
- mid = start + (end - start) // 2
- mid_val = boundaries[mid]
- if right:
- cond_mid = mid_val > a
- else:
- cond_mid = mid_val >= a
- start = torch.where(cond_mid, start, mid + 1)
- if n_boundaries > 1:
- cond_update = torch.ones_like(a, dtype=torch.bool)
- niters = int(math.log2(n_boundaries))
- for _ in range(niters):
- end = torch.where(cond_mid & cond_update, mid, end)
- cond_update = start < end
- # start might end up pointing to 1 past the end, we guard against that
- mid = torch.where(cond_update, start + (end - start) // 2, 0)
- mid_val = boundaries[mid]
- # If right is true, the buckets are closed on the *left*
- # (i.e., we are doing the equivalent of std::upper_bound in C++)
- # Otherwise they are closed on the right (std::lower_bound)
- if right:
- cond_mid = mid_val > a
- else:
- cond_mid = mid_val >= a
- start = torch.where((~cond_mid) & cond_update, mid + 1, start)
- return start.to(dtype=out_dtype)
- @register_decomposition(aten.cauchy)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self",),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def cauchy(self, median=0, sigma=1, generator=None):
- assert generator is None
- torch._check(
- not utils.is_complex_dtype(self.dtype)
- and not utils.is_integer_dtype(self.dtype)
- and not utils.is_boolean_dtype(self.dtype),
- lambda: f"Cauchy distribution is a continuous probability distribution. \
- dtype must be a floating point but you specified {self.dtype}",
- )
- torch._check(
- sigma > 0.0,
- lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}",
- )
- return median + sigma * torch.tan(math.pi * (torch.rand_like(self) - 0.5))
- @register_decomposition(aten.exponential)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self",),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def exponential(self, rate=1, generator=None):
- assert generator is None
- torch._check(
- not utils.is_complex_dtype(self.dtype)
- and not utils.is_integer_dtype(self.dtype)
- and not utils.is_boolean_dtype(self.dtype),
- lambda: f"Exponential distribution is a continuous probability distribution. \
- dtype must be a floating point but you specified {self.dtype}",
- )
- torch._check(
- rate > 0.0,
- lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
- )
- return -1 / rate * torch.log1p(-torch.rand_like(self))
- @register_decomposition(aten.geometric)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self",),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def geometric(self, p, generator=None):
- assert generator is None
- # TODO: fix inductor rand_like for integer, bool dtypes
- torch._check(
- not utils.is_complex_dtype(self.dtype)
- and not utils.is_boolean_dtype(self.dtype),
- lambda: f"geometric not implemented for {self.dtype}",
- )
- torch._check(
- 0 < p and p < 1,
- lambda: f"geometric_ expects p to be in (0, 1), but got p={p}",
- )
- return torch.floor(torch.log1p(-torch.rand_like(self)) / math.log1p(-p)) + 1
- @register_decomposition(aten.log_normal)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self",),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def log_normal(self, mean=1, std=2, generator=None):
- assert generator is None
- torch._check(
- not utils.is_complex_dtype(self.dtype)
- and not utils.is_integer_dtype(self.dtype)
- and not utils.is_boolean_dtype(self.dtype),
- lambda: f"log_normal not implemented for {self.dtype}",
- )
- torch._check(
- 0 < std,
- lambda: f"log_normal_ expects std > 0.0, but found std={std}",
- )
- return torch.exp(std * torch.randn_like(self) + mean)
- # TODO: add support for functionalization aten.normal_functional
- # NOTE: the device and dtype will be ignored when shape is None
- @register_decomposition(aten.normal)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=(
- "mean",
- "std",
- ),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def normal(
- mean=0,
- std=1,
- size=None,
- *,
- generator=None,
- dtype=None,
- layout=None,
- device=None,
- pin_memory=None,
- ):
- assert layout is None or layout == torch.strided
- if not isinstance(std, TensorLike):
- torch._check(
- std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}"
- )
- if size is None:
- tensors = tuple(t for t in (mean, std) if isinstance(t, TensorLike))
- torch._check(
- len(tensors) > 0,
- lambda: "normal expects that either mean or std is a tensor, or size is defined",
- )
- torch._check(
- layout is None and pin_memory is None,
- lambda: "Cannot pass layout, or pin_memory without size",
- )
- size = _broadcast_shapes(*(t.shape for t in tensors))
- dtype = tensors[0].dtype
- device = tensors[0].device
- else:
- torch._check(
- not isinstance(mean, TensorLike) and not isinstance(std, TensorLike),
- lambda: "normal expects mean and std to be scalars when size is defined",
- )
- dtype = torch.get_default_dtype() if dtype is None else dtype
- device = torch.device("cpu") if device is None else device
- normal_samples = prims.normal(
- size,
- mean=0.0,
- std=1.0,
- dtype=dtype,
- device=device,
- requires_grad=False,
- generator=generator,
- )
- return std * normal_samples + mean
- @register_decomposition(aten.normal_)
- def normal_(self, mean=0, std=1, *, generator=None):
- return normal(mean, std, self.shape, out=self, generator=generator)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def rad2deg(self: TensorLikeType):
- torch._check(
- not utils.is_complex_dtype(self.dtype),
- lambda: "rad2deg is not supported for complex tensors.",
- )
- M_180_PI = 57.295779513082320876798154814105170332405472466564
- return self * M_180_PI
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def deg2rad(self: TensorLikeType):
- torch._check(
- not utils.is_complex_dtype(self.dtype),
- lambda: "deg2rad is not supported for complex tensors.",
- )
- M_PI_180 = 0.017453292519943295769236907684886127134428718885417
- return self * M_PI_180
- @register_decomposition(aten.count_nonzero)
- @out_wrapper()
- def count_nonzero(self, dim: Optional[DimsType] = None):
- return (self != 0).sum(dim)
- def _dot_check(self, other):
- torch._check(
- self.dim() == 1 and other.dim() == 1,
- lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
- )
- def numel_error():
- return (
- f"inconsistent tensor size, expected tensor [{self.numel()}] and src [{other.numel()}] to have the"
- f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively"
- )
- torch._check(self.numel() == other.numel(), numel_error)
- @register_decomposition(aten.dot)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self", "other"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def dot(self, other):
- if self.is_complex():
- if self.is_conj():
- if other.is_conj():
- return torch.dot(self.conj(), other.conj()).conj()
- else:
- return torch.vdot(self.conj(), other)
- elif other.is_conj():
- return torch.vdot(other.conj(), self)
- _dot_check(self, other)
- return (self * other).sum()
- @register_decomposition(aten.vdot)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self", "other"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def vdot(self, other):
- if not self.is_complex():
- return torch.dot(self, other)
- if self.is_conj():
- if other.is_conj():
- return torch.vdot(other.conj(), self.conj())
- else:
- return torch.dot(self.conj(), other)
- elif other.is_conj():
- return torch.dot(self, other.conj()).conj()
- _dot_check(self, other)
- # The decomposition fails if you do self.conj()... not sure why
- return (self.conj_physical() * other).sum()
- @register_decomposition(aten.select_scatter)
- @out_wrapper()
- def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int):
- dim = utils.canonicalize_dim(x.ndim, dim)
- mask_shape = [1] * x.ndim
- mask_shape[dim] = -1
- if index < 0:
- index = index + x.shape[dim]
- mask = torch.arange(x.shape[dim], device=x.device).view(mask_shape) == index
- src = torch.unsqueeze(src, dim).expand(x.shape)
- return torch.where(mask, src, x)
- # inplace
- abs_ = _make_inplace(abs)
- acos_ = _make_inplace(acos)
- acosh_ = _make_inplace(acosh)
- add_ = _make_inplace(add)
- addcmul_ = _make_inplace(addcmul)
- addcdiv_ = _make_inplace(addcdiv)
- asin_ = _make_inplace(asin)
- asinh_ = _make_inplace(asinh)
- atan_ = _make_inplace(atan)
- atanh_ = _make_inplace(atanh)
- atan2_ = _make_inplace(atan2)
- bitwise_and_ = _make_inplace(bitwise_and)
- bitwise_left_shift_ = _make_inplace(bitwise_left_shift)
- bitwise_not_ = _make_inplace(bitwise_not)
- bitwise_or_ = _make_inplace(bitwise_or)
- bitwise_right_shift_ = _make_inplace(bitwise_right_shift)
- bitwise_xor_ = _make_inplace(bitwise_xor)
- ceil_ = _make_inplace(ceil)
- clamp_ = _make_inplace(clamp)
- clamp_min_ = _make_inplace(clamp_min)
- clamp_max_ = _make_inplace(clamp_max)
- conj_physical_ = _make_inplace(conj_physical)
- copysign_ = _make_inplace(copysign)
- cos_ = _make_inplace(cos)
- cosh_ = _make_inplace(cosh)
- cumsum_ = _make_inplace(cumsum)
- cumprod_ = _make_inplace(cumprod)
- deg2rad_ = _make_inplace(deg2rad)
- digamma_ = _make_inplace(digamma)
- div_ = _make_inplace(div)
- eq_ = _make_inplace(eq)
- erf_ = _make_inplace(erf)
- erfc_ = _make_inplace(erfc)
- erfinv_ = _make_inplace(erfinv)
- exp_ = _make_inplace(exp)
- exp2_ = _make_inplace(exp2)
- expm1_ = _make_inplace(expm1)
- float_power_ = _make_inplace(float_power)
- floor_ = _make_inplace(floor)
- floor_divide_ = _make_inplace(floor_divide)
- fmod_ = _make_inplace(fmod)
- frac_ = _make_inplace(frac)
- gcd_ = _make_inplace(gcd)
- ge_ = _make_inplace(ge)
- gt_ = _make_inplace(gt)
- heaviside_ = _make_inplace(heaviside)
- hypot_ = _make_inplace(hypot)
- igamma_ = _make_inplace(igamma)
- igammac_ = _make_inplace(igammac)
- i0_ = _make_inplace(i0)
- lcm_ = _make_inplace(lcm)
- le_ = _make_inplace(le)
- lerp_ = _make_inplace(lerp)
- lgamma_ = _make_inplace(lgamma)
- log10_ = _make_inplace(log10)
- log1p_ = _make_inplace(log1p)
- log2_ = _make_inplace(log2)
- log_ = _make_inplace(log)
- logical_and_ = _make_inplace(logical_and)
- logical_not_ = _make_inplace(logical_not)
- logical_or_ = _make_inplace(logical_or)
- logical_xor_ = _make_inplace(logical_xor)
- lt_ = _make_inplace(lt)
- mul_ = _make_inplace(mul)
- mvlgamma_ = _make_inplace(mvlgamma)
- nan_to_num_ = _make_inplace(nan_to_num)
- ne_ = _make_inplace(ne)
- neg_ = _make_inplace(neg)
- nextafter_ = _make_inplace(nextafter)
- pow_ = _make_inplace(pow)
- rad2deg_ = _make_inplace(rad2deg)
- reciprocal_ = _make_inplace(reciprocal)
- remainder_ = _make_inplace(remainder)
- rsqrt_ = _make_inplace(rsqrt)
- sgn_ = _make_inplace(sgn)
- sigmoid_ = _make_inplace(sigmoid)
- sign_ = _make_inplace(sign)
- sin_ = _make_inplace(sin)
- sinc_ = _make_inplace(sinc)
- sinh_ = _make_inplace(sinh)
- sqrt_ = _make_inplace(sqrt)
- square_ = _make_inplace(square)
- sub_ = _make_inplace(sub)
- tan_ = _make_inplace(tan)
- tanh_ = _make_inplace(tanh)
- tril_ = _make_inplace(tril)
- triu_ = _make_inplace(triu)
- true_divide_ = _make_inplace(true_divide)
- trunc_ = _make_inplace(trunc)
- xlogy_ = _make_inplace(xlogy)
- cauchy_ = _make_inplace(cauchy)
- exponential_ = _make_inplace(exponential)
- geometric_ = _make_inplace(geometric)
- log_normal_ = _make_inplace(log_normal)
- zero_ = _make_inplace(zero)
- # xref: isStorage in torch/csrc/DynamicTypes.cpp
- def _isStorage(obj):
- return isinstance(obj, (torch.TypedStorage, torch.UntypedStorage))
- # xref: compute_sizes in torch/csrc/utils/tensor_new.cpp
- def _compute_sizes(seq, scalar_type):
- MAX_DIMS = 128
- is_storage = _isStorage(seq)
- sizes = []
- # TODO: this is inaccurate, we actually test PySequence_Check
- while isinstance(seq, (list, tuple)):
- length = len(seq)
- if is_storage:
- length //= scalar_type.itemsize
- sizes.append(length)
- if len(sizes) > MAX_DIMS:
- raise ValueError(f"too many dimensions '{type(seq).__name__}'")
- if length == 0:
- break
- try:
- handle = seq[0]
- except Exception:
- raise ValueError( # noqa: B904
- f"could not determine the shape of object type '{type(seq).__name__}'"
- )
- seq = handle
- return sizes
- # xref: infer_scalar_type in torch/csrc/utils/tensor_new.cpp
- def _infer_scalar_type(obj):
- if isinstance(obj, FloatLike):
- return torch.get_default_dtype()
- if isinstance(obj, IntLike) and not isinstance(obj, bool): # careful!
- return torch.int64
- if isinstance(obj, BoolLike):
- return torch.bool
- if isinstance(obj, complex):
- default_dtype = torch.get_default_dtype()
- if default_dtype is torch.float:
- return torch.cfloat
- elif default_dtype is torch.double:
- return torch.cdouble
- elif default_dtype is torch.half:
- return torch.chalf
- else:
- raise RuntimeError("invalid default scalar type for complex")
- if isinstance(obj, torch.Tensor):
- return obj.dtype
- if isinstance(obj, str):
- raise TypeError(f"new(): invalid data type '{type(obj).__name__}'")
- # TODO: this is inaccurate, we actually test PySequence_Check
- if isinstance(obj, (list, tuple)):
- scalarType = None
- length = len(obj)
- # match NumPy semantics, except use default tensor type instead of
- # double.
- if length == 0:
- return torch.get_default_dtype()
- for i in range(length):
- cur_item = obj[i]
- # TODO: test this
- """
- if cur_item is obj:
- raise TypeError("new(): self-referential lists are incompatible")
- """
- item_scalarType = _infer_scalar_type(cur_item) # recurse!
- if scalarType is not None:
- scalarType = torch.promote_types(scalarType, item_scalarType)
- else:
- scalarType = item_scalarType
- if scalarType is torch.cdouble:
- # this won't change (unless we hit undefined, but that will
- # fail later)
- return scalarType
- return scalarType
- raise RuntimeError(f"Could not infer dtype of {type(obj).__name__}")
- # Analogous to recursive_store
- # xref: recursive_store in torch/csrc/utils/tensor_new.cpp
- def _recursive_build(
- scalarType: torch.dtype, obj: Union[TensorOrNumberLikeType, TensorSequenceType]
- ):
- if isinstance(obj, Tensor) and obj.numel() == 1:
- return obj.detach().to(dtype=scalarType, device="cpu", copy=True).view(())
- elif isinstance(obj, Tensor):
- # It is invalid to call ".tensor([...])" with a non-scalar tensor in eager mode
- # >>> torch.tensor([torch.randn(2)])
- # ValueError: only one element tensors can be converted to Python scalars
- #
- # But it is possible with a NumPy array
- # >>> torch.tensor([np.random.uniform(size=(2,))]).shape
- # torch.Size([1, 2])
- return obj.detach().to(dtype=scalarType, device="cpu", copy=True)
- elif isinstance(obj, Number):
- return torch.scalar_tensor(obj, dtype=scalarType)
- # seq can be a list of tensors
- seq = obj
- return torch.stack([_recursive_build(scalarType, item) for item in seq])
- # xref: internal_new_from_data in torch/csrc/utils/tensor_new.cpp
- def _internal_new_from_data(
- options,
- scalar_type,
- device_opt,
- data,
- copy_variables,
- copy_numpy,
- type_inference,
- pin_memory=False,
- ):
- if isinstance(data, torch.Tensor):
- torch._check(
- not pin_memory, lambda: "Can't pin tensor constructed from a variable"
- )
- var = data
- if copy_variables:
- var = var.detach()
- inferred_scalar_type = var.dtype if type_inference else scalar_type
- device = device_opt if device_opt is not None else var.device
- return var.to(
- device=device,
- dtype=inferred_scalar_type,
- non_blocking=False,
- copy=copy_variables,
- )
- # TODO
- if hasattr(data, "__cuda_array_interface__"):
- return NotImplemented
- # TODO: test for numpy input with PyArray_Check
- device = device_opt if device_opt is not None else options["device"]
- inferred_scalar_type = _infer_scalar_type(data) if type_inference else scalar_type
- # NB: Don't need to avoid tracing, as we aren't going to do any manual
- # pointer filling tricks
- if _isStorage(data):
- return NotImplemented
- else:
- if torch.device(device).type == "meta":
- return NotImplemented
- # In the C implementation, we would directly start poking the memory
- # of a freshly allocated CPU tensor. Here, we're going to do an
- # alternate, heinously slow implementation: turn each individual
- # scalar into a tensor, and then repeatedly cat them together
- tensor = _recursive_build(inferred_scalar_type, data)
- tensor = tensor.to(device, inferred_scalar_type, non_blocking=False, copy=False)
- # NB: lift_fresh is not needed, because we built the tensor from scalars
- # guaranteeing a fresh tensor in this case
- return tensor
- # xref: tensor_ctor in torch/csrc/utils/tensor_new.cpp
- def tensor(data, *, dtype=None, device=None, pin_memory=False, requires_grad=False):
- # TODO (or not): support names kwarg
- if isinstance(data, torch.Tensor):
- warnings.warn(
- "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() "
- "or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor)"
- )
- type_inference = dtype is None
- new_tensor = _internal_new_from_data(
- # device="cpu" because that's what you get with torch.tensor(2) no
- # device by default
- {"device": "cpu"}, # TODO: use torch.get_default_tensor_type
- dtype if dtype is not None else torch.get_default_dtype(),
- device,
- data,
- copy_variables=True,
- copy_numpy=True,
- type_inference=type_inference,
- pin_memory=pin_memory,
- )
- new_tensor.detach_()
- if requires_grad:
- new_tensor.requires_grad_(requires_grad)
- return new_tensor
- # Views
- # We can't model these as above, as the pattern of doing `op(a, out=a)` does not work for a view function
- # given that it does not reshape the input (it just copies the result into it)
- # squeeze_ = _make_inplace(squeeze)
- # t_ = _make_inplace(t)
- # transpose_ = _make_inplace(transpose)
- # unsqueeze_ = _make_inplace(unsqueeze)
- import torch._refs._conversions
- import torch._refs.fft
- import torch._refs.linalg
- import torch._refs.nn.functional
- import torch._refs.special
|