__init__.py 203 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492
  1. # mypy: allow-untyped-defs
  2. import builtins
  3. import collections
  4. import inspect
  5. import itertools
  6. import math
  7. import operator
  8. import warnings
  9. from collections.abc import Iterable
  10. from enum import Enum
  11. from functools import partial, reduce, singledispatch, wraps
  12. from typing import Any, Callable, Dict, List, Optional, overload, Sequence, Tuple, Union
  13. import torch
  14. import torch._prims as prims
  15. import torch._prims_common as utils
  16. from torch import sym_float, sym_int
  17. from torch._prims_common import (
  18. BoolLike,
  19. DeviceLikeType,
  20. Dim,
  21. DimsSequenceType,
  22. DimsType,
  23. dtype_to_type,
  24. ELEMENTWISE_TYPE_PROMOTION_KIND,
  25. FloatLike,
  26. FloatWithoutSymFloat,
  27. IntLike,
  28. is_weakly_lesser_type,
  29. Number,
  30. NumberType,
  31. RealNumberType,
  32. REDUCTION_OUTPUT_TYPE_KIND,
  33. ShapeType,
  34. StrideType,
  35. TensorLike,
  36. TensorLikeType,
  37. TensorOrNumberLikeType,
  38. TensorSequenceType,
  39. )
  40. from torch._prims_common.wrappers import (
  41. _maybe_convert_to_dtype,
  42. _maybe_resize_out,
  43. _safe_copy_out,
  44. elementwise_type_promotion_wrapper,
  45. elementwise_unary_scalar_wrapper,
  46. out_wrapper,
  47. )
  48. # Experimental module containing prototype Python references for existing
  49. # PyTorch operations.
  50. __all__ = [
  51. #
  52. # Elementwise Unary References
  53. #
  54. "abs",
  55. "acos",
  56. "acosh",
  57. "asinh",
  58. "asin",
  59. "atan",
  60. "atanh",
  61. "bitwise_not",
  62. # "cbrt", # No corresponding torch operation
  63. "ceil",
  64. "conj_physical",
  65. "cos",
  66. "cosh",
  67. "count_nonzero",
  68. "deg2rad",
  69. "digamma",
  70. "erf",
  71. "erfinv",
  72. "erfc",
  73. "exp",
  74. "expm1",
  75. "exponential",
  76. "exp2",
  77. "fill",
  78. "fill_",
  79. "floor",
  80. "frac",
  81. "geometric",
  82. "index_add",
  83. "index_copy",
  84. "index_copy_",
  85. "index_select",
  86. "index_fill",
  87. "index_fill_",
  88. "isfinite",
  89. "isinf",
  90. "isposinf",
  91. "isneginf",
  92. "isnan",
  93. "isreal",
  94. "i0",
  95. "lerp",
  96. "lgamma",
  97. "log",
  98. "log1p",
  99. "log2",
  100. "log10",
  101. "log_normal",
  102. "log_softmax",
  103. "mvlgamma",
  104. "norm",
  105. "normal",
  106. "nan_to_num",
  107. "neg",
  108. "positive",
  109. "rad2deg",
  110. "reciprocal",
  111. "round", # TODO: model kwargs
  112. "sigmoid",
  113. "sgn",
  114. "sign",
  115. "signbit",
  116. "sin",
  117. "sinc",
  118. "sinh",
  119. "softmax",
  120. "sqrt",
  121. "square",
  122. "tan",
  123. "tanh",
  124. "trace",
  125. "trunc",
  126. #
  127. # Elementwise Binary References
  128. #
  129. "add",
  130. "atan2",
  131. "bitwise_and",
  132. "bitwise_left_shift",
  133. "bitwise_or",
  134. "bitwise_right_shift",
  135. "bitwise_xor",
  136. "clamp_min",
  137. "clamp_max",
  138. "copysign",
  139. "div",
  140. "eq",
  141. "float_power",
  142. "floor_divide",
  143. "fmax",
  144. "fmin",
  145. "fmod",
  146. "gcd",
  147. "ge",
  148. "gt",
  149. "heaviside",
  150. "hypot",
  151. "igamma",
  152. "igammac",
  153. "imag",
  154. "isclose",
  155. "lcm",
  156. # 'ldexp',
  157. "le",
  158. "logaddexp",
  159. "logaddexp2",
  160. "logical_and",
  161. "logical_not",
  162. "logical_or",
  163. "logical_xor",
  164. "logsumexp",
  165. "lt",
  166. # 'max', # implement with reductions
  167. "maximum",
  168. # 'min', # implement with reductions
  169. "minimum",
  170. "mul",
  171. "ne",
  172. "nextafter",
  173. # 'polar', # abs, cos, sin
  174. "pow",
  175. "real",
  176. "rpow",
  177. "remainder",
  178. "rsub",
  179. "rtruediv",
  180. "rfloordiv",
  181. "sub",
  182. "true_divide",
  183. "trunc_divide",
  184. "xlogy",
  185. #
  186. # Elementwise Ternary References
  187. #
  188. "addcdiv",
  189. "addcmul",
  190. "clamp",
  191. #
  192. # Conditional references
  193. #
  194. "masked_fill",
  195. "masked_fill_",
  196. "where",
  197. #
  198. # Data conversion and movement references
  199. #
  200. "clone",
  201. "copy_to", # TODO: add OpInfo (or implement .to)
  202. "item",
  203. "to",
  204. #
  205. # Reduction ops
  206. #
  207. "all",
  208. "amax",
  209. "amin",
  210. "any",
  211. "cumsum",
  212. "cumprod",
  213. "mean",
  214. "dot",
  215. "vdot",
  216. "std",
  217. "std_mean",
  218. "sum",
  219. "sum_to_size",
  220. "prod",
  221. "var",
  222. "var_mean",
  223. #
  224. # Linear algebra ops
  225. #
  226. "addr",
  227. #
  228. # View & Shape Ops
  229. #
  230. "alias",
  231. "atleast_1d",
  232. "atleast_2d",
  233. "atleast_3d",
  234. "as_strided",
  235. "as_strided_scatter",
  236. "block_diag",
  237. "broadcast_shapes",
  238. "broadcast_tensors",
  239. "broadcast_to",
  240. "cat",
  241. "chunk",
  242. "column_stack",
  243. "conj",
  244. "constant_pad_nd",
  245. "contiguous",
  246. "diag_embed",
  247. "diag",
  248. "diagonal",
  249. "diagonal_copy",
  250. "diagonal_scatter",
  251. "dsplit",
  252. "dstack",
  253. "expand",
  254. "expand_as",
  255. "flatten",
  256. "flip",
  257. "fliplr",
  258. "flipud",
  259. "hsplit",
  260. "hstack",
  261. "meshgrid",
  262. "movedim",
  263. "narrow",
  264. "narrow_copy",
  265. "native_group_norm",
  266. "native_layer_norm",
  267. "permute",
  268. "ravel",
  269. "repeat",
  270. "reshape",
  271. "reshape_as",
  272. "roll",
  273. "rot90",
  274. "rsqrt",
  275. "stack",
  276. "swap_axes", # alias for transpose
  277. "squeeze",
  278. "t",
  279. "T",
  280. "take_along_dim",
  281. "tensor_split",
  282. "transpose",
  283. "unfold",
  284. "unfold_copy",
  285. "unsqueeze",
  286. "view",
  287. "view_as",
  288. "vsplit",
  289. "vstack",
  290. "view_as_complex",
  291. "unflatten",
  292. "unbind",
  293. "triu",
  294. "tril",
  295. "triu_indices",
  296. "tril_indices",
  297. #
  298. # Tensor Creation
  299. #
  300. "arange",
  301. "cauchy",
  302. "empty",
  303. "empty_like",
  304. "empty_permuted",
  305. "empty_strided",
  306. "eye",
  307. "full",
  308. "full_like",
  309. "linspace",
  310. "logspace",
  311. "new_empty",
  312. "new_empty_strided",
  313. "new_full",
  314. "new_ones",
  315. "new_zeros",
  316. "ones",
  317. "ones_like",
  318. "randn",
  319. "scalar_tensor",
  320. "zero",
  321. "zeros",
  322. "zeros_like",
  323. #
  324. # Test-related functions
  325. #
  326. "allclose",
  327. "equal",
  328. #
  329. # Statistical operations
  330. #
  331. "bucketize",
  332. #
  333. # Misc
  334. #
  335. "is_complex",
  336. "renorm",
  337. "stft",
  338. "istft",
  339. ]
  340. Tensor = torch.Tensor
  341. DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
  342. aten = torch._ops.ops.aten
  343. # Note that the docstrings for the public methods from this file are in
  344. # torch/_torch_docs.py
  345. def is_noncontiguous_supported(device):
  346. if device is not None and device.type == "hpu":
  347. return False
  348. return True
  349. def handle_noncontiguous_outputs(input_tlist, output):
  350. device = None
  351. from torch._subclasses.fake_tensor import FakeTensor
  352. for t in input_tlist:
  353. if isinstance(t, FakeTensor):
  354. device = t.fake_device
  355. break
  356. if not is_noncontiguous_supported(device):
  357. output = output.contiguous()
  358. return output
  359. def _broadcast_shapes(*_shapes):
  360. from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
  361. shapes = tuple(
  362. (x,) if isinstance(x, IntLike) else x
  363. for x in filter(lambda x: x is not None, _shapes)
  364. )
  365. # Short-circuits on no input
  366. if len(shapes) == 0:
  367. return None
  368. # Type checking
  369. # TODO: make common validations available as utils
  370. for shape in shapes:
  371. assert isinstance(shape, Sequence)
  372. # Computes common shape
  373. common_shape = [
  374. 1,
  375. ] * reduce(max, (len(shape) for shape in shapes))
  376. for arg_idx, shape in enumerate(shapes):
  377. for idx in range(-1, -1 - len(shape), -1):
  378. if guard_size_oblivious(common_shape[idx] == 1):
  379. if shape[idx] < 0:
  380. raise ValueError(
  381. "Attempting to broadcast a dimension with negative length!"
  382. )
  383. common_shape[idx] = shape[idx]
  384. elif guard_size_oblivious(shape[idx] != 1):
  385. if common_shape[idx] != shape[idx]:
  386. raise RuntimeError(
  387. f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! "
  388. f"Mismatching argument at index {arg_idx} had {shape}; but expected shape "
  389. f"should be broadcastable to {common_shape}"
  390. )
  391. return common_shape
  392. def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
  393. # Computes common shape
  394. common_shape = _broadcast_shapes(
  395. *(t.shape if isinstance(t, TensorLike) else None for t in args)
  396. )
  397. def __maybe_broadcast(x, shape):
  398. if x is None:
  399. return None
  400. elif isinstance(x, Number):
  401. return x
  402. elif isinstance(x, TensorLike):
  403. if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
  404. return x
  405. if not utils.same_shape(x.shape, common_shape):
  406. return x.expand(common_shape)
  407. return x
  408. else:
  409. raise RuntimeError(
  410. "Unexpected type when broadcasting: " + str(type(x)) + "!"
  411. )
  412. return tuple(__maybe_broadcast(x, common_shape) for x in args)
  413. # Utilities should come BEFORE this import
  414. from torch._decomp import register_decomposition
  415. #
  416. # Elementwise unary references
  417. #
  418. infer_aten_op = object()
  419. # TODO: add type promotion support
  420. def _make_elementwise_unary_reference(
  421. type_promotion_kind,
  422. *,
  423. aten_op=infer_aten_op,
  424. extra_meta=None,
  425. ) -> Callable:
  426. def inner(prim: Callable):
  427. nonlocal aten_op
  428. @wraps(prim)
  429. @out_wrapper()
  430. @elementwise_unary_scalar_wrapper
  431. @elementwise_type_promotion_wrapper(
  432. type_promoting_args=("a",),
  433. type_promotion_kind=type_promotion_kind,
  434. )
  435. def _ref(a: TensorLikeType) -> TensorLikeType:
  436. if extra_meta is not None:
  437. extra_meta(a)
  438. output = prim(a)
  439. return handle_noncontiguous_outputs([a], output)
  440. if aten_op is infer_aten_op:
  441. aten_op = utils.get_aten_op(prim, prim.__name__)
  442. if aten_op is not None:
  443. register_decomposition(aten_op)(_ref)
  444. return _ref
  445. return inner
  446. def _make_alias(fn, name):
  447. """
  448. This function defines an alias of another function and sets its __name__ argument.
  449. It also sets its __module__ argument to the module of the caller.
  450. Note that when naively doing `alias = fn`, we have that `alias.__name__ == "fn"`, and
  451. `alias.__module__ == fn.__module__`.
  452. """
  453. def _fn(*args, **kwargs):
  454. return fn(*args, **kwargs)
  455. _fn.__name__ = name
  456. _fn.__module__ = inspect.currentframe().f_back.f_globals["__name__"] # type: ignore[union-attr]
  457. return _fn
  458. def _make_inplace(fn):
  459. """
  460. Given a function with out variant (i.e. using `out_wrapper()), it returns its in-place variant
  461. See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-do-in-place-operations-work-in-pytorch
  462. """
  463. # nb. We use the name of the first argument used in the unary references
  464. @wraps(fn)
  465. def _fn(a, *args, **kwargs):
  466. return fn(a, *args, out=a, **kwargs)
  467. inplace_name = f"{fn.__name__}_"
  468. _fn.__name__ = inplace_name
  469. _fn = register_decomposition(getattr(aten, inplace_name))(_fn)
  470. # We access the __all__ attribute of the module where fn is defined
  471. # There may be a cleaner way of doing this...
  472. from inspect import getmodule
  473. _all = getmodule(fn).__all__ # type: ignore[union-attr]
  474. if inplace_name not in _all:
  475. _all.append(inplace_name)
  476. return _fn
  477. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT)
  478. def abs(a):
  479. return prims.abs(a)
  480. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  481. def acos(a):
  482. return prims.acos(a)
  483. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  484. def acosh(a):
  485. return prims.acosh(a)
  486. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  487. def asin(a):
  488. return prims.asin(a)
  489. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  490. def asinh(a):
  491. return prims.asinh(a)
  492. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  493. def atan(a):
  494. return prims.atan(a)
  495. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  496. def atanh(a):
  497. return prims.atanh(a)
  498. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  499. def bitwise_not(a):
  500. return prims.bitwise_not(a)
  501. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  502. def ceil(a):
  503. return prims.ceil(a)
  504. @register_decomposition(aten.is_complex)
  505. def is_complex(input: TensorLikeType):
  506. return utils.is_complex_dtype(input.dtype)
  507. @register_decomposition(aten.conj_physical)
  508. @out_wrapper()
  509. def conj_physical(input: TensorLikeType):
  510. if not utils.is_complex_dtype(input.dtype):
  511. return input
  512. return prims.conj_physical(input)
  513. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  514. def cos(a):
  515. return prims.cos(a)
  516. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  517. def cosh(a):
  518. return prims.cosh(a)
  519. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  520. def digamma(a):
  521. return prims.digamma(a)
  522. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  523. def erf(a):
  524. return prims.erf(a)
  525. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  526. def erfinv(a):
  527. return prims.erf_inv(a)
  528. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  529. def erfc(a):
  530. return prims.erfc(a)
  531. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  532. def exp(a):
  533. return prims.exp(a)
  534. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  535. def expm1(a):
  536. return prims.expm1(a)
  537. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  538. def exp2(a):
  539. return prims.exp2(a)
  540. # Fill has its own implementation because it has a value parameter
  541. # CompositeImplicitAutograd - don't register decomp
  542. @out_wrapper()
  543. @elementwise_type_promotion_wrapper(
  544. type_promoting_args=("a,"),
  545. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
  546. )
  547. def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
  548. assert isinstance(a, TensorLike)
  549. assert isinstance(value, Number)
  550. python_type = utils.dtype_to_type(a.dtype)
  551. if not utils.is_weakly_lesser_type(type(value), python_type):
  552. msg = f"value argument of type {type(value)} cannot be safely cast to type {python_type}!"
  553. raise ValueError(msg)
  554. return prims.fill(a, value)
  555. def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType:
  556. r = prims.fill(a, value)
  557. prims.copy_to(a, r)
  558. return a
  559. @register_decomposition(aten.zero)
  560. @out_wrapper()
  561. def zero(input: TensorLikeType) -> TensorLikeType:
  562. return torch.zeros_like(input)
  563. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  564. def floor(a):
  565. return prims.floor(a)
  566. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  567. def frac(x: TensorLikeType) -> TensorLikeType:
  568. trunc_x = torch.mul(torch.floor(torch.abs(x)), torch.sign(x))
  569. return torch.sub(x, trunc_x)
  570. # imag does not use _make_elementwise_unary_reference because it does not support out
  571. def imag(a: TensorLikeType) -> TensorLikeType:
  572. assert isinstance(a, TensorLike)
  573. torch._check(
  574. utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors."
  575. )
  576. return prims.imag(a)
  577. @_make_elementwise_unary_reference(
  578. ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  579. aten_op=None, # CompositeImplicitAutograd
  580. )
  581. def isfinite(a: TensorLikeType) -> TensorLikeType:
  582. if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype):
  583. return prims.isfinite(a)
  584. return ones_like(a, dtype=torch.bool)
  585. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
  586. def isinf(a: TensorLikeType) -> TensorLikeType:
  587. if utils.is_complex_dtype(a.dtype):
  588. return torch.logical_or(isinf(torch.real(a)), isinf(torch.imag(a)))
  589. if utils.is_float_dtype(a.dtype):
  590. return torch.abs(a) == float("inf")
  591. return torch.zeros_like(a, dtype=torch.bool)
  592. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
  593. def isposinf(a: TensorLikeType) -> TensorLikeType:
  594. torch._check(
  595. not utils.is_complex_dtype(a.dtype),
  596. lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}",
  597. )
  598. if utils.is_float_dtype(a.dtype):
  599. return a == float("inf")
  600. return torch.zeros_like(a, dtype=torch.bool)
  601. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
  602. def isneginf(a: TensorLikeType) -> TensorLikeType:
  603. torch._check(
  604. not utils.is_complex_dtype(a.dtype),
  605. lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}",
  606. )
  607. if utils.is_float_dtype(a.dtype):
  608. return a == float("-inf")
  609. return torch.zeros_like(a, dtype=torch.bool)
  610. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
  611. def isnan(a: TensorLikeType) -> TensorLikeType:
  612. return prims.ne(a, a)
  613. # alias
  614. mvlgamma = _make_alias(torch.special.multigammaln, "mvlgamma") # type: ignore[has-type]
  615. @_make_elementwise_unary_reference(
  616. ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  617. aten_op=None, # CompositeImplicitAutograd
  618. )
  619. def isreal(a: TensorLikeType) -> TensorLikeType:
  620. if utils.is_complex_dtype(a.dtype):
  621. return torch.imag(a) == 0
  622. return torch.ones_like(a, dtype=torch.bool)
  623. # TODO: if this is special maybe it should be defined there and imported here?
  624. @_make_elementwise_unary_reference(
  625. ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.i0
  626. )
  627. def i0(a):
  628. return prims.bessel_i0(a)
  629. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  630. def lgamma(a):
  631. return prims.lgamma(a)
  632. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  633. def log(a):
  634. return prims.log(a)
  635. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  636. def log1p(a):
  637. return prims.log1p(a)
  638. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  639. def log2(a):
  640. return prims.log2(a)
  641. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  642. def log10(a):
  643. return prims.log10(a)
  644. # CompositeImplicitAutograd - don't register decomp
  645. @out_wrapper()
  646. def log_softmax(
  647. a: TensorLikeType,
  648. dim: int,
  649. dtype: Optional[torch.dtype] = None,
  650. ) -> TensorLikeType:
  651. result_dtype = dtype or a.dtype
  652. computation_dtype = utils.get_computation_dtype(result_dtype)
  653. a_ = _maybe_convert_to_dtype(a, computation_dtype)
  654. return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value]
  655. @register_decomposition(aten.logsumexp)
  656. @out_wrapper()
  657. @elementwise_type_promotion_wrapper(
  658. type_promoting_args=("self",),
  659. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  660. )
  661. def logsumexp(
  662. self: TensorLikeType, dim: DimsType, keepdim: bool = False
  663. ) -> TensorLikeType:
  664. if not isinstance(dim, Iterable):
  665. dim = (dim,)
  666. if self.numel() == 0:
  667. return torch.sum(torch.exp(self), dim, keepdim).log()
  668. maxes = torch.amax(self, dim, keepdim=True)
  669. maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0)
  670. maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim)
  671. result = torch.sum(torch.exp(self - maxes), dim, keepdim)
  672. return result.log().add(maxes_squeezed)
  673. @register_decomposition(aten.nan_to_num)
  674. @out_wrapper()
  675. def nan_to_num(
  676. a: TensorLikeType,
  677. nan: Optional[NumberType] = 0.0,
  678. posinf: Optional[NumberType] = None,
  679. neginf: Optional[NumberType] = None,
  680. ) -> TensorLikeType:
  681. assert isinstance(a, TensorLike)
  682. if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
  683. return a.clone()
  684. if nan is None:
  685. nan = 0.0
  686. if posinf is None:
  687. posinf = torch.finfo(a.dtype).max
  688. if neginf is None:
  689. neginf = torch.finfo(a.dtype).min
  690. result = torch.where(torch.isnan(a), nan, a) # type: ignore[call-overload]
  691. result = torch.where(torch.isneginf(a), neginf, result) # type: ignore[call-overload]
  692. result = torch.where(torch.isposinf(a), posinf, result) # type: ignore[call-overload]
  693. return result
  694. def _neg_meta(a: TensorLikeType):
  695. torch._check(
  696. a.dtype is not torch.bool,
  697. lambda: (
  698. "Negation, the `-` operator, on a bool tensor is not supported. "
  699. "If you are trying to invert a mask, use the `~` or `logical_not()` "
  700. "operator instead."
  701. ),
  702. )
  703. @_make_elementwise_unary_reference(
  704. ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, extra_meta=_neg_meta
  705. )
  706. def neg(a):
  707. return prims.neg(a)
  708. # positive does not use _make_elementwise_unary_reference because it does not support out
  709. # CompositeImplicitAutograd - don't register decomp
  710. def positive(a: TensorLikeType) -> TensorLikeType:
  711. assert isinstance(a, TensorLike)
  712. if a.dtype is torch.bool:
  713. msg = "positive does not support bool tensors."
  714. raise RuntimeError(msg)
  715. return a
  716. # real does not use _make_elementwise_unary_reference because it does not support out
  717. def real(a: TensorLikeType) -> TensorLikeType:
  718. assert isinstance(a, TensorLike)
  719. if utils.is_complex_dtype(a.dtype):
  720. return prims.real(a)
  721. return a
  722. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  723. def reciprocal(a):
  724. return prims.reciprocal(a)
  725. @register_decomposition(aten.round)
  726. @out_wrapper()
  727. @elementwise_type_promotion_wrapper(
  728. type_promoting_args=("a",),
  729. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  730. )
  731. def round(a: TensorLikeType, *, decimals: int = 0) -> TensorLikeType:
  732. if decimals == 0:
  733. return prims.round(a)
  734. else:
  735. ten_pow = 10**decimals
  736. ten_neg_pow = 10 ** (-decimals)
  737. return prims.mul(prims.round(prims.mul(a, ten_pow)), ten_neg_pow)
  738. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  739. def rsqrt(a):
  740. return prims.rsqrt(a)
  741. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  742. def sigmoid(a: TensorLikeType) -> TensorLikeType:
  743. return true_divide(1, add(1, exp(neg(a))))
  744. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  745. def sgn(a):
  746. if utils.is_complex_dtype(a.dtype):
  747. a_abs = a.abs()
  748. return torch.where(a_abs == 0, 0, a / a_abs)
  749. else:
  750. return a.sign()
  751. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  752. def sign(a):
  753. return prims.sign(a)
  754. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
  755. def signbit(a):
  756. return prims.signbit(a)
  757. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  758. def sin(a):
  759. return prims.sin(a)
  760. # Autograd note: This will give the right first derivative at zero (by chance),
  761. # but not the right second derivative
  762. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  763. def sinc(a):
  764. a = math.pi * a
  765. return torch.where(a == 0, 1, torch.sin(a) / a)
  766. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  767. def sinh(a):
  768. return prims.sinh(a)
  769. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  770. def sqrt(a):
  771. return prims.sqrt(a)
  772. @_make_elementwise_unary_reference(
  773. ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
  774. aten_op=None, # CompositeImplicitAutograd,
  775. )
  776. def square(a: TensorLikeType) -> TensorLikeType:
  777. return mul(a, a)
  778. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  779. def tan(a):
  780. return prims.tan(a)
  781. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  782. def tanh(a):
  783. return prims.tanh(a)
  784. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  785. def trunc(a):
  786. return prims.trunc(a)
  787. # TODO: register this as a real ref/decomposition once TorchInductor supports complex!
  788. def view_as_complex(self: TensorLikeType) -> TensorLikeType:
  789. input_dtype = self.dtype
  790. torch._check(
  791. utils.is_float_dtype(input_dtype),
  792. lambda: f"view_as_complex is only supported for floating point"
  793. f"tensors, but got a tensor of scalar type: {input_dtype}",
  794. )
  795. sizes = self.size()
  796. torch._check(
  797. len(sizes) != 0,
  798. lambda: "Input tensor must have one or more dimensions",
  799. )
  800. torch._check(
  801. sizes[-1] == 2,
  802. lambda: "Tensor must have a last dimension of size 2",
  803. )
  804. old_strides = self.stride()
  805. torch._check(
  806. old_strides[-1] == 1,
  807. lambda: "Tensor must have a last dimension with stride 1",
  808. )
  809. dims = old_strides[:-1]
  810. torch._check(
  811. py_all(stride % 2 == 0 for stride in dims),
  812. lambda: "Tensor must have a stride divisible by 2 for all but last dimension",
  813. )
  814. torch._check(
  815. self.storage_offset() % 2 == 0,
  816. lambda: "Tensor must have a storage_offset divisible by 2",
  817. )
  818. return prims.view_element_type(
  819. self, utils.corresponding_complex_dtype(input_dtype)
  820. ).squeeze(-1)
  821. def _make_elementwise_binary_reference(
  822. type_promotion_kind,
  823. aten_op=infer_aten_op,
  824. name=None,
  825. has_out=True,
  826. supports_lhs_python_scalar=True,
  827. supports_rhs_python_scalar=True,
  828. supports_two_python_scalars=False,
  829. should_register_decomposition=True,
  830. ) -> Callable:
  831. def inner(prim: Callable):
  832. nonlocal aten_op, name
  833. if name is None:
  834. name = prim.__name__
  835. @wraps(prim)
  836. @elementwise_type_promotion_wrapper(
  837. type_promoting_args=("a", "b"),
  838. type_promotion_kind=type_promotion_kind,
  839. )
  840. def _ref(
  841. a: Union[Tensor, NumberType],
  842. b: Union[Tensor, NumberType],
  843. ) -> Tensor:
  844. torch._check_value(
  845. supports_lhs_python_scalar or not isinstance(a, Number),
  846. lambda: f"{name}: Received a lhs Python scalar to an elementwise binary "
  847. "operation that does not accept lhs scalars!",
  848. )
  849. torch._check_value(
  850. supports_rhs_python_scalar or not isinstance(b, Number),
  851. lambda: f"{name}: Received a rhs Python scalar to an elementwise binary "
  852. "operation that does not accept rhs scalars!",
  853. )
  854. torch._check_value(
  855. supports_two_python_scalars
  856. or not (isinstance(a, Number) and isinstance(b, Number)),
  857. lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!",
  858. )
  859. a, b = _maybe_broadcast(a, b)
  860. output = prim(a, b)
  861. return handle_noncontiguous_outputs([a, b], output)
  862. if has_out:
  863. _ref = out_wrapper()(_ref)
  864. _ref.__name__ = name
  865. if aten_op is infer_aten_op:
  866. aten_op = utils.get_aten_op(prim, name)
  867. if aten_op is not None and should_register_decomposition:
  868. register_decomposition(aten_op)(_ref)
  869. return _ref
  870. return inner
  871. # Add has its own implementation because it has an alpha argument
  872. @register_decomposition(aten.add)
  873. @out_wrapper()
  874. @elementwise_type_promotion_wrapper(
  875. type_promoting_args=("a", "b"),
  876. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  877. )
  878. def add(
  879. a: Union[TensorLikeType, NumberType],
  880. b: Union[TensorLikeType, NumberType],
  881. *,
  882. alpha: Optional[NumberType] = None,
  883. ):
  884. """
  885. Reference implementation of torch.add
  886. """
  887. a, b = _maybe_broadcast(a, b)
  888. if alpha is not None:
  889. dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr]
  890. python_type = utils.dtype_to_type(dtype)
  891. if python_type != bool and not utils.is_weakly_lesser_type(
  892. type(alpha), python_type
  893. ):
  894. msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
  895. raise ValueError(msg)
  896. if isinstance(b, TensorLike):
  897. b = prims.mul(b, alpha)
  898. else:
  899. b = b * alpha
  900. output = prims.add(a, b)
  901. return handle_noncontiguous_outputs([a, b], output)
  902. @_make_elementwise_binary_reference(
  903. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  904. supports_lhs_python_scalar=False,
  905. supports_rhs_python_scalar=False,
  906. )
  907. def atan2(a, b):
  908. return prims.atan2(a, b)
  909. @_make_elementwise_binary_reference(
  910. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  911. )
  912. def bitwise_and(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  913. return prims.bitwise_and(a, b)
  914. @_make_elementwise_binary_reference(
  915. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  916. )
  917. def bitwise_left_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  918. return prims.shift_left(a, b)
  919. @_make_elementwise_binary_reference(
  920. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  921. )
  922. def bitwise_or(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  923. return prims.bitwise_or(a, b)
  924. @_make_elementwise_binary_reference(
  925. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  926. )
  927. def bitwise_right_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  928. return prims.shift_right_arithmetic(a, b)
  929. @_make_elementwise_binary_reference(
  930. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  931. )
  932. def bitwise_xor(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  933. return prims.bitwise_xor(a, b)
  934. @_make_elementwise_binary_reference(
  935. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  936. supports_lhs_python_scalar=False,
  937. )
  938. def copysign(
  939. a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
  940. ):
  941. if isinstance(b, Number) and isinstance(a, Tensor):
  942. b = scalar_tensor(b, dtype=a.dtype, device=a.device)
  943. elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
  944. msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!"
  945. raise RuntimeError(msg)
  946. return where(signbit(b), neg(abs(a)), abs(a))
  947. # complex = _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  948. @register_decomposition(aten.div)
  949. @out_wrapper()
  950. def div(
  951. a: Union[TensorLikeType, NumberType],
  952. b: Union[TensorLikeType, NumberType],
  953. *,
  954. rounding_mode: Optional[str] = None,
  955. ):
  956. """
  957. Reference implementation of torch.div
  958. """
  959. if rounding_mode is None:
  960. return true_divide(a, b)
  961. elif rounding_mode == "trunc":
  962. return trunc_divide(a, b)
  963. elif rounding_mode == "floor":
  964. return floor_divide(a, b)
  965. else:
  966. msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
  967. raise ValueError(msg)
  968. @_make_elementwise_binary_reference(
  969. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  970. supports_lhs_python_scalar=False,
  971. )
  972. def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  973. return prims.eq(a, b)
  974. @_make_elementwise_binary_reference(
  975. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
  976. )
  977. def pow(
  978. a: Union[TensorLikeType, NumberType],
  979. b: Union[TensorLikeType, NumberType],
  980. ) -> TensorLikeType:
  981. assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType)
  982. if isinstance(b, Number):
  983. if b == 1.0:
  984. return a.clone() # type: ignore[return-value,union-attr]
  985. elif b == 2.0:
  986. return a * a # type: ignore[return-value]
  987. elif b == 0.5:
  988. return torch.sqrt(a) # type: ignore[arg-type]
  989. elif isinstance(a, Number):
  990. if a == 1.0:
  991. return torch.fill(b, True)
  992. if a == 2.0 and (
  993. utils.is_float_dtype(b.dtype) or utils.is_complex_dtype(b.dtype)
  994. ):
  995. return torch.exp2(b)
  996. return prims.pow(a, b)
  997. # Float power has its own implementation because it has unique type promotion.
  998. # CompositeImplicitAutograd - don't register decomp
  999. @out_wrapper()
  1000. def float_power(
  1001. a: Union[TensorLikeType, NumberType],
  1002. b: Union[TensorLikeType, NumberType],
  1003. ) -> Tensor:
  1004. if isinstance(a, Number) and isinstance(b, Number):
  1005. raise ValueError(
  1006. "Receive two Number inputs to an elementwise binary operation!"
  1007. )
  1008. # Handles type promotion
  1009. dtype = utils.get_higher_dtype(a, b)
  1010. assert dtype is not None
  1011. if utils.is_complex_dtype(dtype):
  1012. dtype = torch.complex128
  1013. else:
  1014. dtype = torch.float64
  1015. # Float power has the following contiguous cast behavior to be
  1016. # consistent with its C++ impl
  1017. a = _maybe_convert_to_dtype(a, dtype)
  1018. b = _maybe_convert_to_dtype(b, dtype)
  1019. a, b = _maybe_broadcast(a, b)
  1020. return pow(a, b)
  1021. # >>> a = torch.tensor(-0.2500, dtype=torch.float64)
  1022. # tensor(-0.250000000000000, dtype=torch.float64)
  1023. #
  1024. # >>> b = torch.tensor(-0.0010, dtype=torch.float64)
  1025. # tensor(-0.001000000000000, dtype=torch.float64)
  1026. #
  1027. # Note: In this case, casting float to double will expand the float mantissa with zeros,
  1028. # while creating a double generates a distinct mantissa.
  1029. # >>> torch.tensor(-0.001).to(dtype=torch.float64)
  1030. # tensor(-0.001000000047497, dtype=torch.float64)
  1031. #
  1032. # Floor Division
  1033. # The difference is caused because torch.remainder(a, b) = -0.001.
  1034. #
  1035. # >>> torch.floor(torch.true_divide(a, b))
  1036. # tensor(250., dtype=torch.float64)
  1037. #
  1038. # >>> torch.div(a, b, rounding_mode='floor')
  1039. # tensor(249., dtype=torch.float64)
  1040. #
  1041. # Definition: a // b = (a - remainder(a, b)) / b
  1042. # >>> torch.true_divide(torch.sub(a, torch.remainder(a, b)), b)
  1043. # tensor(249., dtype=torch.float64)
  1044. #
  1045. # For reference, see CPython's implementation:
  1046. # https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
  1047. @_make_elementwise_binary_reference(
  1048. type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1049. supports_two_python_scalars=True,
  1050. should_register_decomposition=False,
  1051. )
  1052. def floor_divide(
  1053. a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
  1054. ):
  1055. # Wrap scalars because some references only accept tensor arguments.
  1056. if isinstance(a, Number) and isinstance(b, Number):
  1057. a = scalar_tensor(a)
  1058. b = scalar_tensor(b)
  1059. elif isinstance(b, Number) and isinstance(a, Tensor):
  1060. b = scalar_tensor(b, dtype=a.dtype, device=a.device)
  1061. elif isinstance(a, Number) and isinstance(b, Tensor):
  1062. a = scalar_tensor(a, dtype=b.dtype, device=b.device)
  1063. elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
  1064. if a.device == torch.device("cpu"):
  1065. msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!"
  1066. raise RuntimeError(msg)
  1067. else:
  1068. b = prims.device_put(b, device=a.device)
  1069. assert isinstance(a, Tensor) and isinstance(b, Tensor)
  1070. dtype = a.dtype
  1071. if utils.is_float_dtype(dtype):
  1072. return _floor_divide_float(a, b)
  1073. elif utils.is_integer_dtype(dtype):
  1074. return _floor_divide_integer(a, b)
  1075. else:
  1076. torch._check(False, lambda: f"{dtype} not supported for floor_divide")
  1077. def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor:
  1078. a, b = _maybe_broadcast(a, b)
  1079. if not a.dtype.is_signed:
  1080. return prims.div(a, b)
  1081. # Convert truncation to flooring:
  1082. offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0)
  1083. return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype)
  1084. def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor:
  1085. mod = fmod(a, b)
  1086. div = true_divide(sub(a, mod), b)
  1087. # Ensure that the remainder has the same sign as denominator
  1088. different_signed_inputs = bitwise_xor(lt(a, 0), lt(b, 0))
  1089. non_zero_remainder = ne(mod, 0)
  1090. mask = bitwise_and(non_zero_remainder, different_signed_inputs)
  1091. div = where(mask, sub(div, 1), div)
  1092. # Map quotient to nearest integer value
  1093. floor_div = floor(div)
  1094. mask = gt(sub(div, floor_div), 0.5)
  1095. floor_div = where(mask, add(floor_div, 1), floor_div)
  1096. basic_div = true_divide(a, b)
  1097. zero_tensor = scalar_tensor(0, dtype=basic_div.dtype, device=basic_div.device)
  1098. # If quotient is zero, copy signbit from true_divide quotient
  1099. floor_div = where(ne(div, 0), floor_div, copysign(zero_tensor, basic_div))
  1100. # If denominator is zero, then follow true_divide behavior
  1101. return where(ne(b, 0), floor_div, basic_div)
  1102. @_make_elementwise_binary_reference(
  1103. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1104. supports_lhs_python_scalar=False,
  1105. supports_rhs_python_scalar=False,
  1106. )
  1107. def fmax(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1108. return prims.fmax(a, b)
  1109. @_make_elementwise_binary_reference(
  1110. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1111. supports_lhs_python_scalar=False,
  1112. supports_rhs_python_scalar=False,
  1113. )
  1114. def fmin(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1115. return prims.fmin(a, b)
  1116. @_make_elementwise_binary_reference(
  1117. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1118. supports_lhs_python_scalar=False,
  1119. supports_rhs_python_scalar=True,
  1120. )
  1121. def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1122. return prims.fmod(a, b)
  1123. @register_decomposition(aten.frexp)
  1124. @out_wrapper("mantissa", "exponent")
  1125. def frexp(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]:
  1126. return torch.return_types.frexp(prims.frexp(self))
  1127. @_make_elementwise_binary_reference(
  1128. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1129. supports_lhs_python_scalar=False,
  1130. supports_rhs_python_scalar=False,
  1131. )
  1132. def gcd(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1133. return prims.gcd(a, b)
  1134. @_make_elementwise_binary_reference(
  1135. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1136. supports_lhs_python_scalar=False,
  1137. )
  1138. def ge(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1139. return prims.ge(a, b)
  1140. @_make_elementwise_binary_reference(
  1141. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1142. supports_lhs_python_scalar=False,
  1143. )
  1144. def gt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1145. return prims.gt(a, b)
  1146. @_make_elementwise_binary_reference(
  1147. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1148. supports_lhs_python_scalar=False,
  1149. supports_rhs_python_scalar=False,
  1150. )
  1151. def heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType:
  1152. input_eq_zero = torch.eq(input, 0)
  1153. input_lt_zero = torch.logical_or(torch.lt(input, 0), torch.isnan(input))
  1154. zeros_and_ones = torch.where(input_lt_zero, 0, 1)
  1155. output = torch.where(input_eq_zero, values, zeros_and_ones)
  1156. return output
  1157. @_make_elementwise_binary_reference(
  1158. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1159. supports_lhs_python_scalar=False,
  1160. supports_rhs_python_scalar=False,
  1161. )
  1162. def hypot(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1163. return prims.hypot(a, b)
  1164. @_make_elementwise_binary_reference(
  1165. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  1166. supports_lhs_python_scalar=False,
  1167. supports_rhs_python_scalar=False,
  1168. )
  1169. def igamma(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1170. return prims.igamma(a, b)
  1171. @_make_elementwise_binary_reference(
  1172. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  1173. supports_lhs_python_scalar=False,
  1174. supports_rhs_python_scalar=False,
  1175. )
  1176. def igammac(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1177. return prims.igammac(a, b)
  1178. def _check_close_args(
  1179. name: str,
  1180. a: TensorLikeType,
  1181. b: TensorLikeType,
  1182. rtol: float,
  1183. atol: float,
  1184. ) -> None:
  1185. torch._check_value(
  1186. a.dtype == b.dtype,
  1187. lambda: f"{name}: Attempting to compare tensors of different dtypes {a.dtype} and {b.dtype}!",
  1188. )
  1189. torch._check(
  1190. rtol >= 0,
  1191. lambda: f"{name}: rtol must be greater than or equal to zero, but got {rtol}!",
  1192. )
  1193. torch._check(
  1194. atol >= 0,
  1195. lambda: f"{name}: atol must be greater than or equal to zero, but got {atol}!",
  1196. )
  1197. # CompositeImplicitAutograd - don't register decomp
  1198. def isclose(
  1199. a: TensorLikeType,
  1200. b: TensorLikeType,
  1201. rtol: float = 1e-05,
  1202. atol: float = 1e-08,
  1203. equal_nan: bool = False,
  1204. ) -> TensorLikeType:
  1205. _check_close_args(name="torch.isclose", a=a, b=b, rtol=rtol, atol=atol)
  1206. close = eq(a, b)
  1207. if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)):
  1208. close = logical_or(close, logical_and(isnan(a), isnan(b)))
  1209. # Note: In case of zero tolerances the closeness inequality degenerates to an equality check.
  1210. # In this case, the short-circuit prevents false positives as detailed in the paragraph below.
  1211. if atol == 0 and rtol == 0:
  1212. return close
  1213. # Note [closeness error computation]
  1214. # atol and rtol are provided as doubles, so the computation
  1215. # rtol * other will produce a float or complex tensor.
  1216. # When the difference (self - other) is compared to it then the
  1217. # tensor representing the difference will also be cast to float or complex.
  1218. # However, since (self - other) in uint8 is very likely to produce a
  1219. # negative value, this moves the cast forward so the difference is
  1220. # always computed in a float or complex type.
  1221. # If the values of the integer tensors cannot be exactly represented
  1222. # by the default scalar type then this may cause an incorrect result.
  1223. if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype):
  1224. a = prims.convert_element_type(a, torch.get_default_dtype())
  1225. b = prims.convert_element_type(b, torch.get_default_dtype())
  1226. allowed_error = add(atol, abs(mul(b, rtol)))
  1227. actual_error = abs(sub(a, b))
  1228. # Computes finite closeness
  1229. result = logical_or(
  1230. close, logical_and(isfinite(actual_error), le(actual_error, allowed_error))
  1231. )
  1232. return result
  1233. @_make_elementwise_binary_reference(
  1234. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1235. supports_lhs_python_scalar=False,
  1236. supports_rhs_python_scalar=False,
  1237. )
  1238. def lcm(a: TensorLikeType, b: TensorLikeType):
  1239. dtype = a.dtype
  1240. # promoting to int32 to maintain 100% consistency with C++ and to
  1241. # prevent overflow in case of int8 and int16
  1242. promote_to_int = dtype in (torch.int8, torch.int16)
  1243. if promote_to_int:
  1244. a = prims.convert_element_type(a, torch.int32)
  1245. b = prims.convert_element_type(b, torch.int32)
  1246. g = torch.gcd(a, b)
  1247. # Avoid division by zero in case gcd(0, 0) == 0
  1248. g = torch.where(g == 0, 1, g)
  1249. res = torch.abs(prims.div(a, g) * b)
  1250. return res if not promote_to_int else prims.convert_element_type(res, dtype)
  1251. @_make_elementwise_binary_reference(
  1252. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1253. supports_lhs_python_scalar=False,
  1254. )
  1255. def le(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1256. return prims.le(a, b)
  1257. @_make_elementwise_binary_reference(
  1258. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1259. supports_lhs_python_scalar=False,
  1260. supports_rhs_python_scalar=False,
  1261. )
  1262. def logaddexp(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1263. # Nb. this implementation does not distribute the gradients evenly when a == b
  1264. mask = torch.real(a) >= torch.real(b)
  1265. max_ = torch.where(mask, a, b)
  1266. min_ = torch.where(mask, b, a)
  1267. inf_mask = torch.logical_and(
  1268. torch.logical_not(torch.isfinite(torch.real(a))), torch.real(a) == torch.real(b)
  1269. )
  1270. if utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype):
  1271. # are you wondering what this bunch of codes are for? edge cases!
  1272. neg_min_mask = torch.real(min_) < 0
  1273. inf_vals = torch.where(
  1274. neg_min_mask, min_, torch.log(torch.exp(min_) + torch.exp(max_))
  1275. )
  1276. non_nan_vals = torch.where(
  1277. inf_mask, inf_vals, max_ + torch.log1p(torch.exp(min_ - max_))
  1278. )
  1279. # the type for full_like does not include tensor yet
  1280. nan_mask = torch.isnan(min_)
  1281. return torch.where(nan_mask, complex(float("nan"), float("nan")), non_nan_vals) # type: ignore[call-overload]
  1282. else:
  1283. return torch.where(inf_mask, a, max_ + torch.log1p(torch.exp(min_ - max_)))
  1284. @_make_elementwise_binary_reference(
  1285. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1286. supports_lhs_python_scalar=False,
  1287. supports_rhs_python_scalar=False,
  1288. )
  1289. def logaddexp2(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1290. torch._check(
  1291. not (utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype)),
  1292. lambda: "logaddexp2 doesn't support complex dtypes",
  1293. )
  1294. # Nb. this implementation does not distribute the gradients evenly when a == b
  1295. mask = a >= b
  1296. max_ = torch.where(mask, a, b)
  1297. min_ = torch.where(mask, b, a)
  1298. inf_mask = torch.logical_and(torch.isinf(a), a == b)
  1299. inv_log_2 = 1.0 / math.log(2)
  1300. result = max_ + torch.log1p(torch.exp2(min_ - max_)) * inv_log_2
  1301. return torch.where(inf_mask, a, result)
  1302. @_make_elementwise_binary_reference(
  1303. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1304. )
  1305. def logical_and(a: TensorLikeType, b: TensorLikeType):
  1306. if not utils.is_boolean_dtype(a.dtype):
  1307. a = a != 0
  1308. if not utils.is_boolean_dtype(b.dtype):
  1309. b = b != 0
  1310. return a & b
  1311. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
  1312. def logical_not(a: TensorLikeType):
  1313. if not utils.is_boolean_dtype(a.dtype):
  1314. return a == 0
  1315. return ~a
  1316. @_make_elementwise_binary_reference(
  1317. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1318. )
  1319. def logical_or(a: TensorLikeType, b: TensorLikeType):
  1320. if not utils.is_boolean_dtype(a.dtype):
  1321. a = a != 0
  1322. if not utils.is_boolean_dtype(b.dtype):
  1323. b = b != 0
  1324. return bitwise_or(a, b)
  1325. # TODO: skip unnecessary conversion of long to float
  1326. @_make_elementwise_binary_reference(
  1327. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1328. )
  1329. def logical_xor(a: TensorLikeType, b: TensorLikeType):
  1330. if not utils.is_boolean_dtype(a.dtype):
  1331. a = a != 0
  1332. if not utils.is_boolean_dtype(b.dtype):
  1333. b = b != 0
  1334. return a ^ b
  1335. @_make_elementwise_binary_reference(
  1336. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1337. supports_lhs_python_scalar=False,
  1338. )
  1339. def lt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1340. return prims.lt(a, b)
  1341. @_make_elementwise_binary_reference(
  1342. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1343. )
  1344. def maximum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1345. return prims.maximum(a, b)
  1346. @_make_elementwise_binary_reference(
  1347. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1348. )
  1349. def minimum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1350. return prims.minimum(a, b)
  1351. @_make_elementwise_binary_reference(
  1352. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1353. supports_two_python_scalars=True,
  1354. )
  1355. def mul(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1356. return prims.mul(a, b)
  1357. @_make_elementwise_binary_reference(
  1358. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1359. supports_lhs_python_scalar=False,
  1360. )
  1361. def ne(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1362. return prims.ne(a, b)
  1363. @_make_elementwise_binary_reference(
  1364. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
  1365. supports_lhs_python_scalar=False,
  1366. supports_rhs_python_scalar=False,
  1367. )
  1368. def nextafter(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1369. return prims.nextafter(a, b)
  1370. @_make_elementwise_binary_reference(
  1371. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1372. )
  1373. def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1374. return prims.remainder(a, b)
  1375. # reverse sub
  1376. @register_decomposition(aten.rsub)
  1377. @out_wrapper()
  1378. def rsub(
  1379. a: Union[TensorLikeType, NumberType],
  1380. b: Union[TensorLikeType, NumberType],
  1381. alpha: NumberType = 1,
  1382. ):
  1383. if isinstance(a, Number):
  1384. msg = "Received a Number for the first argument, but expected a Tensor"
  1385. raise ValueError(msg)
  1386. return torch.sub(b, a, alpha=alpha)
  1387. # TODO: consider refactoring this with add impl
  1388. # sub has its own implementation because it has an alpha argument
  1389. @register_decomposition(aten.sub)
  1390. @out_wrapper()
  1391. @elementwise_type_promotion_wrapper(
  1392. type_promoting_args=("a", "b"),
  1393. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1394. )
  1395. def sub(
  1396. a: Union[TensorLikeType, NumberType],
  1397. b: Union[TensorLikeType, NumberType],
  1398. *,
  1399. alpha: NumberType = 1,
  1400. ):
  1401. """
  1402. Reference implementation of torch.sub
  1403. """
  1404. a, b = _maybe_broadcast(a, b)
  1405. if alpha != 1:
  1406. dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr]
  1407. python_type = utils.dtype_to_type(dtype)
  1408. if not utils.is_weakly_lesser_type(type(alpha), python_type):
  1409. msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
  1410. raise ValueError(msg)
  1411. if isinstance(b, torch.Tensor):
  1412. b = prims.mul(b, alpha)
  1413. else:
  1414. # Carefully not to use prims.mul if b is a scalar / symint.
  1415. # prims.mul always returns a tensor,
  1416. # which will mess with type promotion.
  1417. b = b * alpha
  1418. output = prims.sub(a, b)
  1419. return handle_noncontiguous_outputs([a, b], output)
  1420. @_make_elementwise_binary_reference(
  1421. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  1422. name="true_divide",
  1423. aten_op=None, # CompositeImplicitAutograd
  1424. supports_two_python_scalars=True,
  1425. )
  1426. def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1427. return prims.div(a, b)
  1428. @register_decomposition(aten.xlogy)
  1429. @out_wrapper()
  1430. @elementwise_type_promotion_wrapper(
  1431. type_promoting_args=("a", "b"),
  1432. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  1433. )
  1434. def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
  1435. torch._check(
  1436. isinstance(a, TensorLike) or isinstance(b, TensorLike),
  1437. lambda: 'Expected either argument a or b to be a Tensor"',
  1438. )
  1439. # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
  1440. if isinstance(b, TensorLike) and isinstance(a, Number):
  1441. a = scalar_tensor(a, dtype=b.dtype, device=b.device)
  1442. elif isinstance(a, TensorLike) and isinstance(b, Number):
  1443. b = scalar_tensor(b, dtype=a.dtype, device=a.device)
  1444. # mypy: expected "Tensor"
  1445. assert isinstance(a, TensorLike)
  1446. assert isinstance(b, TensorLike)
  1447. rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log(b)))
  1448. return torch.where(torch.isnan(b), float("nan"), rhs)
  1449. @_make_elementwise_binary_reference(
  1450. type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1451. aten_op=None, # CompositeImplicitAutograd
  1452. supports_two_python_scalars=True,
  1453. )
  1454. def trunc_divide(
  1455. a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
  1456. ):
  1457. dtype = utils.get_dtype(a)
  1458. if utils.is_integer_dtype(dtype):
  1459. return prims.div(a, b)
  1460. return trunc(prims.div(a, b))
  1461. #
  1462. # Elementwise Ternary References
  1463. #
  1464. @register_decomposition(aten.addcdiv)
  1465. @out_wrapper()
  1466. @elementwise_type_promotion_wrapper(
  1467. type_promoting_args=("self", "tensor1", "tensor2"),
  1468. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  1469. )
  1470. def addcdiv(
  1471. self: TensorLikeType,
  1472. tensor1: TensorLikeType,
  1473. tensor2: TensorLikeType,
  1474. *,
  1475. value: NumberType = 1,
  1476. ) -> TensorLikeType:
  1477. """
  1478. Reference implementation of torch.addcdiv
  1479. """
  1480. if value is not None:
  1481. dtype = self.dtype # no scalars allowed, see add
  1482. python_type = utils.dtype_to_type(dtype)
  1483. torch._check_value(
  1484. utils.is_weakly_lesser_type(type(value), python_type),
  1485. lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!",
  1486. )
  1487. return self + value * tensor1 / tensor2
  1488. @register_decomposition(aten.addcmul)
  1489. @out_wrapper()
  1490. @elementwise_type_promotion_wrapper(
  1491. type_promoting_args=("self", "tensor1", "tensor2"),
  1492. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1493. )
  1494. def addcmul(
  1495. self: TensorLikeType,
  1496. tensor1: TensorLikeType,
  1497. tensor2: TensorLikeType,
  1498. *,
  1499. value: NumberType = 1,
  1500. ) -> TensorLikeType:
  1501. """
  1502. Reference implementation of torch.addcmul
  1503. """
  1504. if value is not None:
  1505. dtype = self.dtype # no scalars allowed, see add
  1506. python_type = utils.dtype_to_type(dtype)
  1507. torch._check_value(
  1508. utils.is_weakly_lesser_type(type(value), python_type),
  1509. lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!",
  1510. )
  1511. return self + value * tensor1 * tensor2
  1512. @register_decomposition(aten.clamp)
  1513. @out_wrapper()
  1514. @elementwise_type_promotion_wrapper(
  1515. type_promoting_args=("a", "min", "max"),
  1516. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1517. )
  1518. def clamp(
  1519. a: TensorLikeType,
  1520. min: Optional[TensorOrNumberLikeType] = None,
  1521. max: Optional[TensorOrNumberLikeType] = None,
  1522. ) -> TensorLikeType:
  1523. # NOTE: grad behavior with implementation `where` is not consistent on `nan`
  1524. if min is None and max is None:
  1525. msg = "clamp called but both min and max are none!"
  1526. raise ValueError(msg)
  1527. if min is not None:
  1528. a_isnan = torch.isnan(a)
  1529. condition = torch.bitwise_or(torch.ge(a, min), a_isnan) # type: ignore[arg-type]
  1530. # we should also propagate `nan` coming from boundaries. However, that's
  1531. # not necessary since `ge` would already `False` when either operands has
  1532. # a `nan`. So this line below is redundant
  1533. # `condition = bitwise_and(condition, bitwise_not(isnan(min)))`
  1534. a = torch.where(condition, a, min) # type: ignore[arg-type]
  1535. if max is not None:
  1536. a_isnan = torch.isnan(a)
  1537. # same as above, no need to adjust `nan` from `max`
  1538. condition = torch.bitwise_or(torch.le(a, max), a_isnan) # type: ignore[arg-type]
  1539. a = torch.where(condition, a, max) # type: ignore[arg-type]
  1540. return a
  1541. @register_decomposition(aten.clamp_min)
  1542. @out_wrapper()
  1543. def clamp_min(
  1544. self: TensorLikeType,
  1545. min: Optional[TensorOrNumberLikeType] = None,
  1546. ) -> TensorLikeType:
  1547. return torch.clamp(self, min=min) # type: ignore[arg-type]
  1548. @register_decomposition(aten.clamp_max)
  1549. @out_wrapper()
  1550. def clamp_max(
  1551. self: TensorLikeType,
  1552. max: Optional[TensorOrNumberLikeType] = None,
  1553. ) -> TensorLikeType:
  1554. return torch.clamp(self, max=max) # type: ignore[arg-type]
  1555. #
  1556. # Conditional references
  1557. #
  1558. # https://pytorch.org/docs/stable/generated/torch.where.html
  1559. # TODO: implement alternate where
  1560. @register_decomposition(aten.where)
  1561. @out_wrapper()
  1562. @elementwise_type_promotion_wrapper(
  1563. type_promoting_args=("a", "b"),
  1564. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
  1565. )
  1566. def where(
  1567. pred: Tensor,
  1568. a: Optional[TensorOrNumberLikeType] = None,
  1569. b: Optional[TensorOrNumberLikeType] = None,
  1570. ):
  1571. """ """
  1572. if a is None or b is None:
  1573. raise NotImplementedError
  1574. utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
  1575. torch._check(
  1576. pred.dtype is torch.bool,
  1577. lambda: f"expected predicate to be bool, got {pred.dtype}",
  1578. )
  1579. pred, a, b = _maybe_broadcast(pred, a, b)
  1580. return prims.where(pred, a, b)
  1581. #
  1582. # Data Movement References
  1583. #
  1584. @register_decomposition(aten.clone)
  1585. @out_wrapper()
  1586. def clone(
  1587. a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
  1588. ) -> TensorLikeType:
  1589. result = prims.clone(a, memory_format=memory_format)
  1590. return result
  1591. def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True):
  1592. if not allow_cross_device and a.device != b.device:
  1593. msg = f"Attempting to copy from device {b.device} to device {a.device}, but cross-device copies are not allowed!"
  1594. raise RuntimeError(msg)
  1595. return prims.copy_to(a, b)
  1596. @register_decomposition(aten.item)
  1597. def item(a: TensorLikeType) -> NumberType:
  1598. if a.numel() != 1:
  1599. msg = f"Can't convert a tensor with {a.numel()} elements to a number!"
  1600. raise ValueError(msg)
  1601. # NOTE: explicit conversion is necessary for bool!
  1602. # See https://github.com/pytorch/pytorch/issues/78071
  1603. number_type = utils.dtype_to_type(a.dtype)
  1604. return number_type(prims.item(a))
  1605. # fast path when `to` returns an alias to input. This mimics the same function in aten
  1606. def _to_will_alias(
  1607. a: TensorLikeType,
  1608. device: Optional[DeviceLikeType] = None,
  1609. dtype: Optional[torch.dtype] = None,
  1610. copy: Optional[bool] = None,
  1611. layout: Optional[torch.layout] = None,
  1612. memory_format: Optional[torch.memory_format] = None,
  1613. pin_memory: Optional[bool] = False,
  1614. non_blocking: bool = False, # not using non_blocking
  1615. ) -> bool:
  1616. return (
  1617. not copy
  1618. and (device is None or a.device == device)
  1619. and (dtype is None or a.dtype == dtype)
  1620. and (layout is None or a.layout == layout)
  1621. # is_pinned issue #84925
  1622. # and (pin_memory is None or pin_memory == a.is_pinned())
  1623. and (
  1624. memory_format is None
  1625. or memory_format == torch.preserve_format
  1626. or utils.is_contiguous_for_memory_format(a, memory_format=memory_format)
  1627. )
  1628. )
  1629. @singledispatch
  1630. def _to_dispatch(*args, **kwargs):
  1631. raise NotImplementedError
  1632. @_to_dispatch.register
  1633. def _to_device(
  1634. device: torch.device,
  1635. dtype: torch.dtype,
  1636. non_blocking: bool = False,
  1637. copy: bool = False,
  1638. memory_format: Optional[torch.memory_format] = None,
  1639. ) -> Dict[str, Any]:
  1640. kwargs = {
  1641. "device": device,
  1642. "dtype": dtype,
  1643. "non_blocking": non_blocking,
  1644. "copy": copy,
  1645. "memory_format": memory_format,
  1646. }
  1647. return kwargs
  1648. @_to_dispatch.register
  1649. def _to_device_str(
  1650. device: str,
  1651. dtype: torch.dtype,
  1652. non_blocking: bool = False,
  1653. copy: bool = False,
  1654. memory_format: Optional[torch.memory_format] = None,
  1655. ) -> Dict[str, Any]:
  1656. kwargs = {
  1657. "device": torch.device(device),
  1658. "dtype": dtype,
  1659. "non_blocking": non_blocking,
  1660. "copy": copy,
  1661. "memory_format": memory_format,
  1662. }
  1663. return kwargs
  1664. @_to_dispatch.register
  1665. def _to_dtype(
  1666. dtype: torch.dtype,
  1667. non_blocking: bool = False,
  1668. copy: bool = False,
  1669. memory_format: Optional[torch.memory_format] = None,
  1670. ) -> Dict[str, Any]:
  1671. kwargs = {
  1672. "dtype": dtype,
  1673. "non_blocking": non_blocking,
  1674. "copy": copy,
  1675. "memory_format": memory_format,
  1676. }
  1677. return kwargs
  1678. @_to_dispatch.register
  1679. def _to_other(
  1680. other: Tensor,
  1681. non_blocking: bool = False,
  1682. copy: bool = False,
  1683. memory_format: Optional[torch.memory_format] = None,
  1684. ) -> Dict[str, Any]:
  1685. device = other.device
  1686. dtype = other.dtype
  1687. layout = other.layout
  1688. # is_pinned issue #84925
  1689. # pin_memory = other.is_pinned()
  1690. kwargs = {
  1691. "device": device,
  1692. "dtype": dtype,
  1693. "layout": layout,
  1694. "non_blocking": non_blocking,
  1695. "copy": copy,
  1696. "memory_format": memory_format,
  1697. }
  1698. return kwargs
  1699. # remove to_kwargs that is already present in `a`
  1700. def _canonicalize_to_arguments(a: Tensor, to_kwargs: dict):
  1701. options_to_check = ["dtype", "device", "layout", "memory_format"]
  1702. # "device" option could be passed a str instead torch.device
  1703. if "device" in to_kwargs and isinstance(to_kwargs["device"], str):
  1704. to_kwargs["device"] = torch.device(to_kwargs["device"])
  1705. for kw in options_to_check:
  1706. if kw in to_kwargs:
  1707. if (
  1708. (kw == "memory_format" and to_kwargs[kw] is torch.preserve_format)
  1709. or (
  1710. kw == "device"
  1711. and to_kwargs[kw].type == a.device.type
  1712. and (
  1713. not to_kwargs[kw].index or to_kwargs[kw].index == a.device.index
  1714. )
  1715. )
  1716. or (
  1717. getattr(a, kw, None) == to_kwargs[kw]
  1718. ) # this also handles {"memory_format": None}
  1719. ):
  1720. to_kwargs.pop(kw)
  1721. def to(a: TensorLikeType, *args, **kwargs) -> TensorLikeType:
  1722. # handled dispatch via positional arguments
  1723. if len(args) != 0:
  1724. kwargs = _to_dispatch(*args, **kwargs)
  1725. # TODO: is_pinned is not currently supported in refs or fake_tensor
  1726. # https://github.com/pytorch/pytorch/issues/84925
  1727. assert "pin_memory" not in kwargs
  1728. _canonicalize_to_arguments(a, kwargs)
  1729. if _to_will_alias(a, **kwargs):
  1730. return a
  1731. copy = kwargs.pop("copy") if "copy" in kwargs else False
  1732. non_blocking = kwargs.pop("non_blocking") if "non_blocking" in kwargs else False
  1733. # short-circuit to `prims.convert_element_type` when `to` is just a dtype change
  1734. if (
  1735. (copy or (kwargs.get("dtype", a.dtype) != a.dtype))
  1736. and (not non_blocking)
  1737. and ("memory_format" not in kwargs)
  1738. and ("device" not in kwargs)
  1739. and ("layout" not in kwargs)
  1740. # is_pinned issue #84925
  1741. # and ("pin_memory" not in kwargs)
  1742. ):
  1743. return prims.convert_element_type(a, kwargs.get("dtype", a.dtype))
  1744. result = torch.empty_like(a, **kwargs)
  1745. # TODO: non_blocking should be handled by `copy_to`
  1746. copy_to(result, a)
  1747. return result
  1748. #
  1749. # Reduction references
  1750. #
  1751. def _reduction(
  1752. a: TensorLikeType,
  1753. prim: Callable,
  1754. *,
  1755. has_identity: bool = True,
  1756. accepts_dim_tuple: bool = True, # to handle min/argmin that accept single dim only
  1757. dims: Optional[DimsType] = None,
  1758. keepdims: bool = False,
  1759. dtype: Optional[torch.dtype] = None, # should be specified for ops that support it
  1760. out: Optional[Tensor] = None,
  1761. output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
  1762. ) -> TensorLikeType: # it is usually SAME, but I want
  1763. # ref writers to actually think about what to put here
  1764. assert isinstance(a, TensorLike)
  1765. if a.ndim > 64:
  1766. raise RuntimeError(
  1767. f"Received a tensor with {a.ndim} dimensions, but only tensors with up to 64 dims are supported!"
  1768. )
  1769. if out is not None:
  1770. assert isinstance(out, TensorLike)
  1771. if dtype is not None:
  1772. # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms
  1773. if dtype != out.dtype:
  1774. raise RuntimeError(
  1775. "dtype argument and out dtype must match in reduction"
  1776. )
  1777. if not accepts_dim_tuple:
  1778. assert dims is None or isinstance(dims, Dim)
  1779. if isinstance(dims, Dim):
  1780. dims = (dims,) # type: ignore[assignment]
  1781. dims = utils.reduction_dims(a.shape, dims)
  1782. if not has_identity:
  1783. valid_shape = a.ndim == 0 or py_all(a.shape[i] for i in dims)
  1784. if not valid_shape:
  1785. raise RuntimeError(
  1786. "reducing over zero-size dimension for reduction operation without identity"
  1787. )
  1788. computation_dtype, result_dtype = utils.reduction_dtypes(
  1789. a, output_dtype_kind, dtype
  1790. )
  1791. a = _maybe_convert_to_dtype(a, computation_dtype) # type: ignore[method-assign]
  1792. result = prim(a, dims)
  1793. if keepdims:
  1794. output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)]
  1795. broadcast_dims = [i for i in range(a.ndim) if i not in dims]
  1796. result = prims.broadcast_in_dim(result, output_shape, broadcast_dims)
  1797. if out is not None:
  1798. assert result_dtype is not None
  1799. if dtype is not None and result_dtype != out.dtype:
  1800. raise RuntimeError(
  1801. "Expected the dtype of reduction result and out to match"
  1802. )
  1803. out = _maybe_resize_out(out, result.shape)
  1804. return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type]
  1805. if result.dtype != result_dtype and result_dtype is not None:
  1806. result = prims.convert_element_type(result, result_dtype)
  1807. return result
  1808. def _make_copy_from_view(fn):
  1809. """
  1810. Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy)
  1811. """
  1812. name = fn.__name__
  1813. fn = out_wrapper()(fn)
  1814. def _fn(*args, out=None, **kwargs):
  1815. result = fn(*args, out=out, **kwargs)
  1816. if out is None:
  1817. return result.clone(memory_format=torch.contiguous_format)
  1818. return result
  1819. copy_name = f"{name}_copy"
  1820. _fn.__name__ = copy_name
  1821. _fn = register_decomposition(getattr(aten, copy_name))(_fn)
  1822. return _fn
  1823. # Saves Python all
  1824. py_all = all
  1825. @register_decomposition(aten.all)
  1826. @out_wrapper()
  1827. def all(
  1828. a: TensorLikeType,
  1829. dim: Optional[DimsType] = None,
  1830. keepdim: bool = False,
  1831. ) -> TensorLikeType:
  1832. result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim))
  1833. if a.dtype == torch.uint8:
  1834. result = result.to(dtype=torch.uint8)
  1835. return result
  1836. # Saves Python any
  1837. py_any = any
  1838. @register_decomposition(aten.any)
  1839. @out_wrapper()
  1840. def any(
  1841. a: TensorLikeType,
  1842. dim: Optional[DimsType] = None,
  1843. keepdim: bool = False,
  1844. ) -> TensorLikeType:
  1845. a_ = _maybe_convert_to_dtype(a, torch.bool)
  1846. if isinstance(dim, (list, tuple)) and len(dim) == 0:
  1847. result = a_.clone()
  1848. else:
  1849. result = a_.sum(dim=dim, keepdim=keepdim).ne(False)
  1850. # Preserves uint8 -- probably a legacy mask thing
  1851. if a.dtype is torch.uint8:
  1852. return prims.convert_element_type(result, torch.uint8)
  1853. return result
  1854. @register_decomposition([aten.sum.dim_IntList, aten.sum.IntList_out])
  1855. def sum(
  1856. a: TensorLikeType,
  1857. dim: Union[Optional[int], Optional[List[int]]] = None,
  1858. keepdim: bool = False,
  1859. *,
  1860. dtype: Optional[torch.dtype] = None,
  1861. out: Optional[Tensor] = None,
  1862. ) -> TensorLikeType:
  1863. if dtype is None:
  1864. if out is not None:
  1865. dtype = out.dtype
  1866. elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
  1867. dtype = torch.int64
  1868. else:
  1869. dtype = a.dtype
  1870. # reduces over all dimensions if dim=() is passed
  1871. if dim == () or dim == []:
  1872. dim = None
  1873. return _reduction(
  1874. a,
  1875. prims.sum,
  1876. dims=dim,
  1877. keepdims=keepdim,
  1878. dtype=dtype,
  1879. out=out,
  1880. output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
  1881. )
  1882. def sum_to_size(
  1883. a: Tensor,
  1884. *shape,
  1885. ) -> Tensor:
  1886. shape = utils.extract_shape_from_varargs(shape, validate=False)
  1887. torch._check(
  1888. utils.is_expandable_to(shape, a.shape),
  1889. lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"',
  1890. )
  1891. # In ATen scalar tensors are sent through sum and the result is returned as
  1892. # type promoted
  1893. if utils.is_same_shape(shape, a.shape) and len(shape) > 0:
  1894. return prims.view_of(a)
  1895. leading_dims = a.ndim - len(shape)
  1896. reduce_dims = tuple(range(leading_dims)) + tuple(
  1897. i
  1898. for i in range(leading_dims, len(shape))
  1899. if shape[i - leading_dims] == 1 and a.shape[i] != 1
  1900. )
  1901. return torch.sum(a, dim=reduce_dims, keepdim=True, dtype=None)
  1902. @register_decomposition(aten.prod)
  1903. def prod(
  1904. a: TensorLikeType,
  1905. dim: Union[Optional[int], Optional[List[int]]] = None,
  1906. keepdim: bool = False,
  1907. *,
  1908. dtype=None,
  1909. out: Optional[Tensor] = None,
  1910. ) -> TensorLikeType:
  1911. if dtype is None:
  1912. if out is not None:
  1913. dtype = out.dtype
  1914. elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
  1915. dtype = torch.int64
  1916. else:
  1917. dtype = a.dtype
  1918. # reduces over all dimensions if dim=() is passed
  1919. if dim == () or dim == []:
  1920. dim = None
  1921. return _reduction(
  1922. a,
  1923. prims.prod,
  1924. dims=dim,
  1925. keepdims=keepdim,
  1926. dtype=dtype,
  1927. out=out,
  1928. output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
  1929. )
  1930. @register_decomposition(aten.amin)
  1931. def amin(
  1932. a: TensorLikeType,
  1933. dim: Optional[DimsType] = None,
  1934. keepdim: bool = False,
  1935. *,
  1936. out: Optional[Tensor] = None,
  1937. ) -> TensorLikeType:
  1938. # reduces over all dimensions if dim=() is passed
  1939. if dim == () or dim == []:
  1940. dim = None
  1941. return _reduction(
  1942. a,
  1943. prims.amin,
  1944. dims=dim,
  1945. keepdims=keepdim,
  1946. dtype=None,
  1947. out=out,
  1948. has_identity=False,
  1949. output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
  1950. )
  1951. @register_decomposition(aten.amax)
  1952. def amax(
  1953. a: TensorLikeType,
  1954. dim: Optional[DimsType] = None,
  1955. keepdim: bool = False,
  1956. *,
  1957. out: Optional[Tensor] = None,
  1958. ) -> TensorLikeType:
  1959. # reduces over all dimensions if dim=() is passed
  1960. if dim == () or dim == []:
  1961. dim = None
  1962. return _reduction(
  1963. a,
  1964. prims.amax,
  1965. dims=dim,
  1966. keepdims=keepdim,
  1967. dtype=None,
  1968. out=out,
  1969. has_identity=False,
  1970. output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
  1971. )
  1972. def _dim_var_dispatch(dim=None, unbiased=None):
  1973. # There's the following overload of torch.var:
  1974. # var(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
  1975. # We need to explicitly convert bool dims to unbiased arg
  1976. if unbiased is None and isinstance(dim, bool):
  1977. unbiased = dim
  1978. dim = None
  1979. return dim, unbiased
  1980. @register_decomposition(aten.var)
  1981. @out_wrapper()
  1982. def var(
  1983. a: TensorLikeType,
  1984. dim: Optional[DimsType] = None,
  1985. unbiased: Optional[bool] = None,
  1986. keepdim: bool = False,
  1987. *,
  1988. correction: Optional[NumberType] = None,
  1989. ) -> TensorLikeType:
  1990. dim, unbiased = _dim_var_dispatch(dim, unbiased)
  1991. correction = utils.set_correction(unbiased, correction)
  1992. # reduces over all dimensions if dim=() is passed
  1993. if dim == () or dim == []:
  1994. dim = None
  1995. result = _reduction(
  1996. a,
  1997. partial(prims.var, correction=correction),
  1998. dims=dim,
  1999. keepdims=keepdim,
  2000. dtype=None,
  2001. out=None,
  2002. has_identity=True,
  2003. output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
  2004. )
  2005. return result
  2006. @register_decomposition(aten.std)
  2007. @out_wrapper()
  2008. def std(
  2009. a: TensorLikeType,
  2010. dim: Union[Optional[int], Optional[List[int]]] = None,
  2011. unbiased: Optional[bool] = None,
  2012. keepdim: bool = False,
  2013. *,
  2014. correction: Optional[NumberType] = None,
  2015. ) -> TensorLikeType:
  2016. dim, unbiased = _dim_var_dispatch(dim, unbiased)
  2017. correction = utils.set_correction(unbiased, correction)
  2018. opmath_dtype, dtype = utils.reduction_dtypes(
  2019. a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
  2020. )
  2021. a = _maybe_convert_to_dtype(a, opmath_dtype)
  2022. a_var = torch.var(a, dim, correction=correction, keepdim=keepdim)
  2023. a_std = torch.sqrt(a_var)
  2024. assert dtype is not None
  2025. return _maybe_convert_to_dtype(a_std, dtype)
  2026. @register_decomposition(aten.mean)
  2027. def mean(
  2028. a: TensorLikeType,
  2029. dim: Optional[DimsType] = None,
  2030. keepdim: bool = False,
  2031. *,
  2032. dtype=None,
  2033. out=None,
  2034. ) -> TensorLikeType:
  2035. # reduces over all dimensions if dim=() is passed
  2036. if dim == () or dim == []:
  2037. dim = None
  2038. orig_dtype = dtype
  2039. if dtype is None:
  2040. dtype = a.dtype
  2041. # can't use out wrapper because of this argument
  2042. torch._check(
  2043. out is None or out.dtype == dtype,
  2044. lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead",
  2045. )
  2046. result = _reduction(
  2047. a,
  2048. prims.sum,
  2049. dims=dim,
  2050. keepdims=keepdim,
  2051. dtype=dtype,
  2052. out=None,
  2053. output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
  2054. )
  2055. torch._check(
  2056. utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
  2057. lambda: (
  2058. f"mean(): could not infer output dtype. "
  2059. f"{'Input' if orig_dtype is None else 'Optional'} dtype must be either "
  2060. f"a floating point or complex dtype. Got: {dtype}"
  2061. ),
  2062. )
  2063. if isinstance(dim, Dim):
  2064. dim = (dim,) # type: ignore[assignment]
  2065. dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type]
  2066. nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1)
  2067. result = true_divide(result, nelem)
  2068. result_dtype = a.dtype if dtype is None else dtype
  2069. result = _maybe_convert_to_dtype(result, result_dtype) # type: ignore[method-assign]
  2070. if out is not None:
  2071. assert isinstance(out, TensorLike)
  2072. out = _maybe_resize_out(out, result.shape)
  2073. return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type]
  2074. return result
  2075. @register_decomposition(aten.std_mean)
  2076. @out_wrapper("out0", "out1")
  2077. def std_mean(
  2078. a: TensorLikeType,
  2079. dim: Optional[DimsType] = None,
  2080. *,
  2081. unbiased: Optional[bool] = None,
  2082. keepdim: bool = False,
  2083. correction: Optional[NumberType] = None,
  2084. ):
  2085. dim, unbiased = _dim_var_dispatch(dim, unbiased)
  2086. correction = utils.set_correction(unbiased, correction)
  2087. opmath_dtype, dtype = utils.reduction_dtypes(
  2088. a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
  2089. )
  2090. original_dtype = a.dtype
  2091. a = _maybe_convert_to_dtype(a, opmath_dtype)
  2092. a_var, a_mean = torch.var_mean(a, dim, correction=correction, keepdim=keepdim)
  2093. a_std = torch.sqrt(a_var)
  2094. assert dtype is not None
  2095. return (
  2096. _maybe_convert_to_dtype(a_std, dtype),
  2097. _maybe_convert_to_dtype(a_mean, original_dtype),
  2098. )
  2099. @register_decomposition(aten.var_mean)
  2100. @out_wrapper("out0", "out1")
  2101. def var_mean(
  2102. a: TensorLikeType,
  2103. dim: Optional[DimsType] = None,
  2104. unbiased: Optional[bool] = None,
  2105. keepdim: bool = False,
  2106. *,
  2107. correction: Optional[NumberType] = None,
  2108. ):
  2109. dim, unbiased = _dim_var_dispatch(dim, unbiased)
  2110. v = var(a, dim, unbiased, keepdim, correction=correction)
  2111. m = mean(a, dim, keepdim)
  2112. return v, m
  2113. @register_decomposition(aten.addr)
  2114. @out_wrapper()
  2115. @elementwise_type_promotion_wrapper(
  2116. type_promoting_args=("self", "vec1", "vec2"),
  2117. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  2118. )
  2119. def addr(
  2120. self: TensorLikeType,
  2121. vec1: TensorLikeType,
  2122. vec2: TensorLikeType,
  2123. *,
  2124. beta: NumberType = 1,
  2125. alpha: NumberType = 1,
  2126. ) -> TensorLikeType:
  2127. torch._check(
  2128. vec1.ndim == 1,
  2129. lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D",
  2130. )
  2131. torch._check(
  2132. vec2.ndim == 1,
  2133. lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D",
  2134. )
  2135. self = self.expand(vec1.shape[0], vec2.shape[0])
  2136. if utils.is_boolean_dtype(self.dtype):
  2137. # Integers are accepted for booleans
  2138. torch._check(
  2139. is_weakly_lesser_type(type(beta), int),
  2140. lambda: f"expected bool/int beta but got {type(beta)}",
  2141. )
  2142. torch._check(
  2143. is_weakly_lesser_type(type(alpha), int),
  2144. lambda: f"expected bool/int alpha but got {type(beta)}",
  2145. )
  2146. if not beta:
  2147. return torch.outer(vec1, vec2) if alpha else torch.full_like(self, False)
  2148. else:
  2149. return torch.logical_or(
  2150. self,
  2151. torch.outer(vec1, vec2) if alpha else torch.full_like(self, False),
  2152. )
  2153. else:
  2154. torch._check(
  2155. is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)),
  2156. lambda: f"cannot safely convert {type(beta)} to {self.dtype}",
  2157. )
  2158. torch._check(
  2159. is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)),
  2160. lambda: f"cannot safely convert {type(alpha)} to {self.dtype}",
  2161. )
  2162. if beta == 0:
  2163. # This means NaNs from self are dropped if beta is zero
  2164. return alpha * torch.outer(vec1, vec2)
  2165. else:
  2166. return beta * self + alpha * torch.outer(vec1, vec2)
  2167. # CompositeImplicitAutograd - don't register decomp
  2168. def atleast_1d(
  2169. arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
  2170. ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
  2171. """Reference implementation of :func:`torch.atleast_1d`."""
  2172. if not args and isinstance(arg, collections.abc.Sequence):
  2173. args_ = arg
  2174. else:
  2175. assert not isinstance(arg, collections.abc.Sequence)
  2176. args_ = (arg,) + args
  2177. res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_)
  2178. return res if len(res) > 1 else res[0]
  2179. # Helper function with assert to avoid MyPy error
  2180. # of incompatible type passed to unsqueeze
  2181. def _unsqueeze_atleast(
  2182. at_least_fn: Callable, dim: int, arg: TensorLikeType
  2183. ) -> TensorLikeType:
  2184. arg_ = at_least_fn(arg)
  2185. assert isinstance(arg_, TensorLike)
  2186. return unsqueeze(arg_, dim)
  2187. # CompositeImplicitAutograd - don't register decomp
  2188. def atleast_2d(
  2189. arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
  2190. ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
  2191. """Reference implementation of :func:`torch.atleast_2d`."""
  2192. if not args and isinstance(arg, collections.abc.Sequence):
  2193. args_ = arg
  2194. else:
  2195. assert not isinstance(arg, collections.abc.Sequence)
  2196. args_ = (arg,) + args
  2197. unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0)
  2198. res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_)
  2199. return res if len(res) > 1 else res[0]
  2200. # CompositeImplicitAutograd - don't register decomp
  2201. def atleast_3d(
  2202. arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
  2203. ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
  2204. """Reference implementation of :func:`torch.atleast_3d`."""
  2205. if not args and isinstance(arg, collections.abc.Sequence):
  2206. args_ = arg
  2207. else:
  2208. assert not isinstance(arg, collections.abc.Sequence)
  2209. args_ = (arg,) + args
  2210. unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1)
  2211. res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_)
  2212. return res if len(res) > 1 else res[0]
  2213. def as_strided(
  2214. a: TensorLikeType,
  2215. size: ShapeType,
  2216. stride: StrideType,
  2217. storage_offset: Optional[int] = None,
  2218. ) -> TensorLikeType:
  2219. storage_offset_int = (
  2220. storage_offset if storage_offset is not None else a.storage_offset()
  2221. )
  2222. return prims.as_strided(a, size, stride, storage_offset_int)
  2223. @register_decomposition(aten.as_strided_scatter)
  2224. @out_wrapper()
  2225. def as_strided_scatter(
  2226. input: TensorLikeType,
  2227. src: TensorLikeType,
  2228. size: ShapeType,
  2229. stride: StrideType,
  2230. storage_offset: Optional[int] = None,
  2231. ) -> TensorLikeType:
  2232. storage_offset_int = 0 if storage_offset is None else storage_offset
  2233. return prims.as_strided_scatter(input, src, size, stride, storage_offset_int)
  2234. def broadcast_shapes(*shapes) -> ShapeType:
  2235. return torch.Size(_broadcast_shapes(*shapes))
  2236. @aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2237. @aten.broadcast_tensors.default.py_impl(DispatchKey.Meta)
  2238. def broadcast_tensors(*tensors) -> List[TensorLikeType]:
  2239. if len(tensors) == 1 and not isinstance(tensors[0], Tensor):
  2240. tensors = tensors[0]
  2241. return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False))
  2242. # CompositeImplicitAutograd - don't register decomp
  2243. def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType:
  2244. start = len(size) - len(a.shape)
  2245. dims = tuple(range(start, len(a.shape) + start))
  2246. return prims.broadcast_in_dim(a, size, dims)
  2247. @register_decomposition(aten.cat)
  2248. @out_wrapper()
  2249. @elementwise_type_promotion_wrapper(
  2250. type_promoting_args=("tensors",),
  2251. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
  2252. )
  2253. def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
  2254. def cat_compute_output_memory_format(inputs):
  2255. format = None
  2256. for t in inputs:
  2257. f = utils.suggest_memory_format(t)
  2258. if f == torch.contiguous_format:
  2259. return f
  2260. if format is not None and format != f:
  2261. return torch.contiguous_format
  2262. format = f
  2263. assert format is not None
  2264. return format
  2265. if len(tensors) == 0:
  2266. msg = "cat expects at least one tensor, but received zero!"
  2267. raise ValueError(msg)
  2268. for tensor in tensors:
  2269. assert isinstance(tensor, TensorLike)
  2270. utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)
  2271. from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
  2272. # This is a bit tricky. Naively, you would expect to just pick one
  2273. # arbitrary tensor and check that all tensors match this tensor. However,
  2274. # there is legacy behavior which says that if you have a 1-D empty tensor
  2275. # (0,), this is permissible. So you can't assume that all the tensors
  2276. # have same dimensionality, and you can't assume that the first tensor is
  2277. # the correct stencil.
  2278. #
  2279. # We'll implement this in a few passes. First, we will try to infer the
  2280. # ndim of the cat output. If this ndim != 1, then we know that all ndim =
  2281. # 1 inputs must be empty, or are errors. If this ndim == 1, then life
  2282. # is easy (the legacy special case coincides with regular handling).
  2283. #
  2284. # NB: The regular implementation of cat just filters out empty inputs,
  2285. # but we do it slightly different here for better handling for unbacked
  2286. # SymInts
  2287. example = None
  2288. for i, t in enumerate(tensors):
  2289. if example is None:
  2290. if t.ndim != 1:
  2291. example = t
  2292. else:
  2293. if t.ndim != 1:
  2294. torch._check(
  2295. t.ndim == example.ndim,
  2296. lambda: "Number of dimensions of tensors must match. "
  2297. f"Expected {example.ndim}-D tensors, but got {t.ndim}-D for "
  2298. f"tensor number {i} in the list",
  2299. )
  2300. if example is None:
  2301. # example is None if everything is 1-D. If so, just arbitrarily pick
  2302. # the first one
  2303. example = tensors[0]
  2304. shape = example.shape
  2305. filtered = []
  2306. for tensor_idx, tensor in enumerate(tensors):
  2307. if len(shape) != len(tensor.shape):
  2308. assert tensor.ndim == 1 # we've already checked this above
  2309. # Don't suggest the legacy behavior in the error message
  2310. torch._check(
  2311. tensor.shape[0] == 0,
  2312. lambda: f"Number of dimensions of tensors must match. "
  2313. f"Expected {example.ndim}-D tensors, but got 1-D for "
  2314. f"tensor number {tensor_idx} in the list",
  2315. )
  2316. else:
  2317. # Remove inputs that are 1-D, zero size
  2318. if tensor.ndim == 1 and guard_size_oblivious(tensor.shape[0] == 0):
  2319. continue
  2320. # Don't bother checking size match, prims.cat will handle it
  2321. filtered.append(tensor)
  2322. memory_format = cat_compute_output_memory_format(tensors)
  2323. if len(filtered) == 0:
  2324. t = tensors[0]
  2325. # TODO: fix this to work with meta tensors
  2326. try:
  2327. requires_grad = any(x.requires_grad for x in tensors)
  2328. except Exception:
  2329. requires_grad = False
  2330. return empty(
  2331. (0,),
  2332. dtype=t.dtype,
  2333. device=t.device,
  2334. requires_grad=requires_grad,
  2335. memory_format=memory_format,
  2336. )
  2337. dim = utils.canonicalize_dim(filtered[0].ndim, dim)
  2338. utils.validate_idx(filtered[0].ndim, dim)
  2339. return prims.cat(filtered, dim).clone(memory_format=memory_format)
  2340. # CompositeImplicitAutograd - don't register decomp
  2341. @out_wrapper()
  2342. def column_stack(tensors: TensorSequenceType) -> TensorLikeType:
  2343. aligned_tensors = tuple(
  2344. x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors
  2345. )
  2346. return cat(aligned_tensors, 1)
  2347. def conj(input: TensorLikeType) -> TensorLikeType:
  2348. if not utils.is_complex_dtype(input.dtype):
  2349. return input
  2350. if input.is_sparse:
  2351. return torch.conj_physical(input)
  2352. return prims.conj(input)
  2353. # This replicates at::constant_pad_nd, defined in ATen/native/PadNd.cpp
  2354. @register_decomposition(aten.constant_pad_nd)
  2355. @out_wrapper()
  2356. def constant_pad_nd(
  2357. input: TensorLikeType, pad: List[int], value: NumberType = 0
  2358. ) -> TensorLikeType:
  2359. torch._check(
  2360. len(pad) % 2 == 0,
  2361. lambda: f"Length of pad must be even but instead it equals {len(pad)}",
  2362. )
  2363. input_sizes = input.shape
  2364. l_inp = len(input_sizes)
  2365. l_pad = len(pad) // 2
  2366. l_diff = l_inp - l_pad
  2367. torch._check(
  2368. l_inp >= l_pad,
  2369. lambda: "Length of pad should be no more than twice the number of "
  2370. f"dimensions of the input. Pad length is {len(pad)} while the input has "
  2371. f"{l_inp} dimensions.",
  2372. )
  2373. c_input = input
  2374. for i in range(l_diff, l_inp):
  2375. pad_idx = 2 * (l_inp - i - 1)
  2376. if pad[pad_idx] < 0:
  2377. c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx])
  2378. if pad[pad_idx + 1] < 0:
  2379. c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1])
  2380. # if none of the pads are positive we can just return the result
  2381. if builtins.all(p <= 0 for p in pad):
  2382. return c_input.clone()
  2383. new_shape = list(input_sizes[:l_diff])
  2384. for i in range(l_pad):
  2385. pad_idx = len(pad) - ((i + 1) * 2)
  2386. new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
  2387. torch._check(
  2388. new_dim > 0,
  2389. lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
  2390. f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
  2391. f"which is invalid. Check dimension {l_diff + i} of your input.",
  2392. )
  2393. new_shape.append(new_dim)
  2394. memory_format = utils.suggest_memory_format(input)
  2395. output = torch.empty(
  2396. new_shape,
  2397. dtype=input.dtype,
  2398. device=input.device,
  2399. requires_grad=input.requires_grad,
  2400. memory_format=memory_format,
  2401. )
  2402. if value == 0 and input.dtype == torch.bool:
  2403. value = False
  2404. # torch.fill isn't typed to allow complex values
  2405. output = torch.fill(output, value) # type: ignore[arg-type]
  2406. c_output = output
  2407. for i in range(l_diff, l_inp):
  2408. pad_idx = 2 * (l_inp - i - 1)
  2409. if pad[pad_idx] > 0:
  2410. c_output = c_output.narrow(
  2411. i, pad[pad_idx], c_output.shape[i] - pad[pad_idx]
  2412. )
  2413. if pad[pad_idx + 1] > 0:
  2414. c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1])
  2415. prims.copy_to(c_output, c_input)
  2416. return output
  2417. def contiguous(
  2418. a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format
  2419. ) -> Tensor:
  2420. torch._check(
  2421. memory_format != torch.preserve_format,
  2422. lambda: "preserve memory format is unsupported by the contiguous operator",
  2423. )
  2424. if utils.is_contiguous_for_memory_format(a, memory_format=memory_format):
  2425. return a
  2426. return torch.clone(a, memory_format=memory_format)
  2427. @out_wrapper()
  2428. def dstack(tensors: TensorSequenceType) -> TensorLikeType:
  2429. torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
  2430. aligned_tensors = atleast_3d(*tensors)
  2431. return cat(aligned_tensors, 2)
  2432. @register_decomposition(aten.expand)
  2433. def expand(a: Tensor, *shape) -> Tensor:
  2434. from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
  2435. # NOTE: cannot use utils.extract_shape_from_varargs here
  2436. # because that also validates the shape, but the shape
  2437. # given to expand may be "invalid"
  2438. if len(shape) == 1 and isinstance(shape[0], Sequence):
  2439. shape = tuple(shape[0])
  2440. torch._check(
  2441. len(shape) >= len(a.shape),
  2442. lambda: "expand: the requested shape has too few dimensions!",
  2443. )
  2444. offset = len(shape) - len(a.shape)
  2445. shape_ = list(shape)
  2446. for idx, x in enumerate(a.shape):
  2447. offset_idx = idx + offset
  2448. requested_length = shape[offset_idx]
  2449. torch._check(
  2450. guard_size_oblivious(requested_length == x)
  2451. or guard_size_oblivious(x == 1)
  2452. or requested_length == -1,
  2453. lambda: f"expand: attempting to expand a dimension of length {x}!",
  2454. )
  2455. shape_[offset_idx] = requested_length if requested_length != -1 else x
  2456. # At this point shape must be valid
  2457. utils.validate_shape(shape_)
  2458. return prims.broadcast_in_dim(
  2459. a, shape_, tuple(range(offset, len(a.shape) + offset))
  2460. )
  2461. # CompositeImplicitAutograd - don't register decomp
  2462. def expand_as(a: Tensor, b: Tensor) -> Tensor:
  2463. return a.expand(b.shape)
  2464. def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]:
  2465. if chunks <= 0:
  2466. msg = f"Expected at least one chunk, but got {chunks}!"
  2467. raise ValueError(msg)
  2468. dim = utils.canonicalize_dim(a.ndim, dim)
  2469. length = a.shape[dim]
  2470. chunk_size = math.ceil(length / chunks)
  2471. full_chunks = math.floor(length / chunk_size)
  2472. tail_chunk_size = length % chunk_size
  2473. result = []
  2474. for i in range(full_chunks):
  2475. result.append(narrow(a, dim, i * chunk_size, chunk_size))
  2476. if tail_chunk_size != 0:
  2477. result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size))
  2478. return tuple(result)
  2479. # Note: flatten, unlike other shape operators, returns the input tensor on a no-op (unless
  2480. # a 0D tensor is flattened, in which case it's returned in 1D)
  2481. # CompositeImplicitAutograd - don't register decomp
  2482. def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType:
  2483. start_dim = utils.canonicalize_dim(a.ndim, start_dim)
  2484. end_dim = utils.canonicalize_dim(a.ndim, end_dim)
  2485. # Short-circuits on no-op
  2486. if start_dim == end_dim and a.ndim != 0:
  2487. return a
  2488. # Tries to take a view
  2489. # TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view)
  2490. new_shape, new_strides = prims._collapse_view_helper(a, start_dim, end_dim)
  2491. if new_shape is not None:
  2492. return prims.collapse_view(a, start_dim, end_dim)
  2493. # Makes a copy if it can't make a view
  2494. return prims.collapse(a, start_dim, end_dim)
  2495. @register_decomposition(aten.flip)
  2496. @out_wrapper()
  2497. def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
  2498. if not isinstance(dims, tuple) and not isinstance(dims, list):
  2499. raise ValueError("dims has to be a sequence of ints")
  2500. dims = utils.canonicalize_dims(a.ndim, dims) # type: ignore[assignment]
  2501. utils.validate_no_repeating_dims(dims)
  2502. return prims.rev(a, dims)
  2503. # CompositeImplicitAutograd - don't register decomp
  2504. def fliplr(a: TensorLikeType) -> TensorLikeType:
  2505. if a.ndim < 2:
  2506. raise RuntimeError("Input must be >= 2-d.")
  2507. return flip(a, (1,))
  2508. # CompositeImplicitAutograd - don't register decomp
  2509. def flipud(a: TensorLikeType) -> TensorLikeType:
  2510. if a.ndim < 1:
  2511. raise RuntimeError("Input must be >= 1-d.")
  2512. return flip(a, (0,))
  2513. # CompositeImplicitAutograd - don't register decomp
  2514. def narrow(
  2515. a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int
  2516. ) -> TensorLikeType:
  2517. # Supports Tensor overload that was added for XLA:
  2518. # https://github.com/pytorch/pytorch/issues/31558
  2519. if isinstance(start, TensorLike):
  2520. torch._check(
  2521. start.dim() == 0 and utils.is_integer_dtype(start.dtype),
  2522. lambda: "start must be an 0-dim integral Tensor.",
  2523. )
  2524. start = start.item() # type: ignore[assignment]
  2525. torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
  2526. torch._check(length >= 0, lambda: "narrow(): length must be non-negative.")
  2527. dim = utils.canonicalize_dim(a.ndim, dim)
  2528. dim_length = a.size(dim)
  2529. torch._check_with(
  2530. IndexError,
  2531. -dim_length <= start and start <= dim_length, # type: ignore[arg-type]
  2532. lambda: f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})",
  2533. )
  2534. if start < 0:
  2535. start = start + dim_length
  2536. torch._check(
  2537. start <= dim_length - length, # type: ignore[arg-type]
  2538. lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
  2539. )
  2540. return prims.slice_in_dim(a, start, start + length, axis=dim)
  2541. # TODO: This must return a sparse tensor if the input is sparse, but refs have
  2542. # no sparse support. See narrow_copy_sparse in core.
  2543. narrow_copy = _make_copy_from_view(narrow)
  2544. def _normalize(
  2545. a: Tensor, norm_dims: DimsType, eps: float
  2546. ) -> Tuple[Tensor, Tensor, Tensor]:
  2547. """Computes mean and 1/std of a tensor along norm_dims.
  2548. Used as a helper function for normalization layers.
  2549. Args:
  2550. a (Tensor): input tensor
  2551. norm_dims (DimsType): dimensions to normalize over
  2552. eps (float): epsilon for numerical stability
  2553. Returns:
  2554. out (Tensor): normalized tensor.
  2555. mean (Tensor): mean of the tensor along norm_dims.
  2556. rstd (Tensor): 1/std of the tensor along norm_dims.
  2557. """
  2558. norm_dims = utils.canonicalize_dims(a.ndim, norm_dims)
  2559. computation_dtype = utils.get_computation_dtype(a.dtype)
  2560. a_acc = _maybe_convert_to_dtype(a, computation_dtype)
  2561. assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean
  2562. biased_var, mean = torch.var_mean(
  2563. a_acc, dim=norm_dims, unbiased=False, keepdim=True
  2564. )
  2565. rstd = torch.rsqrt(biased_var + eps)
  2566. out = (a - mean) * rstd
  2567. return out, mean, rstd
  2568. # add all specified dimensions
  2569. def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType:
  2570. for dim in sorted(dimensions):
  2571. x = torch.unsqueeze(x, dim)
  2572. return x
  2573. @register_decomposition(aten.native_group_norm.default)
  2574. def native_group_norm(
  2575. input: Tensor,
  2576. weight: Optional[Tensor],
  2577. bias: Optional[Tensor],
  2578. batch_size: int,
  2579. num_channels: int,
  2580. flattened_inner_size: int,
  2581. num_groups: int,
  2582. eps: float,
  2583. ) -> Tuple[Tensor, Tensor, Tensor]:
  2584. torch._check(
  2585. input.ndim >= 2,
  2586. lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
  2587. )
  2588. torch._check(
  2589. num_channels % num_groups == 0,
  2590. lambda: "Expected number of channels in input to be divisible by num_groups, "
  2591. + f"but got input of shape {input.shape} and num_groups = {num_groups}",
  2592. )
  2593. # num_channels / num_groups and flattened inner dimension are the reduction axes
  2594. reduction_dims = [2, 3]
  2595. input_reshaped = torch.reshape(
  2596. input,
  2597. [batch_size, num_groups, num_channels // num_groups, flattened_inner_size],
  2598. )
  2599. out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps)
  2600. out = out.view(input.shape)
  2601. broadcast_dims = [0] + list(range(2, input.ndim))
  2602. unsqueeze_bias = None
  2603. if bias is not None:
  2604. unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims)
  2605. unsqueeze_weight = None
  2606. if weight is not None:
  2607. unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims)
  2608. if unsqueeze_weight is not None:
  2609. out = out * unsqueeze_weight
  2610. if unsqueeze_bias is not None:
  2611. out = out + unsqueeze_bias
  2612. out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment]
  2613. mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment]
  2614. rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment]
  2615. # remove broadcast dimensions from mean and rstd
  2616. mean = torch.squeeze(mean, reduction_dims)
  2617. rstd = torch.squeeze(rstd, reduction_dims)
  2618. return (out, mean, rstd)
  2619. @register_decomposition(aten.native_layer_norm)
  2620. @out_wrapper("out0", "out1", "out2")
  2621. def native_layer_norm(
  2622. input: Tensor,
  2623. normalized_shape: ShapeType,
  2624. weight: Optional[Tensor],
  2625. bias: Optional[Tensor],
  2626. eps: float,
  2627. ) -> Tuple[Tensor, Tensor, Tensor]:
  2628. normalized_ndim = len(normalized_shape)
  2629. torch._check(
  2630. normalized_ndim >= 1,
  2631. lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., "
  2632. + "containing at least one element, but got normalized_shape = "
  2633. + str(normalized_shape),
  2634. )
  2635. # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False
  2636. # while torch.Size([1, 2, 3]) == (1, 2, 3) is True
  2637. # therefore we use tuple(normalized_shape)
  2638. torch._check(
  2639. weight is None or weight.shape == tuple(normalized_shape),
  2640. lambda: "Expected weight to be of same shape as normalized_shape, but got "
  2641. + "weight of shape "
  2642. + str(weight.shape) # type: ignore[union-attr]
  2643. + " and normalized_shape = "
  2644. + str(normalized_shape),
  2645. )
  2646. torch._check(
  2647. bias is None or bias.shape == tuple(normalized_shape),
  2648. lambda: "Expected bias to be of same shape as normalized_shape, but got "
  2649. + "bias of shape "
  2650. + str(bias.shape) # type: ignore[union-attr]
  2651. + " and normalized_shape = "
  2652. + str(normalized_shape),
  2653. )
  2654. torch._check(
  2655. input.ndim >= normalized_ndim
  2656. and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
  2657. lambda: "Given normalized_shape="
  2658. + str(normalized_shape)
  2659. + ", expected input with shape "
  2660. + str(normalized_shape)
  2661. + ", but got input of size "
  2662. + str(input.shape),
  2663. )
  2664. input = input.contiguous()
  2665. if weight is not None:
  2666. weight = weight.contiguous()
  2667. if bias is not None:
  2668. bias = bias.contiguous()
  2669. axis = input.ndim - normalized_ndim
  2670. reduction_dims = list(range(axis, input.ndim))
  2671. out, mean, rstd = _normalize(input, reduction_dims, eps)
  2672. if weight is None and bias is not None:
  2673. out = out + bias
  2674. elif weight is not None and bias is None:
  2675. out = out * weight
  2676. elif weight is not None and bias is not None:
  2677. out = out * weight + bias
  2678. out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment]
  2679. if input.device.type == "cpu":
  2680. mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment]
  2681. rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment]
  2682. return (out, mean, rstd)
  2683. # TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode.
  2684. # test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu
  2685. @register_decomposition(aten.permute)
  2686. def permute(a: TensorLikeType, *dims) -> TensorLikeType:
  2687. _permutation = utils.canonicalize_dims(
  2688. a.ndim, utils.extract_dims_from_varargs(dims)
  2689. )
  2690. return prims.transpose(a, _permutation)
  2691. @register_decomposition(aten.renorm)
  2692. @out_wrapper()
  2693. def renorm(
  2694. input: TensorLikeType, p: RealNumberType, dim: int, maxnorm: RealNumberType
  2695. ) -> TensorLikeType:
  2696. torch._check(not isinstance(p, complex), lambda: "renorm: p must be real-valued")
  2697. torch._check(p > 0, lambda: "renorm: non-positive norm not supported")
  2698. torch._check(
  2699. not isinstance(maxnorm, complex), lambda: "renorm: maxnorm must be real-valued"
  2700. )
  2701. torch._check(
  2702. maxnorm >= 0, lambda: f"renorm: expected maxnorm to be >= 0 but got {maxnorm}"
  2703. )
  2704. ndim = input.ndim
  2705. torch._check(
  2706. ndim > 1,
  2707. lambda: f"renorm: input needs at least 2 dimensions, got {ndim} dimensions",
  2708. )
  2709. dim = utils.canonicalize_dim(ndim, dim)
  2710. reduce_dims = list(range(ndim))
  2711. del reduce_dims[dim]
  2712. # For half and bfloat16, calculate norm in float precision then cast
  2713. # normalization factor to half
  2714. acc_type = utils.get_computation_dtype(input.dtype)
  2715. if acc_type != input.dtype:
  2716. norm = torch.linalg.vector_norm(
  2717. input, p, reduce_dims, keepdim=True, dtype=acc_type
  2718. )
  2719. else:
  2720. norm = torch.linalg.vector_norm(input, p, reduce_dims, keepdim=True)
  2721. eps = 1e-7
  2722. norm_factor = torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0)
  2723. if acc_type != input.dtype:
  2724. norm_factor = prims.convert_element_type(norm_factor, input.dtype)
  2725. return (input * norm_factor).contiguous()
  2726. # CompositeImplicitAutograd - don't register decomp
  2727. @aten.stft.center.py_impl(DispatchKey.CompositeImplicitAutograd)
  2728. def stft(
  2729. input: Tensor,
  2730. n_fft: int,
  2731. hop_length: Optional[int] = None,
  2732. win_length: Optional[int] = None,
  2733. window: Optional[Tensor] = None,
  2734. center: bool = True,
  2735. pad_mode: str = "reflect",
  2736. normalized: bool = False,
  2737. onesided: Optional[bool] = None,
  2738. return_complex: Optional[bool] = None,
  2739. ) -> Tensor:
  2740. torch._check(
  2741. window is None or window.device == input.device,
  2742. lambda: (
  2743. f"stft input and window must be on the same device but got self on {input.device}"
  2744. + f" and window on {window.device}" # type: ignore[union-attr]
  2745. ),
  2746. )
  2747. hop_length_ = hop_length if hop_length is not None else n_fft // 4
  2748. win_length_ = win_length if win_length is not None else n_fft
  2749. if return_complex is None:
  2750. return_complex_ = input.is_complex() or (
  2751. window is not None and utils.is_complex_dtype(window.dtype)
  2752. )
  2753. torch._check(
  2754. return_complex_,
  2755. (
  2756. "stft requires the return_complex parameter be given for real inputs, "
  2757. + "and will further require that return_complex=True in a future PyTorch release."
  2758. ),
  2759. )
  2760. else:
  2761. return_complex_ = return_complex
  2762. torch._check(
  2763. utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype),
  2764. lambda: "stft expected a tensor of floating point or complex values",
  2765. )
  2766. torch._check(1 <= input.ndim <= 2, lambda: "stft expected a 1D or 2D tensor")
  2767. original_ndim = input.ndim
  2768. if original_ndim == 1:
  2769. input = input.unsqueeze(0)
  2770. if center:
  2771. extra_dims = 3 - input.ndim
  2772. pad_amount = n_fft // 2
  2773. extended_shape = [*itertools.repeat(1, extra_dims), *input.shape]
  2774. input = aten.pad(input.view(extended_shape), [pad_amount, pad_amount], pad_mode)
  2775. input = input.view(input.size()[extra_dims:])
  2776. batch = input.size(0)
  2777. length = input.size(1)
  2778. torch._check(
  2779. 0 < n_fft <= length,
  2780. lambda: f"stft expected 0 < n_fft <= {length}, but got n_fft={n_fft}",
  2781. )
  2782. torch._check(
  2783. hop_length_ > 0,
  2784. lambda: f"stft expected hop_length > 0 but got hop_length={hop_length_}",
  2785. )
  2786. torch._check(
  2787. 0 < win_length_ <= n_fft,
  2788. lambda: f"stft expected 0 < win_length <= n_fft but got win_length={win_length_}",
  2789. )
  2790. torch._check(
  2791. window is None or window.shape == (win_length_,),
  2792. lambda: (
  2793. f"expected a 1D window tensor of size equal to win_length={win_length_}, "
  2794. + f"but got window with size {window.shape}" # type: ignore[union-attr]
  2795. ),
  2796. )
  2797. if win_length_ < n_fft:
  2798. if window is None:
  2799. window = torch.ones(win_length_, dtype=input.dtype, device=input.device)
  2800. left = (n_fft - win_length_) // 2
  2801. window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left])
  2802. input = input.unfold(dimension=-1, size=n_fft, step=hop_length_)
  2803. if window is not None:
  2804. input = input * window
  2805. complex_fft = utils.is_complex_dtype(input.dtype)
  2806. onesided = onesided if onesided is not None else not complex_fft
  2807. norm = "ortho" if normalized else None
  2808. if onesided:
  2809. torch._check(
  2810. not complex_fft,
  2811. lambda: "Cannot have onesided output if window or input is complex",
  2812. )
  2813. out = torch.fft.rfft(input, dim=-1, norm=norm)
  2814. else:
  2815. out = torch.fft.fft(input, dim=-1, norm=norm)
  2816. out.transpose_(1, 2)
  2817. if original_ndim == 1:
  2818. out = out.squeeze_(0)
  2819. return out if return_complex_ else torch.view_as_real(out)
  2820. # CompositeImplicitAutograd - don't register decomp
  2821. @aten.istft.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2822. def istft(
  2823. input: Tensor,
  2824. n_fft: int,
  2825. hop_length: Optional[int] = None,
  2826. win_length: Optional[int] = None,
  2827. window: Optional[Tensor] = None,
  2828. center: bool = True,
  2829. normalized: bool = False,
  2830. onesided: Optional[bool] = None,
  2831. length: Optional[int] = None,
  2832. return_complex=False,
  2833. ) -> Tensor:
  2834. torch._check(
  2835. window is None or window.device == input.device,
  2836. lambda: (
  2837. f"istft input and window must be on the same device but got self on {input.device}"
  2838. + f" and window on {window.device}" # type: ignore[union-attr]
  2839. ),
  2840. )
  2841. hop_length_ = hop_length if hop_length is not None else n_fft // 4
  2842. win_length_ = win_length if win_length is not None else n_fft
  2843. torch._check(
  2844. utils.is_complex_dtype(input.dtype),
  2845. lambda: (
  2846. "istft input and window must be on the same device but got self on "
  2847. + f"{input.device} and window on {window.device}" # type: ignore[union-attr]
  2848. ),
  2849. )
  2850. n_frames = input.size(-1)
  2851. fft_size = input.size(-2)
  2852. expected_output_signal_len = n_fft + hop_length_ * (n_frames - 1)
  2853. torch._check(input.numel() > 0, lambda: "istft input tensor cannot be empty")
  2854. torch._check(
  2855. 2 <= input.ndim <= 3,
  2856. lambda: f"istft expected a tensor with 2 or 3 dimensions, but got {input.ndim}",
  2857. )
  2858. onesided_ = onesided if onesided is not None else fft_size != n_fft
  2859. if onesided_:
  2860. torch._check(
  2861. n_fft // 2 + 1 == fft_size,
  2862. lambda: (
  2863. "istft expected the frequency dimension (3rd to the last) of the input tensor "
  2864. + "to match n_fft / 2 + 1 when onesided=True, but got {fft_size}"
  2865. ),
  2866. )
  2867. else:
  2868. torch._check(
  2869. n_fft == fft_size,
  2870. lambda: (
  2871. "istft expected the frequency dimension (3rd to the last) of the input tensor "
  2872. + "to match n_fft when onesided=False, but got {fft_size}",
  2873. ),
  2874. )
  2875. torch._check(
  2876. 0 < hop_length_ <= win_length_,
  2877. lambda: "istft expected 0 < hop_length <= win_length",
  2878. )
  2879. torch._check(
  2880. 0 < win_length_ <= n_fft, lambda: "istft expected 0 < win_length <= n_fft"
  2881. )
  2882. torch._check(
  2883. window is None or window.shape == (win_length_,),
  2884. lambda: "Invalid window shape. window has to be 1D and length of `win_length`",
  2885. )
  2886. if window is None:
  2887. real_dtype = utils.corresponding_real_dtype(input.dtype)
  2888. window_ = torch.ones(win_length_, dtype=real_dtype, device=input.device)
  2889. else:
  2890. window_ = window
  2891. if win_length_ != n_fft:
  2892. left = (n_fft - win_length_) // 2
  2893. window_ = aten.constant_pad_nd(window_, (left, n_fft - win_length_ - left), 0)
  2894. original_ndim = input.ndim
  2895. if input.ndim == 2:
  2896. input = input.unsqueeze(0)
  2897. input = input.transpose(1, 2)
  2898. norm = "ortho" if normalized else None
  2899. if return_complex:
  2900. torch._check(
  2901. not onesided_,
  2902. lambda: "cannot have onesided output if window or input is complex",
  2903. )
  2904. input = torch.fft.ifft(input, dim=-1, norm=norm)
  2905. else:
  2906. torch._check(
  2907. window is None or not utils.is_complex_dtype(window.dtype),
  2908. lambda: "Complex windows are incompatible with return_complex=False",
  2909. )
  2910. if not onesided_:
  2911. input = input.narrow(dim=-1, start=0, length=n_fft // 2 + 1)
  2912. input = torch.fft.irfft(input, dim=-1, norm=norm)
  2913. assert input.size(2) == n_fft
  2914. y_tmp = input * window_.view([1, 1, n_fft])
  2915. y = aten.unfold_backward(
  2916. y_tmp,
  2917. input_sizes=(y_tmp.size(0), expected_output_signal_len),
  2918. dim=1,
  2919. size=n_fft,
  2920. step=hop_length_,
  2921. )
  2922. window_envelop = aten.unfold_backward(
  2923. window_.pow(2).expand((1, n_frames, n_fft)),
  2924. input_sizes=(y_tmp.size(0), expected_output_signal_len),
  2925. dim=1,
  2926. size=n_fft,
  2927. step=hop_length_,
  2928. )
  2929. assert expected_output_signal_len == y.size(1)
  2930. assert expected_output_signal_len == window_envelop.size(1)
  2931. start = n_fft // 2 if center else 0
  2932. if length is not None:
  2933. end = start + length
  2934. elif center:
  2935. end = expected_output_signal_len - n_fft // 2
  2936. else:
  2937. end = expected_output_signal_len
  2938. length = max(0, end - start)
  2939. y = y.narrow(dim=1, start=start, length=length)
  2940. window_envelop = window_envelop.narrow(dim=1, start=start, length=length)
  2941. window_envelop_lowest = window_envelop.abs().min().lt(1e-11)
  2942. torch._check(
  2943. not window_envelop_lowest.item(),
  2944. lambda: "window overlap add min less than 1e-11",
  2945. )
  2946. y = y / window_envelop
  2947. if original_ndim == 2:
  2948. y = y.squeeze(0)
  2949. if end > expected_output_signal_len:
  2950. warnings.warn(
  2951. "The length of signal is shorter than the length parameter. Result is being "
  2952. + "padded with zeros in the tail. Please check your center and hop_length settings"
  2953. )
  2954. y = aten.constant_pad_nd(y, (0, end - expected_output_signal_len), 0)
  2955. return y
  2956. # Get the new shape and stride after applying unfold to an input tensor
  2957. def _get_unfold_shape_stride(
  2958. a_shape: ShapeType, a_stride: StrideType, dimension: int, size: int, step: int
  2959. ):
  2960. a_ndim = len(a_shape)
  2961. dim = utils.canonicalize_dim(a_ndim, dimension, wrap_scalar=True)
  2962. max_size = 1 if a_ndim == 0 else a_shape[dim]
  2963. last_stride = 1 if a_ndim == 0 else a_stride[dim]
  2964. torch._check(
  2965. size <= max_size,
  2966. lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}",
  2967. )
  2968. torch._check(
  2969. step > 0,
  2970. lambda: f"Step is {step} but must be > 0",
  2971. )
  2972. shape = list(a_shape)
  2973. strides = list(a_stride)
  2974. shape.append(size)
  2975. strides.append(last_stride)
  2976. if dim < a_ndim:
  2977. shape[dim] = (shape[dim] - size) // step + 1
  2978. strides[dim] *= step
  2979. return shape, strides
  2980. @register_decomposition(aten.repeat)
  2981. @out_wrapper()
  2982. def repeat(a: Tensor, *repeat_shape) -> Tensor:
  2983. repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False)
  2984. torch._check(
  2985. len(repeat_shape) >= len(a.shape),
  2986. lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
  2987. )
  2988. if len(repeat_shape) == 0:
  2989. return torch.clone(a)
  2990. num_new_dimensions = len(repeat_shape) - a.ndim
  2991. padded_shape = [1] * num_new_dimensions
  2992. for dim_size in a.shape:
  2993. padded_shape.append(dim_size)
  2994. target_shape = tuple(
  2995. padded_size * repeat_size
  2996. for padded_size, repeat_size in zip(padded_shape, repeat_shape)
  2997. )
  2998. # return an empty tensor if one of the repeat_shape dimensions is zero
  2999. if 0 in repeat_shape:
  3000. return torch.empty(
  3001. target_shape,
  3002. dtype=a.dtype,
  3003. device=a.device,
  3004. requires_grad=a.requires_grad,
  3005. memory_format=utils.suggest_memory_format(a),
  3006. )
  3007. urtensor_shape = target_shape
  3008. urtensor_stride = utils.make_contiguous_strides_for(target_shape)
  3009. for dim, dim_size in enumerate(padded_shape):
  3010. # repeat each dimension by using unfold_copy operation
  3011. urtensor_shape, urtensor_stride = _get_unfold_shape_stride(
  3012. urtensor_shape, urtensor_stride, dim, dim_size, max(dim_size, 1)
  3013. )
  3014. # derive permute order by sorting urtensor strides
  3015. enumerated_stride = list(enumerate(urtensor_stride))
  3016. enumerated_stride.sort(key=operator.itemgetter(1), reverse=True)
  3017. permute_order, sorted_stride = zip(*enumerated_stride)
  3018. # add new and expand dimensions according to urtensor
  3019. repeat_xtensor = a.expand(urtensor_shape)
  3020. # clone tensor to concretize expanded dimensions
  3021. cloned_result = torch.clone(repeat_xtensor)
  3022. # transpose axis so strides are in sorted order
  3023. permuted_result = cloned_result.permute(permute_order)
  3024. # reshape to get contiguous tensor with correct target shape
  3025. return permuted_result.reshape(target_shape)
  3026. def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
  3027. from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq
  3028. # Creates a valid shape
  3029. shape = utils.extract_shape_from_varargs(shape, validate=False)
  3030. # Reshape may be given a shape with a -1 length
  3031. # This indicates that the dimension's length should be inferred
  3032. shape = utils.infer_size(shape, a.numel())
  3033. # Special-cases tensors with no elements
  3034. if guard_size_oblivious(a.numel() == 0):
  3035. return as_strided(a, shape, utils.make_contiguous_strides_for(shape))
  3036. # Special-cases reshaping zero dim tensors
  3037. if a.ndim == 0:
  3038. _a = a
  3039. for length in shape:
  3040. assert length == 1
  3041. _a = unsqueeze(_a, -1)
  3042. if _a is a:
  3043. return prims.view_of(a)
  3044. else:
  3045. return _a
  3046. # Special-cases reshaping to zero dim tensors
  3047. if len(shape) == 0:
  3048. _a = a
  3049. for length in a.shape:
  3050. assert length == 1
  3051. _a = squeeze(_a, -1)
  3052. if _a is a:
  3053. return prims.view_of(a)
  3054. else:
  3055. return _a
  3056. if a.is_contiguous():
  3057. # Special-cases for nd_to_1d
  3058. if len(shape) == 1 and a.ndim > 1:
  3059. return torch.as_strided(a, [a.numel()], [1])
  3060. # Special-cases for 1d_to_2d
  3061. if len(shape) == 2 and a.ndim == 1:
  3062. dim0 = shape[0]
  3063. dim1 = shape[1]
  3064. return torch.as_strided(a, [dim0, dim1], [dim1, 1])
  3065. # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
  3066. # NOTE [Reshape Algorithm]
  3067. # This algorithm works by attempting to greedily construct the desired dimensions in
  3068. # the output shape, left to right. It does this by, conceptually, accumulating
  3069. # dimensions of the original tensor, also left to right, until the dimension
  3070. # can be constructed using prims.split_dim.
  3071. # The algorithm also has special handling for tail squeezes/unsqueezes, like
  3072. # if a reshape from (5, 5) to (5, 5, 1) or vice versa.
  3073. #
  3074. # This algorithm does not flatten the original tensor and then split dims as appropriate
  3075. # because that would create copies more often than this algorithm. flatten is the only
  3076. # operation below which can create a view or a copy, and while it prefers creating
  3077. # views it may sometimes create a copy if the tensor's strides do not permit a view.
  3078. # As a result, this algorithm tries to minimize flattening.
  3079. #
  3080. # Note that a better version of this algorithm may exist. Regions which could be
  3081. # flattened without creating a copy can be identified in advance, and that might
  3082. # allow fewer flatten calls or faster short-circuiting to make a copy.
  3083. idx = 0
  3084. a_ = a
  3085. for length in shape:
  3086. # Handles tail unsqueezes
  3087. if idx >= a_.ndim:
  3088. assert length == 1
  3089. last_dim = a_.ndim - 1
  3090. # NOTE: using split_dim instead of unsqueeze may seem silly here,
  3091. # but it's necessary to get the strides correct
  3092. a_ = prims.split_dim(a_, last_dim, a_.shape[last_dim])
  3093. idx = idx + 1
  3094. continue
  3095. # Skips dimensions that are already the correct length
  3096. if guard_size_oblivious(length == a_.shape[idx]):
  3097. idx = idx + 1
  3098. continue
  3099. # Gathers enough original dimensions such that this new dimension can be created
  3100. # Note that this accumulation will terminate because we've verified a and the shape
  3101. # specify the same number of elements above
  3102. accum = a_.shape[idx]
  3103. end = idx
  3104. while guard_size_oblivious(accum % length != 0):
  3105. end = end + 1
  3106. accum = accum * a_.shape[end]
  3107. if end != idx:
  3108. # NOTE: in this case multiple dimensions must be flatten to create the desired dimension
  3109. # This flattening is why reshape sometimes creates a copy -- because flattening
  3110. # may return a view of a copy
  3111. # Checks if collapse can be a view and short-circuits to copying reshape if it can't
  3112. new_shape, new_strides = prims._collapse_view_helper(a_, idx, end)
  3113. if new_shape is None:
  3114. if allow_copy:
  3115. return prims.reshape(a, shape)
  3116. msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
  3117. raise ValueError(msg)
  3118. a_ = flatten(a_, idx, end)
  3119. # Splits the (possibly flattened) dimension to create the desired dim length
  3120. if guard_size_oblivious(accum != length):
  3121. a_ = prims.split_dim(a_, idx, length)
  3122. idx = idx + 1
  3123. # Squeezes tail
  3124. while idx < a_.ndim:
  3125. torch._check(
  3126. a_.shape[idx] == 1,
  3127. lambda: f"a.size({idx}) expected to be 1 but got {a_.shape[idx]}",
  3128. )
  3129. a_ = squeeze(a_, idx)
  3130. if a_ is a:
  3131. return prims.view_of(a)
  3132. else:
  3133. return a_
  3134. # CompositeImplicitAutograd - don't register decomp
  3135. # NOTE: shape is a vararg because Tensor.reshape can be called with as
  3136. # Tensor.reshape(a, b, c) or Tensor.reshape((a, b, c)) Function call
  3137. # torch.reshape doesn't support unpacked shapes
  3138. def reshape(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
  3139. return _reshape_view_helper(a, *shape, allow_copy=True)
  3140. # CompositeImplicitAutograd - don't register decomp
  3141. def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
  3142. return self.reshape(other.size())
  3143. @register_decomposition(aten.roll)
  3144. @out_wrapper()
  3145. def roll(
  3146. a: TensorLikeType, shifts: DimsType, dims: DimsType = tuple()
  3147. ) -> TensorLikeType:
  3148. """Reference implementation of :func:`torch.roll`."""
  3149. dims = utils.canonicalize_dims(a.ndim, dims)
  3150. # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
  3151. if not isinstance(shifts, Iterable):
  3152. shifts = (shifts,)
  3153. if not isinstance(dims, Iterable):
  3154. dims = (dims,)
  3155. # Avoid modulo by zero
  3156. if a.numel() == 0:
  3157. # Keeping this as ref for now as FakeTensor runs into some issues with complex tensors
  3158. return a.clone()
  3159. if a.dim() == 0 and len(dims) > 0:
  3160. raise IndexError(
  3161. f"Dimension specified as {dims[0]} but tensor has no dimensions"
  3162. )
  3163. len_shifts = len(shifts)
  3164. len_dims = len(dims)
  3165. if len_shifts != 1 or len_dims != 1:
  3166. if len_shifts == 0:
  3167. raise RuntimeError("`shifts` required")
  3168. # Takes care of the case when dims is not specified (default)
  3169. # By default, the tensor is flattened before shifting, after which the original shape is restored
  3170. if len_dims == 0 and len_shifts == 1:
  3171. return torch.roll(torch.flatten(a), shifts, 0).view(a.shape)
  3172. if len_shifts != len_dims:
  3173. raise RuntimeError(
  3174. f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
  3175. )
  3176. assert len_dims > 1
  3177. tail_shifts = shifts[1:]
  3178. tail_dims = dims[1:]
  3179. first_dim_rolled = torch.roll(a, (shifts[0],), dims[0])
  3180. return torch.roll(first_dim_rolled, tail_shifts, tail_dims)
  3181. # This path is taken when only one dimension is rolled
  3182. # For example to get `first_dim_rolled` above
  3183. dim = dims[0]
  3184. size = a.shape[dim]
  3185. start = (size - shifts[0]) % size
  3186. idx = torch.arange(size, device=a.device)
  3187. return a.index_select(dim, torch.fmod(start + idx, size))
  3188. @register_decomposition(aten.rot90)
  3189. @out_wrapper()
  3190. def rot90(
  3191. a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1)
  3192. ) -> TensorLikeType:
  3193. """Reference implementation of :func:`torch.rot90`."""
  3194. if len(dims) != 2:
  3195. raise RuntimeError(
  3196. f"expected total rotation dims == 2, but got dims = {len(dims)}"
  3197. )
  3198. if a.ndim < 2:
  3199. raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}")
  3200. # Do this after the initial checks to be compatible with the behavior in
  3201. # core.
  3202. dims = utils.canonicalize_dims(a.ndim, dims)
  3203. if dims[0] == dims[1]:
  3204. raise RuntimeError(
  3205. f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
  3206. )
  3207. k = k % 4 # Rotation direction is from the second towards the first axis for k < 0
  3208. if k == 1:
  3209. return torch.transpose(torch.flip(a, (dims[1],)), dims[0], dims[1])
  3210. elif k == 2:
  3211. return torch.flip(a, dims)
  3212. elif k == 3:
  3213. return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1])
  3214. else:
  3215. return clone(a, memory_format=torch.contiguous_format)
  3216. def _check_stack_inputs(tensors: TensorSequenceType) -> None:
  3217. entry_shape = tensors[0].shape
  3218. for i in range(1, len(tensors)):
  3219. assert tensors[i].shape == entry_shape, (
  3220. f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 "
  3221. f"and {tensors[i].shape} at entry {i}"
  3222. )
  3223. @register_decomposition(aten.stack)
  3224. @out_wrapper()
  3225. def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
  3226. assert len(tensors) > 0, "stack expects a non-empty TensorList"
  3227. wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim)
  3228. # Refs need sparse support to check other condition
  3229. if wrapped_dim < tensors[0].ndim: # and not tensors[0].is_sparse:
  3230. _check_stack_inputs(tensors)
  3231. result_sizes = list(tensors[0].shape)
  3232. result_sizes.insert(wrapped_dim, len(tensors))
  3233. out = torch.cat(tensors, wrapped_dim)
  3234. return out.view(result_sizes)
  3235. # If dim == tensors[0].ndim, view cannot efficiently handle it
  3236. return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim)
  3237. # CompositeImplicitAutograd - don't register decomp
  3238. @out_wrapper()
  3239. def softmax(
  3240. a: TensorLikeType,
  3241. dim: int,
  3242. dtype: Optional[torch.dtype] = None,
  3243. ) -> TensorLikeType:
  3244. result_dtype = dtype or a.dtype
  3245. computation_dtype = utils.get_computation_dtype(result_dtype)
  3246. a_ = _maybe_convert_to_dtype(a, computation_dtype)
  3247. if a.numel() == 0:
  3248. a_exp = exp(a_)
  3249. else:
  3250. a_max = amax(a_, dim, keepdim=True)
  3251. a_exp = exp(a_ - a_max)
  3252. return _maybe_convert_to_dtype(
  3253. true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype
  3254. ) # type: ignore[return-value]
  3255. # CompositeImplicitAutograd - don't register decomp
  3256. @out_wrapper()
  3257. def hstack(tensors: TensorSequenceType) -> TensorLikeType:
  3258. torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
  3259. aligned_tensors = atleast_1d(*tensors)
  3260. if aligned_tensors[0].ndim == 1:
  3261. return cat(aligned_tensors, 0)
  3262. return cat(aligned_tensors, 1)
  3263. # CompositeImplicitAutograd - don't register decomp
  3264. @out_wrapper()
  3265. def vstack(tensors: TensorSequenceType) -> TensorLikeType:
  3266. torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
  3267. aligned_tensors = atleast_2d(*tensors)
  3268. return cat(aligned_tensors, 0)
  3269. # CompositeImplicitAutograd - don't register decomp
  3270. def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:
  3271. dim = utils.canonicalize_dim(a.ndim, dim)
  3272. torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
  3273. return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :]))
  3274. @register_decomposition(aten.unbind)
  3275. def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
  3276. from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
  3277. dim = utils.canonicalize_dim(t.ndim, dim)
  3278. torch._check_index(
  3279. len(t.shape) > 0,
  3280. lambda: "Dimension specified as 0 but tensor has no dimensions",
  3281. )
  3282. if guard_size_oblivious(t.shape[dim] == 0):
  3283. return tuple()
  3284. else:
  3285. return tuple(
  3286. torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
  3287. )
  3288. @out_wrapper()
  3289. def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
  3290. return x.clone(memory_format=torch.contiguous_format).index_copy_(
  3291. dim, index, tensor
  3292. )
  3293. def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
  3294. dim = utils.canonicalize_dims(x.ndim, dim)
  3295. torch._check(
  3296. index.ndim <= 1,
  3297. lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
  3298. )
  3299. # Treat scalars as elements of \R^1
  3300. y = x.unsqueeze(0) if x.ndim == 0 else x
  3301. idx = (slice(None),) * dim + (index,)
  3302. y[idx] = tensor
  3303. return x
  3304. @register_decomposition(aten.index_fill)
  3305. @out_wrapper()
  3306. def index_fill(
  3307. x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
  3308. ):
  3309. return _index_fill(x, dim, index, value, inplace=False)
  3310. @register_decomposition(aten.index_fill_)
  3311. def index_fill_(
  3312. x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
  3313. ):
  3314. return _index_fill(x, dim, index, value, inplace=True)
  3315. def _index_fill(
  3316. x: TensorLike,
  3317. dim: int,
  3318. index: TensorLike,
  3319. value: Union[NumberType, TensorLike],
  3320. *,
  3321. inplace: bool,
  3322. ):
  3323. torch._check(
  3324. index.ndim <= 1,
  3325. lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
  3326. )
  3327. if isinstance(value, TensorLike):
  3328. torch._check(
  3329. value.ndim == 0,
  3330. lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr]
  3331. f"Got a tensor with {value.ndim} dimensions.",
  3332. ) # type: ignore[arg-type]
  3333. else:
  3334. value = torch.scalar_tensor(
  3335. value, dtype=x.dtype, layout=x.layout, device=x.device # type: ignore[arg-type]
  3336. )
  3337. # index_copy has some unnecessary preconditions when x is a scalar. We do this to work through them
  3338. zero_dim = x.ndim == 0
  3339. y = x.unsqueeze(0) if zero_dim else x
  3340. # index_copy does not broadcast on value so we have to do it manually
  3341. shape = list(y.shape)
  3342. shape[dim] = index.numel()
  3343. value = value.expand(shape)
  3344. index_copy = Tensor.index_copy_ if inplace else torch.index_copy
  3345. out = index_copy(y, dim, index, value) # type: ignore[operator]
  3346. if inplace:
  3347. return x
  3348. else:
  3349. if zero_dim:
  3350. # The clone is necessary so that it returns a fresh tensor rather than a view
  3351. out = out.squeeze(0).clone()
  3352. # index_fill preserves the strides. index_copy always returns contiguous tensors
  3353. if out.stride() != x.stride():
  3354. new_out = torch.empty_like(x)
  3355. new_out.copy_(out)
  3356. out = new_out
  3357. return out
  3358. @out_wrapper()
  3359. def index_add(
  3360. x: TensorLike,
  3361. dim: int,
  3362. index: TensorLike,
  3363. tensor: TensorLike,
  3364. *,
  3365. alpha: NumberType = 1,
  3366. ):
  3367. # index_add always returns a new contiguous tensor
  3368. return x.clone(memory_format=torch.contiguous_format).index_add_(
  3369. dim, index, tensor, alpha=alpha # type: ignore[arg-type]
  3370. )
  3371. @register_decomposition(aten.index_select)
  3372. @out_wrapper()
  3373. def index_select(x: TensorLike, dim: int, index: TensorLike):
  3374. dim = utils.canonicalize_dims(x.ndim, dim)
  3375. torch._check(
  3376. index.ndim <= 1,
  3377. lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
  3378. )
  3379. if index.ndim == 0:
  3380. index = index.unsqueeze(0)
  3381. if x.ndim == 0:
  3382. # Treat scalars as elements of \R^1
  3383. # We cannot use x[idx] here as it accesses item() (??), hence this awkward construction
  3384. return torch.empty_like(x).index_copy(0, index, x.expand_as(index))
  3385. idx = (slice(None),) * dim + (index,)
  3386. return x[idx]
  3387. @register_decomposition(aten.squeeze.dims)
  3388. def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
  3389. from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
  3390. if dim is None:
  3391. dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1)
  3392. return prims.squeeze(a, dims) if dims else prims.view_of(a)
  3393. ndim = a.ndim
  3394. dim = utils.canonicalize_dims(ndim, dim)
  3395. dims = (dim,) if isinstance(dim, Dim) else dim
  3396. # Short-circuits if the tensor has no dimensions
  3397. if ndim == 0:
  3398. assert len(dims) == 0 or dims == (0,)
  3399. return prims.view_of(a)
  3400. # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1
  3401. dims = tuple(d for d in dims if guard_size_oblivious(a.shape[d] == 1))
  3402. if len(dims) == 0:
  3403. return prims.view_of(a)
  3404. if len(dims) == 1:
  3405. return prims.squeeze(a, dims)
  3406. dims_list = list(dims)
  3407. dims_list = sorted(dims_list, reverse=True)
  3408. for i in dims_list:
  3409. a = squeeze(a, i)
  3410. return a
  3411. # Note: does not work with TensorMetas because of data-dependent control-flow
  3412. # CompositeImplicitAutograd - don't register decomp
  3413. def tensor_split(
  3414. a: TensorLikeType,
  3415. indices_or_sections: Union[Tensor, DimsType],
  3416. dim: int = 0,
  3417. ) -> Tuple[TensorLikeType, ...]:
  3418. _dim = utils.canonicalize_dim(a.ndim, dim)
  3419. if a.ndim == 0:
  3420. msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!"
  3421. raise ValueError(msg)
  3422. # If indices_or_sections is a tensor, it must be a CPU Long tensor
  3423. if isinstance(indices_or_sections, TensorLike):
  3424. if not indices_or_sections.device.type == "cpu":
  3425. msg = (
  3426. f"tensor_split: if indices_or_sections is a tensor it must be on the CPU, "
  3427. f"but received one on {indices_or_sections.device}"
  3428. )
  3429. raise ValueError(msg)
  3430. if indices_or_sections.dtype != torch.long:
  3431. msg = "tensor_split: if indices_or_sections is a tensor it must have long dtype, "
  3432. f" but received one with dtype {indices_or_sections.dtype}"
  3433. raise ValueError(msg)
  3434. # Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length
  3435. if isinstance(indices_or_sections, IntLike) or (
  3436. isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0
  3437. ):
  3438. sections: int = (
  3439. indices_or_sections # type: ignore[assignment]
  3440. if isinstance(indices_or_sections, Number)
  3441. else indices_or_sections.item()
  3442. )
  3443. if sections <= 0:
  3444. msg = f"tensor_split: number of sections must be greater than 0, but was {sections}"
  3445. raise ValueError(msg)
  3446. splits = []
  3447. dim_size = a.shape[_dim]
  3448. min_split_size = math.floor(dim_size / sections)
  3449. num_splits_one_extra = dim_size % sections
  3450. start_idx = 0
  3451. for split_idx in range(sections):
  3452. split_size = (
  3453. min_split_size + 1
  3454. if (split_idx < num_splits_one_extra)
  3455. else min_split_size
  3456. )
  3457. s = prims.slice_in_dim(a, start_idx, start_idx + split_size, axis=_dim)
  3458. splits.append(s)
  3459. start_idx = start_idx + split_size
  3460. return tuple(splits)
  3461. # Case 1 -- indices_or_sections is a sequence of integers or a 1D tensor describing the splits
  3462. else:
  3463. indices = indices_or_sections
  3464. if isinstance(indices_or_sections, TensorLike):
  3465. if indices_or_sections.ndim != 1:
  3466. msg = "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, "
  3467. f"but received a tensor with {indices_or_sections.ndim} dimensions"
  3468. raise ValueError(msg)
  3469. indices = indices_or_sections.tolist()
  3470. splits = []
  3471. start_idx = 0
  3472. for x in indices:
  3473. splits.append(prims.slice_in_dim(a, start_idx, x, axis=_dim))
  3474. start_idx = x
  3475. splits.append(prims.slice_in_dim(a, start_idx, a.shape[_dim], axis=_dim))
  3476. return tuple(splits)
  3477. # CompositeImplicitAutograd - don't register decomp
  3478. def hsplit(
  3479. a: TensorLikeType, indices_or_sections: DimsType
  3480. ) -> Tuple[TensorLikeType, ...]:
  3481. torch._check(
  3482. a.ndim >= 1,
  3483. lambda: (
  3484. "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "
  3485. + str(a.ndim)
  3486. + " dimensions!"
  3487. ),
  3488. )
  3489. dim = 0 if a.ndim == 1 else 1
  3490. if isinstance(indices_or_sections, IntLike):
  3491. split_size = indices_or_sections
  3492. torch._check(
  3493. (split_size != 0 and a.shape[dim] % split_size == 0),
  3494. lambda: (
  3495. "torch.hsplit attempted to split along dimension "
  3496. + str(dim)
  3497. + ", but the size of the dimension "
  3498. + str(a.shape[dim])
  3499. + " is not divisible by the split_size "
  3500. + str(split_size)
  3501. + "!"
  3502. ),
  3503. )
  3504. return tensor_split(a, split_size, dim)
  3505. torch._check_type(
  3506. isinstance(indices_or_sections, (list, tuple)),
  3507. lambda: (
  3508. "hsplit(): received an invalid combination of arguments. "
  3509. "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
  3510. f"but got type {type(indices_or_sections)}"
  3511. ),
  3512. )
  3513. split_sizes = indices_or_sections
  3514. return tensor_split(a, split_sizes, dim)
  3515. # CompositeImplicitAutograd - don't register decomp
  3516. def vsplit(
  3517. a: TensorLikeType, indices_or_sections: DimsType
  3518. ) -> Tuple[TensorLikeType, ...]:
  3519. torch._check(
  3520. a.ndim >= 2,
  3521. lambda: (
  3522. "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "
  3523. + str(a.ndim)
  3524. + " dimensions!"
  3525. ),
  3526. )
  3527. if isinstance(indices_or_sections, IntLike):
  3528. split_size = indices_or_sections
  3529. torch._check(
  3530. (split_size != 0 and a.shape[0] % split_size == 0),
  3531. lambda: (
  3532. f"torch.vsplit attempted to split along dimension 0"
  3533. f", but the size of the dimension "
  3534. f"{a.shape[0]}"
  3535. f" is not divisible by the split_size "
  3536. f"{split_size}"
  3537. f"!"
  3538. ),
  3539. )
  3540. return tensor_split(a, split_size, 0)
  3541. torch._check_type(
  3542. isinstance(indices_or_sections, (list, tuple)),
  3543. lambda: (
  3544. "vsplit(): received an invalid combination of arguments. "
  3545. "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
  3546. f"but got type {type(indices_or_sections)}"
  3547. ),
  3548. )
  3549. split_sizes = indices_or_sections
  3550. return tensor_split(a, split_sizes, 0)
  3551. @register_decomposition(aten.diag.out)
  3552. @out_wrapper()
  3553. def diag(
  3554. self: TensorLikeType,
  3555. offset: int = 0,
  3556. ) -> TensorLikeType:
  3557. ndim = self.dim()
  3558. torch._check(
  3559. ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D"
  3560. )
  3561. if ndim == 1:
  3562. return torch.diag_embed(self, offset)
  3563. else:
  3564. return torch.diagonal_copy(self, offset)
  3565. @register_decomposition(aten.diagonal_scatter)
  3566. @out_wrapper()
  3567. def diagonal_scatter(
  3568. input: TensorLikeType,
  3569. src: TensorLikeType,
  3570. offset: int = 0,
  3571. dim1: int = 0,
  3572. dim2: int = 1,
  3573. ) -> TensorLikeType:
  3574. out = utils.clone_preserve_strides(input)
  3575. diag = out.diagonal(offset, dim1, dim2)
  3576. torch._check(
  3577. diag.shape == src.shape,
  3578. lambda: "expected src to have a size equal to the diagonal of the input."
  3579. f"Got {src.shape} for a diagonal of shape {diag.shape}",
  3580. )
  3581. copy_to(diag, src)
  3582. return out
  3583. @register_decomposition(aten.diagonal)
  3584. def diagonal(
  3585. self: TensorLikeType,
  3586. offset: int = 0,
  3587. dim1: int = 0,
  3588. dim2: int = 1,
  3589. ) -> TensorLikeType:
  3590. """
  3591. Reference implementation of torch.diagonal
  3592. """
  3593. num_dims = self.dim()
  3594. dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims)
  3595. dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims)
  3596. torch._check(
  3597. dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
  3598. )
  3599. storage_offset = self.storage_offset()
  3600. if offset >= 0:
  3601. diag_size = max(min(self.size()[dim1], self.size()[dim2] - offset), 0)
  3602. else:
  3603. diag_size = max(min(self.size()[dim1] + offset, self.size()[dim2]), 0)
  3604. if diag_size > 0:
  3605. if offset >= 0:
  3606. storage_offset += offset * self.stride()[dim2]
  3607. else:
  3608. storage_offset -= offset * self.stride()[dim1]
  3609. sizes = [s for i, s in enumerate(self.size()) if i not in (dim1, dim2)]
  3610. sizes.append(diag_size)
  3611. strides = [s for i, s in enumerate(self.stride()) if i not in (dim1, dim2)]
  3612. strides.append(self.stride()[dim1] + self.stride()[dim2])
  3613. result = self.as_strided(size=sizes, stride=strides, storage_offset=storage_offset)
  3614. return result
  3615. diagonal_copy = _make_copy_from_view(diagonal)
  3616. @register_decomposition(aten.diag_embed)
  3617. @out_wrapper()
  3618. def diag_embed(
  3619. t: TensorLikeType,
  3620. offset: int = 0,
  3621. dim1: int = -2,
  3622. dim2: int = -1,
  3623. ) -> TensorLikeType:
  3624. """
  3625. Reference implementation of torch.diag_embed
  3626. """
  3627. # convert from negative dims
  3628. rank = t.ndim + 1
  3629. dim1 = utils.canonicalize_dim(rank=rank, idx=dim1)
  3630. dim2 = utils.canonicalize_dim(rank=rank, idx=dim2)
  3631. # as per the docs, exchanging dims is equivalent to changing the sign of
  3632. # offset
  3633. if dim1 > dim2:
  3634. dim1, dim2 = dim2, dim1
  3635. offset = -offset
  3636. torch._check(
  3637. dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
  3638. )
  3639. # as per the docs, the size of last dim is placed at dim1 and dim2
  3640. last_dim = t.size(-1)
  3641. if offset != 0:
  3642. # add padding to match the new size
  3643. t_shape = list(t.shape)
  3644. t_shape[-1] = builtins.abs(offset)
  3645. z = torch.zeros(t_shape, dtype=t.dtype, device=t.device, requires_grad=False)
  3646. pair = (z, t) if offset > 0 else (t, z)
  3647. t = torch.cat(pair, dim=-1)
  3648. # make sure the diagonal always has the same size
  3649. last_dim += builtins.abs(offset)
  3650. # preserve original data, but place 1 at dim1 and move last dim to dim2
  3651. t = t.unsqueeze(dim1).movedim(-1, dim2)
  3652. # generate ranges shifting indices based on offset
  3653. a_range = torch.arange(last_dim, device=t.device, dtype=torch.int64)
  3654. b_range = torch.arange(
  3655. offset, last_dim + offset, device=t.device, dtype=torch.int64
  3656. )
  3657. # broadcast
  3658. cond = a_range == b_range.unsqueeze(-1)
  3659. cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(t.shape))]
  3660. cond = cond.reshape(cond_shape)
  3661. # aten.diag_embed always returns a new contiguous tensor
  3662. # contiguous() is needed to correctly model the output stride
  3663. return utils.mask_tensor(cond, t).contiguous()
  3664. @register_decomposition(aten.block_diag)
  3665. @out_wrapper()
  3666. def _block_diag_iterable(tensors: List[TensorLikeType]) -> TensorLikeType:
  3667. """
  3668. Reference implementation of torch.block_diag
  3669. """
  3670. tensors_2d = [
  3671. tensor.view(1, -1) if tensor.dim() <= 1 else tensor for tensor in tensors
  3672. ]
  3673. ncols = builtins.sum(tensor.shape[1] for tensor in tensors_2d)
  3674. device = tensors_2d[0].device
  3675. result = []
  3676. col_start = 0
  3677. for i, tensor in enumerate(tensors_2d):
  3678. torch._check(
  3679. tensor.dim() == 2,
  3680. lambda: "Input tensors must have 2 or fewer dimensions. "
  3681. f"Input {i} has {tensor.dim()} dimensions",
  3682. )
  3683. torch._check(
  3684. tensor.device == device,
  3685. lambda: "Input tensors must all be on the same device. "
  3686. f"Input 0 is on device {device} and input {i} is on device {tensor.device}.",
  3687. )
  3688. row, col = tensor.shape
  3689. left = torch.zeros((row, col_start), device=device, dtype=tensor.dtype)
  3690. right = torch.zeros(
  3691. (row, ncols - col_start - col), device=device, dtype=tensor.dtype
  3692. )
  3693. result += [torch.cat((left, tensor, right), dim=1)]
  3694. col_start += col
  3695. return torch.cat(result, dim=0)
  3696. def block_diag(*tensors: List[TensorLikeType]) -> TensorLikeType:
  3697. """
  3698. This is used as an input to PythonRefInfo. `torch.block_diag`
  3699. expects arguments splatted, but `aten.block_diag` expects only
  3700. one argument that is a list of Tensors.
  3701. """
  3702. return _block_diag_iterable(tensors)
  3703. # CompositeImplicitAutograd - don't register decomp
  3704. def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
  3705. if a.ndim < 3:
  3706. raise RuntimeError(
  3707. f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!"
  3708. )
  3709. if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0):
  3710. raise RuntimeError(
  3711. "torch.dsplit attempted to split along dimension 2, "
  3712. + f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!"
  3713. )
  3714. return tensor_split(a, sections, 2)
  3715. @register_decomposition(aten.t.default)
  3716. def t(a: TensorLikeType):
  3717. # TODO: Add sparse support
  3718. # if a.is_sparse:
  3719. # sparse_dim = a.sparse_dim()
  3720. # dense_dim = a.dense_dim()
  3721. # if not (sparse_dim <= 2 and dense_dim == 0):
  3722. # raise RuntimeError(
  3723. # f"t() expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and"
  3724. # f"{dense_dim} dense dimensions"
  3725. # )
  3726. if a.ndim > 2:
  3727. raise RuntimeError(
  3728. f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D"
  3729. )
  3730. return torch.transpose(a, 0, 0 if a.ndim < 2 else 1)
  3731. # CompositeImplicitAutograd - don't register decomp
  3732. def T(a: TensorLikeType) -> TensorLikeType:
  3733. # n != 2 && n != 0 is deprecated in regular PyTorch.
  3734. torch._check(
  3735. a.ndim in (0, 2),
  3736. lambda: (
  3737. "The use of `x.T` on tensors of dimension other than 0 or 2 "
  3738. "to reverse their shape is not supported."
  3739. ),
  3740. )
  3741. return a.t()
  3742. @register_decomposition(aten.alias)
  3743. def alias(a: TensorLikeType) -> TensorLikeType:
  3744. return prims.view_of(a)
  3745. @register_decomposition(aten.transpose)
  3746. def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType:
  3747. _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc]
  3748. if a.ndim <= 1 or dim0 == dim1:
  3749. return aten.alias.default(a)
  3750. _permutation = list(range(0, a.ndim))
  3751. _permutation[_dim0] = _dim1
  3752. _permutation[_dim1] = _dim0
  3753. return torch.permute(a, _permutation)
  3754. # Aliases for transpose
  3755. swap_axes = transpose
  3756. @register_decomposition(aten.unfold)
  3757. def unfold(
  3758. self: TensorLikeType, dimension: int, size: int, step: int
  3759. ) -> TensorLikeType:
  3760. shape, strides = _get_unfold_shape_stride(
  3761. self.shape, self.stride(), dimension, size, step
  3762. )
  3763. return self.as_strided(shape, strides)
  3764. @register_decomposition(aten.unfold_copy)
  3765. @out_wrapper()
  3766. def unfold_copy(self: TensorLikeType, dimension: int, size: int, step: int):
  3767. return self.unfold(dimension, size, step).clone(
  3768. memory_format=torch.contiguous_format
  3769. )
  3770. def _cumsumprod_common(
  3771. func,
  3772. init,
  3773. a: TensorLikeType,
  3774. dim: int,
  3775. *,
  3776. dtype: Optional[torch.dtype] = None,
  3777. out: Optional[Tensor] = None,
  3778. ) -> TensorLikeType:
  3779. # We implement all the kwargs of a reduction. ATen just handles dtype
  3780. # nb. This decomposition may not be as efficient as a backend-specific implementation
  3781. ndim = a.ndim
  3782. dim = utils.canonicalize_dim(ndim, dim)
  3783. if ndim == 0:
  3784. return func(a.unsqueeze(0), dim=0, dtype=dtype, out=out)
  3785. a = a.unsqueeze(dim + 1)
  3786. rg = torch.arange(a.shape[dim], device=a.device)
  3787. mask = rg.unsqueeze(1) <= rg
  3788. for _ in range(ndim - dim - 1):
  3789. mask = mask.unsqueeze(-1)
  3790. masked_a = torch.where(mask, a, init)
  3791. return func(masked_a, dim=dim, dtype=dtype, out=out)
  3792. @register_decomposition(aten.cumsum)
  3793. def cumsum(
  3794. a: TensorLikeType,
  3795. dim: int,
  3796. *,
  3797. dtype: Optional[torch.dtype] = None,
  3798. out: Optional[Tensor] = None,
  3799. ) -> TensorLikeType:
  3800. return _cumsumprod_common(func=sum, init=0, a=a, dim=dim, dtype=dtype, out=out)
  3801. @register_decomposition(aten.cumprod)
  3802. def cumprod(
  3803. a: TensorLikeType,
  3804. dim: int,
  3805. *,
  3806. dtype: Optional[torch.dtype] = None,
  3807. out: Optional[Tensor] = None,
  3808. ) -> TensorLikeType:
  3809. return _cumsumprod_common(func=prod, init=1, a=a, dim=dim, dtype=dtype, out=out)
  3810. # Note: although squeeze is documented as having the out= kwarg it doesn't
  3811. @register_decomposition(aten.unsqueeze)
  3812. def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType:
  3813. # Note that unsqueeze canonicalizes with rank + 1 because it allows
  3814. # a new innermost dimension to be specified
  3815. ndim = a.ndim + 1
  3816. dim = utils.canonicalize_dim(ndim, dim)
  3817. return prims.expand_dims(a, (dim,), ndim=ndim)
  3818. # NOTE: shape is a vararg because Tensor.reshape can be called with as
  3819. # Tensor.view(a, b, c) or Tensor.view((a, b, c)) Function call torch.view
  3820. # doesn't support unpacked shapes
  3821. # TODO: Turn this into a decomposition (currently fails on reshape meta tests)
  3822. @register_decomposition(aten.view.default)
  3823. def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
  3824. return _reshape_view_helper(a, *shape, allow_copy=False)
  3825. # CompositeImplicitAutograd - don't register decomp
  3826. def view_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
  3827. return self.view(other.size())
  3828. # CompositeImplicitAutograd - don't register decomp
  3829. def ravel(a: TensorLikeType) -> TensorLikeType:
  3830. return reshape(a, (-1,))
  3831. # CompositeImplicitAutograd - don't register decomp
  3832. # missing ref impl. for aten.gather
  3833. @out_wrapper()
  3834. def take_along_dim(
  3835. a: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = None
  3836. ) -> torch.Tensor:
  3837. torch._check(
  3838. a.ndim == indices.ndim,
  3839. lambda: (
  3840. "torch.take_along_dim(): input and indices should have the same "
  3841. f"number of dimensions, but got {a.ndim} dimensions for input, and "
  3842. f"{indices.ndim} dimensions for indices"
  3843. ),
  3844. )
  3845. torch._check(
  3846. utils.is_integer_dtype(indices.dtype),
  3847. lambda: (
  3848. "torch.take_along_dim(): dtype of indices should be int but got "
  3849. f"{indices.dtype} instead"
  3850. ),
  3851. )
  3852. if dim is None:
  3853. return torch.gather(a.view(-1), 0, indices.view(-1))
  3854. else:
  3855. self_sizes = list(a.shape)
  3856. self_sizes[dim] = indices.size(dim)
  3857. broadcast_shape = utils.infer_size_shapes(self_sizes, indices.size())
  3858. indices_broadcast = broadcast_to(indices, broadcast_shape)
  3859. indices_sizes = list(indices.shape)
  3860. indices_sizes[dim] = a.size(dim)
  3861. broadcast_shape = utils.infer_size_shapes(indices_sizes, a.size())
  3862. self_broadcast = broadcast_to(a, broadcast_shape)
  3863. return torch.gather(self_broadcast, dim, indices_broadcast)
  3864. @out_wrapper()
  3865. def empty(
  3866. *shape,
  3867. dtype: Optional[torch.dtype] = None,
  3868. layout: torch.layout = torch.strided,
  3869. device: Optional[DeviceLikeType] = None,
  3870. requires_grad: bool = False,
  3871. pin_memory: bool = False,
  3872. memory_format: torch.memory_format = torch.contiguous_format,
  3873. ) -> TensorLikeType:
  3874. torch._check(
  3875. memory_format != torch.preserve_format,
  3876. lambda: "torch.empty: the Preserve memory format is not supported",
  3877. )
  3878. shape = utils.extract_shape_from_varargs(shape)
  3879. if memory_format == torch.contiguous_format:
  3880. strides = utils.make_contiguous_strides_for(shape)
  3881. elif memory_format == torch.channels_last_3d:
  3882. strides = utils.make_channels_last_3d_strides_for(shape)
  3883. else: # memory_format == torch.channels_last
  3884. torch._check(
  3885. memory_format == torch.channels_last,
  3886. lambda: f"torch.empty: received an unknown memory format {memory_format}!",
  3887. )
  3888. strides = utils.make_channels_last_2d_strides_for(shape)
  3889. return torch.empty_strided(
  3890. shape,
  3891. strides,
  3892. dtype=dtype,
  3893. layout=layout,
  3894. device=device,
  3895. pin_memory=pin_memory,
  3896. requires_grad=requires_grad,
  3897. )
  3898. @out_wrapper()
  3899. def empty_permuted(
  3900. shape,
  3901. physical_layout,
  3902. dtype: Optional[torch.dtype] = None,
  3903. layout: torch.layout = torch.strided,
  3904. device: Optional[DeviceLikeType] = None,
  3905. requires_grad: bool = False,
  3906. pin_memory: bool = False,
  3907. ) -> TensorLikeType:
  3908. return prims.empty_permuted(
  3909. shape,
  3910. physical_layout,
  3911. dtype=dtype,
  3912. device=device,
  3913. requires_grad=requires_grad,
  3914. )
  3915. @register_decomposition(aten.new_empty)
  3916. @out_wrapper()
  3917. def new_empty(
  3918. a: TensorLikeType,
  3919. size: ShapeType,
  3920. *,
  3921. dtype: Optional[torch.dtype] = None,
  3922. layout: Optional[torch.layout] = None,
  3923. device: Optional[DeviceLikeType] = None,
  3924. pin_memory: bool = False,
  3925. ) -> TensorLikeType:
  3926. dtype = a.dtype if dtype is None else dtype
  3927. layout = a.layout if layout is None else layout
  3928. device = a.device if device is None else device
  3929. return torch.empty(
  3930. size,
  3931. dtype=dtype,
  3932. device=device,
  3933. pin_memory=pin_memory,
  3934. layout=layout,
  3935. )
  3936. @register_decomposition(aten.new_empty_strided)
  3937. @out_wrapper()
  3938. def new_empty_strided(
  3939. a: TensorLikeType,
  3940. size: ShapeType,
  3941. stride: StrideType,
  3942. *,
  3943. dtype: Optional[torch.dtype] = None,
  3944. layout: Optional[torch.layout] = None,
  3945. device: Optional[DeviceLikeType] = None,
  3946. pin_memory: bool = False,
  3947. ) -> TensorLikeType:
  3948. """
  3949. Reference implementation of torch.Tensor.new_empty_strided
  3950. """
  3951. dtype = a.dtype if dtype is None else dtype
  3952. layout = a.layout if layout is None else layout
  3953. device = a.device if device is None else device
  3954. return torch.empty_strided(
  3955. size,
  3956. stride,
  3957. dtype=dtype,
  3958. device=device,
  3959. pin_memory=pin_memory,
  3960. layout=layout,
  3961. )
  3962. @register_decomposition(aten.zeros.default)
  3963. @out_wrapper()
  3964. def zeros(
  3965. *size,
  3966. dtype: Optional[torch.dtype] = None,
  3967. layout: torch.layout = torch.strided,
  3968. device: Optional[DeviceLikeType] = None,
  3969. pin_memory: bool = False,
  3970. requires_grad: bool = False,
  3971. ) -> TensorLikeType:
  3972. size = utils.extract_shape_from_varargs(size)
  3973. if dtype is None:
  3974. dtype = torch.get_default_dtype()
  3975. return torch.full(
  3976. size,
  3977. False if dtype == torch.bool else 0,
  3978. dtype=dtype,
  3979. layout=layout,
  3980. device=device,
  3981. pin_memory=pin_memory,
  3982. requires_grad=requires_grad,
  3983. )
  3984. @register_decomposition(aten.new_zeros)
  3985. @out_wrapper()
  3986. def new_zeros(
  3987. a: TensorLikeType,
  3988. size: ShapeType,
  3989. *,
  3990. dtype: Optional[torch.dtype] = None,
  3991. layout: Optional[torch.layout] = None,
  3992. device: Optional[DeviceLikeType] = None,
  3993. pin_memory: bool = False,
  3994. requires_grad: bool = False,
  3995. ) -> TensorLikeType:
  3996. dtype = a.dtype if dtype is None else dtype
  3997. layout = a.layout if layout is None else layout
  3998. device = a.device if device is None else device
  3999. return torch.full(
  4000. size,
  4001. False if (dtype or a.dtype) == torch.bool else 0,
  4002. dtype=dtype,
  4003. layout=layout,
  4004. device=device,
  4005. pin_memory=pin_memory,
  4006. requires_grad=requires_grad,
  4007. )
  4008. @register_decomposition(aten.ones.default)
  4009. @out_wrapper()
  4010. def ones(
  4011. *size,
  4012. dtype: Optional[torch.dtype] = None,
  4013. layout: torch.layout = torch.strided,
  4014. device: Optional[DeviceLikeType] = None,
  4015. pin_memory: bool = False,
  4016. requires_grad: bool = False,
  4017. ) -> TensorLikeType:
  4018. size = utils.extract_shape_from_varargs(size)
  4019. if dtype is None:
  4020. dtype = torch.get_default_dtype()
  4021. return torch.full(
  4022. size,
  4023. True if dtype == torch.bool else 1,
  4024. dtype=dtype,
  4025. layout=layout,
  4026. device=device,
  4027. pin_memory=pin_memory,
  4028. requires_grad=requires_grad,
  4029. )
  4030. @register_decomposition(aten.new_ones)
  4031. @out_wrapper()
  4032. def new_ones(
  4033. a: TensorLikeType,
  4034. size: ShapeType,
  4035. *,
  4036. dtype: Optional[torch.dtype] = None,
  4037. layout: Optional[torch.layout] = None,
  4038. device: Optional[DeviceLikeType] = None,
  4039. pin_memory: bool = False,
  4040. requires_grad: bool = False,
  4041. ) -> TensorLikeType:
  4042. dtype = a.dtype if dtype is None else dtype
  4043. layout = a.layout if layout is None else layout
  4044. device = a.device if device is None else device
  4045. return torch.full(
  4046. size,
  4047. True if (dtype or a.dtype) == torch.bool else 1,
  4048. dtype=dtype,
  4049. layout=layout,
  4050. device=device,
  4051. pin_memory=pin_memory,
  4052. requires_grad=requires_grad,
  4053. )
  4054. @register_decomposition(aten.new_full)
  4055. @out_wrapper()
  4056. def new_full(
  4057. a: TensorLikeType,
  4058. size: ShapeType,
  4059. fill_value: NumberType,
  4060. *,
  4061. dtype: Optional[torch.dtype] = None,
  4062. layout: Optional[torch.layout] = None,
  4063. device: Optional[DeviceLikeType] = None,
  4064. pin_memory: bool = False,
  4065. ) -> TensorLikeType:
  4066. dtype = a.dtype if dtype is None else dtype
  4067. layout = a.layout if layout is None else layout
  4068. device = a.device if device is None else device
  4069. return torch.full(
  4070. size,
  4071. fill_value,
  4072. dtype=dtype,
  4073. layout=layout,
  4074. device=device,
  4075. pin_memory=pin_memory,
  4076. )
  4077. @register_decomposition(aten.empty_like)
  4078. @out_wrapper()
  4079. def empty_like(
  4080. a: TensorLikeType,
  4081. *,
  4082. dtype: Optional[torch.dtype] = None,
  4083. device: Optional[DeviceLikeType] = None,
  4084. layout: Optional[torch.layout] = None,
  4085. pin_memory: bool = False,
  4086. requires_grad: bool = False,
  4087. memory_format: torch.memory_format = torch.preserve_format,
  4088. ) -> TensorLikeType:
  4089. dtype = a.dtype if dtype is None else dtype
  4090. layout = a.layout if layout is None else layout
  4091. device = a.device if device is None else device
  4092. if memory_format != torch.preserve_format:
  4093. return torch.empty(
  4094. a.shape,
  4095. dtype=dtype,
  4096. layout=layout,
  4097. device=device,
  4098. requires_grad=requires_grad,
  4099. pin_memory=pin_memory,
  4100. memory_format=memory_format,
  4101. )
  4102. # memory_format == torch.preserve_format
  4103. logical_to_physical_perm = (
  4104. utils.compute_elementwise_output_logical_to_physical_perm(a)
  4105. )
  4106. # identity perm is [2, 1, 0]
  4107. return torch.empty_permuted(
  4108. a.shape,
  4109. logical_to_physical_perm,
  4110. dtype=dtype,
  4111. layout=layout,
  4112. device=device,
  4113. pin_memory=pin_memory,
  4114. requires_grad=requires_grad,
  4115. )
  4116. @register_decomposition([aten.arange.start_step, aten.arange.start_out])
  4117. @out_wrapper()
  4118. def arange(
  4119. start: NumberType = 0,
  4120. end: Optional[NumberType] = None,
  4121. step: NumberType = 1,
  4122. *,
  4123. dtype: Optional[torch.dtype] = None,
  4124. layout: torch.layout = torch.strided,
  4125. device: Optional[DeviceLikeType] = None,
  4126. pin_memory: bool = False,
  4127. requires_grad: bool = False,
  4128. ) -> TensorLikeType:
  4129. utils.check_layout(layout)
  4130. utils.check_pin_memory(pin_memory)
  4131. device = torch.device(utils.device_or_default(device))
  4132. assert not isinstance(start, complex)
  4133. assert not isinstance(end, complex)
  4134. assert not isinstance(step, complex)
  4135. # Case: torch.arange(5)
  4136. if end is None:
  4137. end = start
  4138. start = 0
  4139. torch._check(step != 0, lambda: "step must be nonzero")
  4140. if step > 0:
  4141. torch._check(
  4142. end >= start,
  4143. lambda: "upper bound and lower bound inconsistent with step sign",
  4144. )
  4145. elif step < 0:
  4146. torch._check(
  4147. end <= start,
  4148. lambda: "upper bound and lower bound inconsistent with step sign",
  4149. )
  4150. def is_finite(x):
  4151. return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x)
  4152. torch._check(
  4153. is_finite(start) and is_finite(end),
  4154. lambda: f"unsupported range: {start} -> {end}",
  4155. )
  4156. torch._check(
  4157. is_finite(step),
  4158. lambda: f"step must be finite but got {step}",
  4159. )
  4160. args = (start, end, step)
  4161. integer_args = builtins.all(isinstance(arg, IntLike) for arg in args)
  4162. if dtype is None:
  4163. dtype = torch.int64 if integer_args else torch.get_default_dtype()
  4164. is_integer = utils.is_integer_dtype(dtype)
  4165. if is_integer:
  4166. xstart = sym_int(start)
  4167. xend = sym_int(end)
  4168. xstep = sym_int(step)
  4169. # For int64 we truncate arguments to int before calculating length, but
  4170. # other integral dtypes we don't. Weird... but needed to match ATen shapes.
  4171. if dtype == torch.int64:
  4172. # Uses floordiv to avoid ceil in inductor.
  4173. sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined]
  4174. length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined]
  4175. else:
  4176. length = math.ceil((end - start) / step)
  4177. if is_integer:
  4178. return prims.iota(
  4179. length,
  4180. start=xstart, # type: ignore[possibly-undefined]
  4181. step=xstep, # type: ignore[possibly-undefined]
  4182. dtype=dtype,
  4183. device=device,
  4184. requires_grad=requires_grad,
  4185. )
  4186. index = prims.iota(
  4187. length,
  4188. start=0,
  4189. step=1,
  4190. dtype=torch.int64,
  4191. device=device,
  4192. requires_grad=False,
  4193. )
  4194. computation_dtype = (
  4195. torch.long if integer_args else utils.get_acc_type(dtype, device)
  4196. )
  4197. index = _maybe_convert_to_dtype(index, computation_dtype)
  4198. result = start + step * index
  4199. result = _maybe_convert_to_dtype(result, dtype)
  4200. if requires_grad:
  4201. result.requires_grad_(True)
  4202. return result
  4203. @register_decomposition(aten.lerp)
  4204. @out_wrapper()
  4205. @elementwise_type_promotion_wrapper(
  4206. type_promoting_args=("start", "end", "weight"),
  4207. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  4208. )
  4209. def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]):
  4210. inputs = [start, end]
  4211. if isinstance(weight, Number):
  4212. weight = start.new_full((), weight) # type: ignore[arg-type]
  4213. else:
  4214. inputs.append(weight)
  4215. assert isinstance(weight, Tensor) # mypy
  4216. # We implement it this way for numerical stability. We assume (in the stability optimisation)
  4217. # that 0 <= weight <= 1. We take the abs to deal with complex numbers
  4218. # We want to perform operations near zero, which is where floating points are most precise
  4219. # thus, we perform the following optimisation:
  4220. # If weight.abs() >= 0.5:
  4221. # return (1 - weight) * (start - end) + end
  4222. mask = weight.abs() >= 0.5
  4223. coeff = torch.where(mask, weight - 1, weight)
  4224. base = torch.where(mask, end, start)
  4225. output = coeff * (end - start) + base
  4226. # make sure the decomposition output's stride is same as non-decomposition path.
  4227. stride = utils.compute_elementwise_output_strides(*_maybe_broadcast(*inputs))
  4228. if output.stride() != stride:
  4229. output = prims.copy_strided(output, stride)
  4230. return handle_noncontiguous_outputs(inputs, output)
  4231. @register_decomposition(aten.linspace)
  4232. @out_wrapper()
  4233. def linspace(
  4234. start: Union[NumberType, TensorLikeType],
  4235. end: Union[NumberType, TensorLikeType],
  4236. steps: NumberType,
  4237. *,
  4238. dtype: Optional[torch.dtype] = None,
  4239. device: Optional[DeviceLikeType] = None,
  4240. layout: torch.layout = torch.strided,
  4241. pin_memory: bool = False,
  4242. requires_grad: bool = False,
  4243. ) -> TensorLikeType:
  4244. if isinstance(start, TensorLikeType):
  4245. torch._check(
  4246. start.dim() == 0,
  4247. lambda: "linspace only supports 0-dimensional start and end tensors",
  4248. )
  4249. start = _maybe_convert_to_dtype(start, torch.float64)
  4250. if isinstance(end, TensorLikeType):
  4251. torch._check(
  4252. end.dim() == 0,
  4253. lambda: "linspace only supports 0-dimensional start and end tensors",
  4254. )
  4255. end = _maybe_convert_to_dtype(end, torch.float64)
  4256. if py_any(isinstance(arg, complex) for arg in (start, end, steps)):
  4257. default_complex_dtype = utils.corresponding_complex_dtype(
  4258. torch.get_default_dtype()
  4259. )
  4260. if dtype is None:
  4261. dtype = default_complex_dtype
  4262. else:
  4263. torch._check(
  4264. utils.is_complex_dtype(dtype),
  4265. lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
  4266. )
  4267. else:
  4268. dtype = dtype or torch.get_default_dtype()
  4269. assert isinstance(dtype, torch.dtype)
  4270. # steps does not participate in the computation of the dtype
  4271. torch._check_type(
  4272. isinstance(steps, IntLike),
  4273. lambda: f"received an invalid combination of arguments - got \
  4274. ({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})",
  4275. )
  4276. assert isinstance(steps, IntLike) # for mypy
  4277. torch._check(steps >= 0, lambda: "number of steps must be non-negative")
  4278. factory_kwargs = {
  4279. "layout": layout,
  4280. "device": device,
  4281. "pin_memory": pin_memory,
  4282. "requires_grad": requires_grad,
  4283. }
  4284. if steps == 0:
  4285. return torch.full((0,), 0, dtype=dtype, **factory_kwargs) # type: ignore[arg-type]
  4286. if steps == 1:
  4287. if isinstance(start, TensorLikeType):
  4288. return torch.empty((steps,), dtype=dtype, **factory_kwargs).copy_(start) # type: ignore[arg-type]
  4289. else:
  4290. return torch.full((steps,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type]
  4291. # Perform in arange in int because some backends like ATen or Triton do not support all the dtypes
  4292. rg = torch.arange(0, steps, **factory_kwargs) # type: ignore[arg-type]
  4293. # Small types need to be computed in higher precision as this is, at heart, an associative scan
  4294. dtype_red = (
  4295. torch.int64
  4296. if (utils.is_boolean_dtype(dtype) or utils.is_integer_dtype(dtype))
  4297. else dtype
  4298. )
  4299. computation_dtype, _ = utils.reduction_dtypes(
  4300. rg, REDUCTION_OUTPUT_TYPE_KIND.SAME, dtype_red
  4301. )
  4302. cast_rg = partial(_maybe_convert_to_dtype, dtype=computation_dtype)
  4303. # We implement torch.lerp without performing rg / (steps - 1) explicitly
  4304. # With this we get out[0] == start, out[-1] == end
  4305. step = (end - start) / (steps - 1)
  4306. out = torch.where(
  4307. rg < steps / 2,
  4308. start + step * cast_rg(rg), # type: ignore[arg-type,operator]
  4309. end - step * cast_rg((steps - 1) - rg), # type: ignore[arg-type,operator]
  4310. )
  4311. return _maybe_convert_to_dtype(out, dtype) # type: ignore[return-value]
  4312. @register_decomposition(aten.logspace)
  4313. @out_wrapper()
  4314. def logspace(
  4315. start: Union[NumberType, TensorLikeType],
  4316. end: Union[NumberType, TensorLikeType],
  4317. steps: NumberType,
  4318. base: NumberType = 10,
  4319. *,
  4320. dtype: Optional[torch.dtype] = None,
  4321. device: Optional[DeviceLikeType] = None,
  4322. layout: torch.layout = torch.strided,
  4323. pin_memory: bool = False,
  4324. requires_grad: bool = False,
  4325. ) -> TensorLikeType:
  4326. if dtype is None:
  4327. dtype = torch.get_default_dtype()
  4328. # NB: NumPy doesn't have this cast
  4329. if prims.utils.is_integer_dtype(dtype):
  4330. if isinstance(start, FloatLike):
  4331. start = sym_int(start)
  4332. elif isinstance(start, TensorLikeType):
  4333. torch._check(
  4334. start.dim() == 0,
  4335. lambda: "logspace only supports 0-dimensional start and end tensors",
  4336. )
  4337. start = _maybe_convert_to_dtype(start, dtype)
  4338. if isinstance(end, FloatLike):
  4339. end = sym_int(end)
  4340. elif isinstance(end, TensorLikeType):
  4341. torch._check(
  4342. end.dim() == 0,
  4343. lambda: "logspace only supports 0-dimensional start and end tensors",
  4344. )
  4345. end = _maybe_convert_to_dtype(end, dtype)
  4346. if py_any(isinstance(arg, complex) for arg in (start, end, steps)):
  4347. default_complex_dtype = utils.corresponding_complex_dtype(
  4348. torch.get_default_dtype()
  4349. )
  4350. dtype = default_complex_dtype
  4351. _dtype = None # torch.linspace will update the correct dtype
  4352. else:
  4353. _dtype = torch.float64
  4354. assert not isinstance(base, complex) # for mypy
  4355. if base < 0:
  4356. raise NotImplementedError
  4357. ret = torch.linspace( # type: ignore[misc]
  4358. start, # type: ignore[arg-type]
  4359. end, # type: ignore[arg-type]
  4360. steps, # type: ignore[arg-type]
  4361. dtype=_dtype,
  4362. layout=layout,
  4363. device=device,
  4364. pin_memory=pin_memory,
  4365. requires_grad=requires_grad,
  4366. )
  4367. return _maybe_convert_to_dtype(torch.pow(base, ret), dtype) # type: ignore[arg-type,return-value]
  4368. @overload
  4369. def meshgrid(tensors: Sequence[TensorLikeType], indexing: str):
  4370. pass
  4371. @overload
  4372. def meshgrid(*tensors: TensorLikeType, indexing: str):
  4373. pass
  4374. @register_decomposition(aten.meshgrid)
  4375. def meshgrid(
  4376. *tensors: Union[TensorLikeType, List[TensorLikeType], Tuple[TensorLikeType]],
  4377. indexing: str,
  4378. ) -> List[TensorLikeType]:
  4379. # This ref simultaneously handles two overloads (see stubs above)
  4380. # The `indexing` argument is currently optional for torch.meshgrid, but we
  4381. # plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276
  4382. if isinstance(tensors[0], (list, tuple)):
  4383. assert len(tensors) == 1
  4384. tensors = tuple(tensors[0])
  4385. torch._check(
  4386. py_all(isinstance(a, TensorLike) for a in tensors),
  4387. lambda: "meshgrid expects its inputs to be tensors",
  4388. )
  4389. torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
  4390. for i in range(len(tensors) - 1):
  4391. torch._check(
  4392. tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr]
  4393. lambda: "meshgrid expects all tensors to have the same dtype",
  4394. )
  4395. torch._check(
  4396. tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr]
  4397. lambda: "meshgrid expects all tensors to have the same device",
  4398. )
  4399. swap_first_and_second_tensors = False
  4400. if indexing == "xy":
  4401. swap_first_and_second_tensors = len(tensors) >= 2
  4402. if swap_first_and_second_tensors:
  4403. tensors = (tensors[1], tensors[0], *tensors[2:])
  4404. else:
  4405. torch._check(
  4406. indexing == "ij",
  4407. lambda: (
  4408. 'torch.meshgrid: indexing must be one of "xy" or "ij", '
  4409. f"but received: {indexing}"
  4410. ),
  4411. )
  4412. result_shape: List[int] = []
  4413. for t in tensors:
  4414. assert isinstance(t, TensorLike) # mypy
  4415. torch._check(
  4416. t.ndim == 0 or t.ndim == 1,
  4417. lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}",
  4418. )
  4419. result_shape.append(t.numel())
  4420. grids: List[TensorLikeType] = []
  4421. for i, t in enumerate(tensors):
  4422. assert isinstance(t, TensorLike) # mypy
  4423. if t.ndim == 0:
  4424. t = t.view((1,))
  4425. grids.append(prims.broadcast_in_dim(t, result_shape, (i,)))
  4426. if swap_first_and_second_tensors:
  4427. # Swap outputs if we originally swapped at the beginning
  4428. grids[0], grids[1] = grids[1], grids[0]
  4429. return grids
  4430. # CompositeImplicitAutograd - don't register decomp
  4431. def movedim(
  4432. input: TensorLikeType,
  4433. source: Union[int, DimsSequenceType],
  4434. destination: Union[int, DimsSequenceType],
  4435. ) -> TensorLikeType:
  4436. """
  4437. Reference implementation of torch.movedim
  4438. """
  4439. if type(source) is int:
  4440. source = (source,)
  4441. if type(destination) is int:
  4442. destination = (destination,)
  4443. # Converts to list to produce a compatible error message with core PyTorch,
  4444. # which prints sequences in square brackets.
  4445. torch._check(
  4446. len(source) == len(destination), # type: ignore[arg-type]
  4447. lambda: (
  4448. "movedim: Invalid source or destination dims: source " # type: ignore[arg-type]
  4449. f"({list(source)} dims) should contain the same number " # type: ignore[arg-type]
  4450. f"of dims as destination ({list(destination)} dims)" # type: ignore[arg-type]
  4451. ),
  4452. )
  4453. rank = input.ndim
  4454. ss = tuple(utils.canonicalize_dims(rank=rank, indices=source)) # type: ignore[arg-type]
  4455. ds = tuple(utils.canonicalize_dims(rank=rank, indices=destination)) # type: ignore[arg-type]
  4456. sss = set(ss)
  4457. dss = set(ds)
  4458. # See above on why this converts to list in error messages.
  4459. torch._check(
  4460. len(ss) == len(sss),
  4461. lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type]
  4462. )
  4463. torch._check(
  4464. len(ds) == len(dss),
  4465. lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type]
  4466. )
  4467. m = dict(zip(ds, ss))
  4468. dims = []
  4469. si = 0 # source index
  4470. for di in range(rank):
  4471. # check if the destination index is in the mapping
  4472. s = m.get(di)
  4473. if s is not None:
  4474. # insert source index if found
  4475. dims.append(s)
  4476. else:
  4477. # insert source index sequentially, skipping indices from the mapping
  4478. while si in sss:
  4479. si += 1
  4480. dims.append(si)
  4481. si += 1
  4482. result = torch.permute(input, tuple(dims))
  4483. return result
  4484. # NOTE: for convenience, shape can be a tuple of ints or a tuple containing a tuple of ints
  4485. @register_decomposition(aten.empty_strided)
  4486. @out_wrapper()
  4487. def empty_strided(
  4488. shape: Union[ShapeType, Tuple[ShapeType]],
  4489. strides: StrideType,
  4490. *,
  4491. dtype: Optional[torch.dtype] = None,
  4492. device: Optional[DeviceLikeType] = None,
  4493. layout: torch.layout = torch.strided,
  4494. requires_grad: bool = False,
  4495. pin_memory: bool = False,
  4496. ) -> TensorLikeType:
  4497. # Layout == strided, pin_memory is False
  4498. utils.check_layout(layout)
  4499. utils.check_pin_memory(pin_memory)
  4500. shape = utils.extract_shape_from_varargs(shape)
  4501. dtype = torch.get_default_dtype() if dtype is None else dtype
  4502. device = torch.device("cpu") if device is None else device
  4503. return prims.empty_strided(
  4504. shape,
  4505. strides,
  4506. dtype=dtype,
  4507. device=device,
  4508. requires_grad=requires_grad,
  4509. )
  4510. @register_decomposition(aten.eye)
  4511. @out_wrapper()
  4512. def eye(
  4513. n: int,
  4514. m: Optional[int] = None,
  4515. *,
  4516. dtype: Optional[torch.dtype] = None,
  4517. layout: torch.layout = torch.strided,
  4518. device: Optional[DeviceLikeType] = None,
  4519. pin_memory: bool = False,
  4520. requires_grad: bool = False, # TODO: unused
  4521. ) -> TensorLikeType:
  4522. """
  4523. Reference implementation of torch.eye
  4524. """
  4525. if m is None:
  4526. m = n
  4527. torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
  4528. torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
  4529. range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False)
  4530. range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False)
  4531. cond = range_n.unsqueeze(-1) == range_m
  4532. if dtype is torch.bool:
  4533. return cond
  4534. else:
  4535. one = torch.ones(
  4536. (1,),
  4537. dtype=dtype,
  4538. layout=layout,
  4539. device=device,
  4540. pin_memory=pin_memory,
  4541. requires_grad=False,
  4542. )
  4543. return torch.where(cond, one, 0)
  4544. # TODO: Use requires_grad. All refs taking the requires_grad kwarg must
  4545. # return a leaf tensor.
  4546. # result.requires_grad_(requires_grad)
  4547. @register_decomposition([aten.full.default, aten.full.out])
  4548. @out_wrapper()
  4549. def full(
  4550. shape: ShapeType,
  4551. fill_value: NumberType,
  4552. *,
  4553. dtype: Optional[torch.dtype] = None,
  4554. layout: torch.layout = torch.strided,
  4555. device: Optional[DeviceLikeType] = None,
  4556. pin_memory: bool = False,
  4557. requires_grad: bool = False,
  4558. ) -> TensorLikeType:
  4559. utils.check_layout(layout)
  4560. utils.check_pin_memory(pin_memory)
  4561. dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
  4562. device = device if device is not None else torch.device("cpu")
  4563. e = empty(
  4564. shape,
  4565. dtype=dtype,
  4566. layout=layout,
  4567. device=device,
  4568. pin_memory=pin_memory,
  4569. requires_grad=requires_grad,
  4570. )
  4571. return torch.fill(e, fill_value) # type: ignore[arg-type]
  4572. def full_like(
  4573. a: TensorLikeType,
  4574. fill_value: NumberType,
  4575. *,
  4576. dtype: Optional[torch.dtype] = None,
  4577. layout: Optional[torch.layout] = None,
  4578. device: Optional[DeviceLikeType] = None,
  4579. pin_memory: bool = False,
  4580. requires_grad: bool = False,
  4581. memory_format: torch.memory_format = torch.preserve_format,
  4582. ) -> TensorLikeType:
  4583. e = torch.empty_like(
  4584. a,
  4585. dtype=dtype,
  4586. layout=layout,
  4587. device=device,
  4588. pin_memory=pin_memory,
  4589. requires_grad=requires_grad,
  4590. memory_format=memory_format,
  4591. )
  4592. return fill(e, fill_value)
  4593. @register_decomposition(aten.zeros_like)
  4594. @out_wrapper()
  4595. def zeros_like(
  4596. a: TensorLikeType,
  4597. *,
  4598. dtype: Optional[torch.dtype] = None,
  4599. layout: Optional[torch.layout] = None,
  4600. device: Optional[DeviceLikeType] = None,
  4601. pin_memory: bool = False,
  4602. requires_grad: bool = False,
  4603. memory_format: torch.memory_format = torch.preserve_format,
  4604. ) -> TensorLikeType:
  4605. return torch.full_like(
  4606. a,
  4607. False if (dtype or a.dtype) == torch.bool else 0,
  4608. dtype=dtype,
  4609. layout=layout,
  4610. device=device,
  4611. pin_memory=pin_memory,
  4612. requires_grad=requires_grad,
  4613. memory_format=memory_format,
  4614. )
  4615. @register_decomposition(aten.ones_like)
  4616. @out_wrapper()
  4617. def ones_like(
  4618. a: TensorLikeType,
  4619. *,
  4620. dtype: Optional[torch.dtype] = None,
  4621. layout: Optional[torch.layout] = None,
  4622. device: Optional[DeviceLikeType] = None,
  4623. pin_memory: bool = False,
  4624. requires_grad: bool = False,
  4625. memory_format: torch.memory_format = torch.preserve_format,
  4626. ) -> TensorLikeType:
  4627. return torch.full_like(
  4628. a,
  4629. True if (dtype or a.dtype) == torch.bool else 1,
  4630. dtype=dtype,
  4631. layout=layout,
  4632. device=device,
  4633. pin_memory=pin_memory,
  4634. requires_grad=requires_grad,
  4635. memory_format=memory_format,
  4636. )
  4637. @register_decomposition(aten.randn.default)
  4638. @out_wrapper()
  4639. def randn(
  4640. *shape,
  4641. dtype: Optional[torch.dtype] = None,
  4642. device: Optional[DeviceLikeType] = None,
  4643. layout: Optional[torch.layout] = None,
  4644. requires_grad: bool = False,
  4645. pin_memory: bool = False,
  4646. ) -> TensorLikeType:
  4647. utils.check_pin_memory(pin_memory)
  4648. shape_ = utils.extract_shape_from_varargs(shape)
  4649. dtype = utils.dtype_or_default(dtype)
  4650. device = utils.device_or_default(device)
  4651. return prims.normal(
  4652. shape_,
  4653. mean=0.0,
  4654. std=1.0,
  4655. dtype=dtype,
  4656. device=device,
  4657. requires_grad=requires_grad,
  4658. )
  4659. def scalar_tensor(
  4660. a: NumberType,
  4661. *,
  4662. dtype: Optional[torch.dtype] = None,
  4663. layout: torch.layout = torch.strided,
  4664. device: Optional[DeviceLikeType] = None,
  4665. pin_memory: bool = False,
  4666. ) -> TensorLikeType:
  4667. utils.check_layout(layout)
  4668. utils.check_pin_memory(pin_memory)
  4669. dtype = dtype if dtype is not None else utils.type_to_dtype(type(a))
  4670. device = device if device is not None else torch.device("cpu")
  4671. return prims.scalar_tensor(a, dtype=dtype, device=device)
  4672. #
  4673. # Randomness References
  4674. #
  4675. def _uniform_helper(
  4676. shape: ShapeType,
  4677. low: Union[bool, int, float] = 0.0,
  4678. high: Union[bool, int, float] = 1.0,
  4679. *,
  4680. dtype: torch.dtype,
  4681. device: DeviceLikeType,
  4682. ) -> TensorLikeType:
  4683. utils.validate_shape(shape)
  4684. assert isinstance(low, Number)
  4685. assert isinstance(high, Number)
  4686. low = sym_float(low)
  4687. high = sym_float(high)
  4688. assert isinstance(dtype, torch.dtype)
  4689. device = utils.canonicalize_device(device)
  4690. return prims._uniform_helper(shape, low=low, high=high, dtype=dtype, device=device)
  4691. @register_decomposition(aten.masked_fill)
  4692. @out_wrapper()
  4693. def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType):
  4694. python_type = utils.dtype_to_type(a.dtype)
  4695. if isinstance(value, Number):
  4696. value_type = type(value)
  4697. else:
  4698. # NOTE: Could not use value = item(value) as it resulted in
  4699. # RuntimeError: Cannot cast FakeTensor(cpu) to number
  4700. value_ndim = value.ndim
  4701. torch._check(
  4702. value_ndim == 0,
  4703. lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension",
  4704. )
  4705. # `masked_fill` allows cpu scalar to be moved to cuda and xpu but not otherwise.
  4706. is_cpu_scalar = (
  4707. a.device.type in ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
  4708. and value.device.type == "cpu"
  4709. )
  4710. torch._check(
  4711. is_cpu_scalar or value.device == a.device,
  4712. lambda: "Expected `value` to be on same device as `a`",
  4713. )
  4714. value_type = utils.dtype_to_type(value.dtype)
  4715. if value_type is complex:
  4716. # only downcasting from complex to lower type is not allowed.
  4717. # We allow casting `value` to lower type for other case
  4718. # Eg. float -> int.
  4719. # Ref: https://github.com/pytorch/pytorch/issues/79195
  4720. torch._check(
  4721. utils.is_weakly_lesser_type(value_type, python_type),
  4722. lambda: f"could not convert to type {python_type} without overflow",
  4723. )
  4724. # Since `where` allows type-promotion,
  4725. # cast value to correct type before passing to `where`
  4726. value = _maybe_convert_to_dtype(value, a.dtype)
  4727. r = torch.where(mask, value, a) # type: ignore[arg-type]
  4728. # aten.mask_fill always return a new contiguous tensor
  4729. # contiguous() is needed to correctly model the output stride
  4730. return r.contiguous()
  4731. @register_decomposition(aten.masked_fill_)
  4732. def masked_fill_(
  4733. a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType
  4734. ) -> TensorLikeType:
  4735. b = torch.masked_fill(a, mask, value) # type: ignore[arg-type]
  4736. a.copy_(b)
  4737. return a
  4738. # CompositeImplicitAutograd - don't register decomp
  4739. def allclose(
  4740. a: TensorLikeType,
  4741. b: TensorLikeType,
  4742. rtol: float = 1e-05,
  4743. atol: float = 1e-08,
  4744. equal_nan: bool = False,
  4745. ) -> bool:
  4746. """
  4747. Reference implementation of torch.allclose
  4748. """
  4749. _check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
  4750. return bool(
  4751. torch.all(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).item()
  4752. )
  4753. def equal(a: TensorLikeType, b: TensorLikeType) -> bool:
  4754. utils.check_same_device(a, b, allow_cpu_scalar_tensors=False)
  4755. utils.check_same_dtype(a, b)
  4756. # Shape check
  4757. if a.ndim != b.ndim:
  4758. return False
  4759. for x, y in zip(a.shape, b.shape):
  4760. if x != y:
  4761. return False
  4762. # Short-circuits if there are no elements to validate
  4763. if a.numel() == 0:
  4764. return True
  4765. return item(all(eq(a, b))) # type: ignore[return-value]
  4766. @register_decomposition(aten.norm)
  4767. @out_wrapper(exact_dtype=True)
  4768. def norm(
  4769. input: TensorLikeType,
  4770. p: Optional[Union[float, str]] = "fro",
  4771. dim: Optional[DimsType] = None,
  4772. keepdim: bool = False,
  4773. *,
  4774. dtype: Optional[torch.dtype] = None,
  4775. ) -> TensorLikeType:
  4776. # In these cases we compute the "Frobenius norm"
  4777. if (
  4778. p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2)
  4779. ) or p is None:
  4780. p = 2
  4781. if isinstance(dim, Dim):
  4782. dim = [dim]
  4783. if isinstance(p, str):
  4784. # Here we either call the nuclear norm, or we call matrix_norm with some arguments
  4785. # that will throw an error
  4786. if dim is None:
  4787. dim = tuple(range(input.ndim))
  4788. return torch.linalg.matrix_norm(input, p, dim, keepdim, dtype=dtype)
  4789. else:
  4790. return torch.linalg.vector_norm(input, p, dim, keepdim, dtype=dtype)
  4791. @register_decomposition(aten.trace)
  4792. @out_wrapper()
  4793. def trace(self: TensorLikeType) -> TensorLikeType:
  4794. torch._check(
  4795. self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
  4796. )
  4797. return torch.sum(torch.diag(self, 0))
  4798. def _make_r_binary_op(base_op):
  4799. def rop(
  4800. a: Union[TensorLikeType, NumberType],
  4801. b: Union[TensorLikeType, NumberType],
  4802. ) -> TensorLikeType:
  4803. return base_op(b, a)
  4804. return rop
  4805. rtruediv = _make_r_binary_op(true_divide)
  4806. rfloordiv = _make_r_binary_op(floor_divide)
  4807. rpow = _make_r_binary_op(pow)
  4808. @register_decomposition(aten.triu)
  4809. @out_wrapper()
  4810. def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
  4811. torch._check(
  4812. a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions"
  4813. )
  4814. h, w = a.shape[-2:]
  4815. mask = (
  4816. torch.arange(w, device=a.device).unsqueeze(-2)
  4817. - torch.arange(h, device=a.device).unsqueeze(-1)
  4818. ) >= diagonal
  4819. # aten.triu always returns a new contiguous tensor
  4820. # contiguous() is needed to correctly model the output stride
  4821. return utils.mask_tensor(mask, a).contiguous()
  4822. @register_decomposition(aten.tril)
  4823. @out_wrapper()
  4824. def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
  4825. torch._check(
  4826. a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions"
  4827. )
  4828. h, w = a.shape[-2:]
  4829. mask = (
  4830. torch.arange(w, device=a.device).unsqueeze(-2)
  4831. - torch.arange(h, device=a.device).unsqueeze(-1)
  4832. ) <= diagonal
  4833. # aten.tril always returns a new contiguous tensor
  4834. # contiguous() is needed to correctly model the output stride
  4835. return utils.mask_tensor(mask, a).contiguous()
  4836. # This is based on get_tril_size in aten/src/ATen/native/TensorFactories.h
  4837. # The components of the matrix that belong to the lower triangle with offset
  4838. # form a pentagon that can be broken down into a top trapezoid and a bottom
  4839. # rectangle. For the implementation of tril_indices, we need the sizes of
  4840. # both of these, as well as the length of the top side of the trapezoid.
  4841. def _get_tril_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]:
  4842. if row == 0 or col == 0:
  4843. return 0, 0, 0
  4844. m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0)
  4845. m_last_row = max(0, min(col, row + offset))
  4846. n_row_all = max(0, min(row, row + offset))
  4847. n_row_trapezoid = m_last_row - m_first_row + 1
  4848. # Number of elements in top trapezoid
  4849. trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2
  4850. # Number of elements in bottom rectangle
  4851. diff_row = n_row_all - n_row_trapezoid
  4852. rectangle_size = max(0, diff_row * col)
  4853. return trapezoid_size, rectangle_size, m_first_row
  4854. def _trilu_checks(
  4855. name: str,
  4856. row: int,
  4857. col: int,
  4858. dtype: torch.dtype,
  4859. layout: torch.layout,
  4860. pin_memory: bool,
  4861. ):
  4862. torch._check(row >= 0, lambda: f"row must be non-negative, got {row}")
  4863. torch._check(col >= 0, lambda: f"col must be non-negative, got {col}")
  4864. torch._check(
  4865. dtype in (torch.int32, torch.int64),
  4866. lambda: f"\"{name}\" not implemented for '{dtype}'",
  4867. )
  4868. # This is based on tril_indices_cuda in aten/src/ATen/native/cuda/TensorFactories.cu
  4869. @register_decomposition(aten.tril_indices)
  4870. @out_wrapper()
  4871. def tril_indices(
  4872. row: int,
  4873. col: int,
  4874. offset: int = 0,
  4875. *,
  4876. dtype: torch.dtype = torch.long,
  4877. layout: torch.layout = torch.strided,
  4878. device: DeviceLikeType = "cpu",
  4879. pin_memory: bool = False,
  4880. ) -> TensorLikeType:
  4881. _trilu_checks("tril_indices", row, col, dtype, layout, pin_memory)
  4882. trapezoid_size, rectangle_size, m_first_row = _get_tril_sizes(row, col, offset)
  4883. row_offset = max(0, -offset)
  4884. arange_kw = partial(
  4885. torch.arange, layout=layout, device=device, pin_memory=pin_memory
  4886. )
  4887. # first we do the indices for top trapezoid
  4888. xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64)
  4889. b = m_first_row - 0.5
  4890. row_inds1 = torch.floor(-b + torch.sqrt(b * b + 2 * xs1))
  4891. col_inds1 = torch.floor(xs1 - (2 * m_first_row - 1 + row_inds1) * row_inds1 * 0.5)
  4892. row_inds1 = _maybe_convert_to_dtype(row_inds1 + row_offset, dtype)
  4893. col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype)
  4894. # then bottom rectangle
  4895. xs2 = arange_kw(0, rectangle_size, dtype=dtype)
  4896. row_inds2 = xs2 // col + (col - m_first_row + 1 + row_offset)
  4897. col_inds2 = xs2 % col
  4898. return torch.stack(
  4899. (torch.cat((row_inds1, row_inds2)), torch.cat((col_inds1, col_inds2)))
  4900. )
  4901. # Similar to _get_tril_sizes above, but here there is a top trapezoid and
  4902. # a bottom rectangle instead. Note that you can't reduce this to
  4903. # _get_tril_sizes(col, row, -offset) because that would correspond to
  4904. # decomposing into a left trapezoid and right rectangle.
  4905. def _get_triu_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]:
  4906. if row == 0 or col == 0:
  4907. return 0, 0, 0
  4908. m_first_row = max(0, col - offset) if offset > 0 else col
  4909. # Number of elements in top rectangle
  4910. rectangle_size = max(0, min(row, -offset) * col)
  4911. # Number of elements in bottom trapezoid
  4912. trapezoid_size_tril, rectangle_size_tril, _ = _get_tril_sizes(row, col, offset - 1)
  4913. triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril)
  4914. trapezoid_size = triu_size - rectangle_size
  4915. return trapezoid_size, rectangle_size, m_first_row
  4916. @register_decomposition(aten.triu_indices)
  4917. @out_wrapper()
  4918. def triu_indices(
  4919. row: int,
  4920. col: int,
  4921. offset: int = 0,
  4922. *,
  4923. dtype: torch.dtype = torch.long,
  4924. layout: torch.layout = torch.strided,
  4925. device: DeviceLikeType = "cpu",
  4926. pin_memory: bool = False,
  4927. ) -> TensorLikeType:
  4928. _trilu_checks("triu_indices", row, col, dtype, layout, pin_memory)
  4929. trapezoid_size, rectangle_size, m_first_row = _get_triu_sizes(row, col, offset)
  4930. col_offset = max(0, offset)
  4931. arange_kw = partial(
  4932. torch.arange, layout=layout, device=device, pin_memory=pin_memory
  4933. )
  4934. # indices for top rectangle
  4935. xs2 = arange_kw(0, rectangle_size, dtype=dtype)
  4936. row_inds2 = xs2 // col
  4937. col_inds2 = xs2 % col
  4938. # bottom trapezoid
  4939. xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64)
  4940. b = -0.5 - m_first_row
  4941. row_inds1 = torch.floor(-b - torch.sqrt(b * b - 2 * xs1))
  4942. col_inds1 = torch.floor(xs1 - ((2 * m_first_row - 1 - row_inds1) * row_inds1) * 0.5)
  4943. row_inds1 = _maybe_convert_to_dtype(row_inds1, dtype)
  4944. col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype)
  4945. if col:
  4946. row_inds1 = row_inds1 + (rectangle_size // col)
  4947. col_inds1 = col_inds1 + col_offset
  4948. return torch.stack(
  4949. (torch.cat((row_inds2, row_inds1)), torch.cat((col_inds2, col_inds1)))
  4950. )
  4951. @register_decomposition(aten.bucketize)
  4952. @out_wrapper(exact_dtype=True)
  4953. def bucketize(
  4954. a: TensorLikeType,
  4955. boundaries: TensorLikeType,
  4956. *,
  4957. out_int32: bool = False,
  4958. right: bool = False,
  4959. ):
  4960. torch._check(
  4961. boundaries.dim() == 1,
  4962. lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})",
  4963. )
  4964. out_dtype = torch.int32 if out_int32 else torch.int64
  4965. n_boundaries = boundaries.shape[-1]
  4966. if n_boundaries == 0:
  4967. return torch.zeros_like(a)
  4968. # We are trying to find the bucket (defined by pairs of consecutive elements of `boundaries`)
  4969. # each element of `a` belongs to. We use binary search to achieve logarithimic complexity,
  4970. # but each step of the search is done "in parallel" over all elements of `a`
  4971. # can't use int32 as indexes, so we have to do all computations with int64 and convert at the end
  4972. start = torch.zeros(a.shape, device=a.device, dtype=torch.int64)
  4973. end = start + n_boundaries
  4974. # Max depth of the binary search
  4975. # Since we can't break out of the loop at different points for different elements of a,
  4976. # we just do the max amount of iterations that binary search requires and add condition
  4977. # tensor (cond_update below) to stop updating once the search terminates
  4978. # For first iteration through loop we can skip some checks, we have separate implementation
  4979. mid = start + (end - start) // 2
  4980. mid_val = boundaries[mid]
  4981. if right:
  4982. cond_mid = mid_val > a
  4983. else:
  4984. cond_mid = mid_val >= a
  4985. start = torch.where(cond_mid, start, mid + 1)
  4986. if n_boundaries > 1:
  4987. cond_update = torch.ones_like(a, dtype=torch.bool)
  4988. niters = int(math.log2(n_boundaries))
  4989. for _ in range(niters):
  4990. end = torch.where(cond_mid & cond_update, mid, end)
  4991. cond_update = start < end
  4992. # start might end up pointing to 1 past the end, we guard against that
  4993. mid = torch.where(cond_update, start + (end - start) // 2, 0)
  4994. mid_val = boundaries[mid]
  4995. # If right is true, the buckets are closed on the *left*
  4996. # (i.e., we are doing the equivalent of std::upper_bound in C++)
  4997. # Otherwise they are closed on the right (std::lower_bound)
  4998. if right:
  4999. cond_mid = mid_val > a
  5000. else:
  5001. cond_mid = mid_val >= a
  5002. start = torch.where((~cond_mid) & cond_update, mid + 1, start)
  5003. return start.to(dtype=out_dtype)
  5004. @register_decomposition(aten.cauchy)
  5005. @out_wrapper()
  5006. @elementwise_type_promotion_wrapper(
  5007. type_promoting_args=("self",),
  5008. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  5009. )
  5010. def cauchy(self, median=0, sigma=1, generator=None):
  5011. assert generator is None
  5012. torch._check(
  5013. not utils.is_complex_dtype(self.dtype)
  5014. and not utils.is_integer_dtype(self.dtype)
  5015. and not utils.is_boolean_dtype(self.dtype),
  5016. lambda: f"Cauchy distribution is a continuous probability distribution. \
  5017. dtype must be a floating point but you specified {self.dtype}",
  5018. )
  5019. torch._check(
  5020. sigma > 0.0,
  5021. lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}",
  5022. )
  5023. return median + sigma * torch.tan(math.pi * (torch.rand_like(self) - 0.5))
  5024. @register_decomposition(aten.exponential)
  5025. @out_wrapper()
  5026. @elementwise_type_promotion_wrapper(
  5027. type_promoting_args=("self",),
  5028. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  5029. )
  5030. def exponential(self, rate=1, generator=None):
  5031. assert generator is None
  5032. torch._check(
  5033. not utils.is_complex_dtype(self.dtype)
  5034. and not utils.is_integer_dtype(self.dtype)
  5035. and not utils.is_boolean_dtype(self.dtype),
  5036. lambda: f"Exponential distribution is a continuous probability distribution. \
  5037. dtype must be a floating point but you specified {self.dtype}",
  5038. )
  5039. torch._check(
  5040. rate > 0.0,
  5041. lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
  5042. )
  5043. return -1 / rate * torch.log1p(-torch.rand_like(self))
  5044. @register_decomposition(aten.geometric)
  5045. @out_wrapper()
  5046. @elementwise_type_promotion_wrapper(
  5047. type_promoting_args=("self",),
  5048. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  5049. )
  5050. def geometric(self, p, generator=None):
  5051. assert generator is None
  5052. # TODO: fix inductor rand_like for integer, bool dtypes
  5053. torch._check(
  5054. not utils.is_complex_dtype(self.dtype)
  5055. and not utils.is_boolean_dtype(self.dtype),
  5056. lambda: f"geometric not implemented for {self.dtype}",
  5057. )
  5058. torch._check(
  5059. 0 < p and p < 1,
  5060. lambda: f"geometric_ expects p to be in (0, 1), but got p={p}",
  5061. )
  5062. return torch.floor(torch.log1p(-torch.rand_like(self)) / math.log1p(-p)) + 1
  5063. @register_decomposition(aten.log_normal)
  5064. @out_wrapper()
  5065. @elementwise_type_promotion_wrapper(
  5066. type_promoting_args=("self",),
  5067. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  5068. )
  5069. def log_normal(self, mean=1, std=2, generator=None):
  5070. assert generator is None
  5071. torch._check(
  5072. not utils.is_complex_dtype(self.dtype)
  5073. and not utils.is_integer_dtype(self.dtype)
  5074. and not utils.is_boolean_dtype(self.dtype),
  5075. lambda: f"log_normal not implemented for {self.dtype}",
  5076. )
  5077. torch._check(
  5078. 0 < std,
  5079. lambda: f"log_normal_ expects std > 0.0, but found std={std}",
  5080. )
  5081. return torch.exp(std * torch.randn_like(self) + mean)
  5082. # TODO: add support for functionalization aten.normal_functional
  5083. # NOTE: the device and dtype will be ignored when shape is None
  5084. @register_decomposition(aten.normal)
  5085. @out_wrapper()
  5086. @elementwise_type_promotion_wrapper(
  5087. type_promoting_args=(
  5088. "mean",
  5089. "std",
  5090. ),
  5091. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  5092. )
  5093. def normal(
  5094. mean=0,
  5095. std=1,
  5096. size=None,
  5097. *,
  5098. generator=None,
  5099. dtype=None,
  5100. layout=None,
  5101. device=None,
  5102. pin_memory=None,
  5103. ):
  5104. assert layout is None or layout == torch.strided
  5105. if not isinstance(std, TensorLike):
  5106. torch._check(
  5107. std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}"
  5108. )
  5109. if size is None:
  5110. tensors = tuple(t for t in (mean, std) if isinstance(t, TensorLike))
  5111. torch._check(
  5112. len(tensors) > 0,
  5113. lambda: "normal expects that either mean or std is a tensor, or size is defined",
  5114. )
  5115. torch._check(
  5116. layout is None and pin_memory is None,
  5117. lambda: "Cannot pass layout, or pin_memory without size",
  5118. )
  5119. size = _broadcast_shapes(*(t.shape for t in tensors))
  5120. dtype = tensors[0].dtype
  5121. device = tensors[0].device
  5122. else:
  5123. torch._check(
  5124. not isinstance(mean, TensorLike) and not isinstance(std, TensorLike),
  5125. lambda: "normal expects mean and std to be scalars when size is defined",
  5126. )
  5127. dtype = torch.get_default_dtype() if dtype is None else dtype
  5128. device = torch.device("cpu") if device is None else device
  5129. normal_samples = prims.normal(
  5130. size,
  5131. mean=0.0,
  5132. std=1.0,
  5133. dtype=dtype,
  5134. device=device,
  5135. requires_grad=False,
  5136. generator=generator,
  5137. )
  5138. return std * normal_samples + mean
  5139. @register_decomposition(aten.normal_)
  5140. def normal_(self, mean=0, std=1, *, generator=None):
  5141. return normal(mean, std, self.shape, out=self, generator=generator)
  5142. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  5143. def rad2deg(self: TensorLikeType):
  5144. torch._check(
  5145. not utils.is_complex_dtype(self.dtype),
  5146. lambda: "rad2deg is not supported for complex tensors.",
  5147. )
  5148. M_180_PI = 57.295779513082320876798154814105170332405472466564
  5149. return self * M_180_PI
  5150. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  5151. def deg2rad(self: TensorLikeType):
  5152. torch._check(
  5153. not utils.is_complex_dtype(self.dtype),
  5154. lambda: "deg2rad is not supported for complex tensors.",
  5155. )
  5156. M_PI_180 = 0.017453292519943295769236907684886127134428718885417
  5157. return self * M_PI_180
  5158. @register_decomposition(aten.count_nonzero)
  5159. @out_wrapper()
  5160. def count_nonzero(self, dim: Optional[DimsType] = None):
  5161. return (self != 0).sum(dim)
  5162. def _dot_check(self, other):
  5163. torch._check(
  5164. self.dim() == 1 and other.dim() == 1,
  5165. lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
  5166. )
  5167. def numel_error():
  5168. return (
  5169. f"inconsistent tensor size, expected tensor [{self.numel()}] and src [{other.numel()}] to have the"
  5170. f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively"
  5171. )
  5172. torch._check(self.numel() == other.numel(), numel_error)
  5173. @register_decomposition(aten.dot)
  5174. @out_wrapper()
  5175. @elementwise_type_promotion_wrapper(
  5176. type_promoting_args=("self", "other"),
  5177. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  5178. )
  5179. def dot(self, other):
  5180. if self.is_complex():
  5181. if self.is_conj():
  5182. if other.is_conj():
  5183. return torch.dot(self.conj(), other.conj()).conj()
  5184. else:
  5185. return torch.vdot(self.conj(), other)
  5186. elif other.is_conj():
  5187. return torch.vdot(other.conj(), self)
  5188. _dot_check(self, other)
  5189. return (self * other).sum()
  5190. @register_decomposition(aten.vdot)
  5191. @out_wrapper()
  5192. @elementwise_type_promotion_wrapper(
  5193. type_promoting_args=("self", "other"),
  5194. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  5195. )
  5196. def vdot(self, other):
  5197. if not self.is_complex():
  5198. return torch.dot(self, other)
  5199. if self.is_conj():
  5200. if other.is_conj():
  5201. return torch.vdot(other.conj(), self.conj())
  5202. else:
  5203. return torch.dot(self.conj(), other)
  5204. elif other.is_conj():
  5205. return torch.dot(self, other.conj()).conj()
  5206. _dot_check(self, other)
  5207. # The decomposition fails if you do self.conj()... not sure why
  5208. return (self.conj_physical() * other).sum()
  5209. @register_decomposition(aten.select_scatter)
  5210. @out_wrapper()
  5211. def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int):
  5212. dim = utils.canonicalize_dim(x.ndim, dim)
  5213. mask_shape = [1] * x.ndim
  5214. mask_shape[dim] = -1
  5215. if index < 0:
  5216. index = index + x.shape[dim]
  5217. mask = torch.arange(x.shape[dim], device=x.device).view(mask_shape) == index
  5218. src = torch.unsqueeze(src, dim).expand(x.shape)
  5219. return torch.where(mask, src, x)
  5220. # inplace
  5221. abs_ = _make_inplace(abs)
  5222. acos_ = _make_inplace(acos)
  5223. acosh_ = _make_inplace(acosh)
  5224. add_ = _make_inplace(add)
  5225. addcmul_ = _make_inplace(addcmul)
  5226. addcdiv_ = _make_inplace(addcdiv)
  5227. asin_ = _make_inplace(asin)
  5228. asinh_ = _make_inplace(asinh)
  5229. atan_ = _make_inplace(atan)
  5230. atanh_ = _make_inplace(atanh)
  5231. atan2_ = _make_inplace(atan2)
  5232. bitwise_and_ = _make_inplace(bitwise_and)
  5233. bitwise_left_shift_ = _make_inplace(bitwise_left_shift)
  5234. bitwise_not_ = _make_inplace(bitwise_not)
  5235. bitwise_or_ = _make_inplace(bitwise_or)
  5236. bitwise_right_shift_ = _make_inplace(bitwise_right_shift)
  5237. bitwise_xor_ = _make_inplace(bitwise_xor)
  5238. ceil_ = _make_inplace(ceil)
  5239. clamp_ = _make_inplace(clamp)
  5240. clamp_min_ = _make_inplace(clamp_min)
  5241. clamp_max_ = _make_inplace(clamp_max)
  5242. conj_physical_ = _make_inplace(conj_physical)
  5243. copysign_ = _make_inplace(copysign)
  5244. cos_ = _make_inplace(cos)
  5245. cosh_ = _make_inplace(cosh)
  5246. cumsum_ = _make_inplace(cumsum)
  5247. cumprod_ = _make_inplace(cumprod)
  5248. deg2rad_ = _make_inplace(deg2rad)
  5249. digamma_ = _make_inplace(digamma)
  5250. div_ = _make_inplace(div)
  5251. eq_ = _make_inplace(eq)
  5252. erf_ = _make_inplace(erf)
  5253. erfc_ = _make_inplace(erfc)
  5254. erfinv_ = _make_inplace(erfinv)
  5255. exp_ = _make_inplace(exp)
  5256. exp2_ = _make_inplace(exp2)
  5257. expm1_ = _make_inplace(expm1)
  5258. float_power_ = _make_inplace(float_power)
  5259. floor_ = _make_inplace(floor)
  5260. floor_divide_ = _make_inplace(floor_divide)
  5261. fmod_ = _make_inplace(fmod)
  5262. frac_ = _make_inplace(frac)
  5263. gcd_ = _make_inplace(gcd)
  5264. ge_ = _make_inplace(ge)
  5265. gt_ = _make_inplace(gt)
  5266. heaviside_ = _make_inplace(heaviside)
  5267. hypot_ = _make_inplace(hypot)
  5268. igamma_ = _make_inplace(igamma)
  5269. igammac_ = _make_inplace(igammac)
  5270. i0_ = _make_inplace(i0)
  5271. lcm_ = _make_inplace(lcm)
  5272. le_ = _make_inplace(le)
  5273. lerp_ = _make_inplace(lerp)
  5274. lgamma_ = _make_inplace(lgamma)
  5275. log10_ = _make_inplace(log10)
  5276. log1p_ = _make_inplace(log1p)
  5277. log2_ = _make_inplace(log2)
  5278. log_ = _make_inplace(log)
  5279. logical_and_ = _make_inplace(logical_and)
  5280. logical_not_ = _make_inplace(logical_not)
  5281. logical_or_ = _make_inplace(logical_or)
  5282. logical_xor_ = _make_inplace(logical_xor)
  5283. lt_ = _make_inplace(lt)
  5284. mul_ = _make_inplace(mul)
  5285. mvlgamma_ = _make_inplace(mvlgamma)
  5286. nan_to_num_ = _make_inplace(nan_to_num)
  5287. ne_ = _make_inplace(ne)
  5288. neg_ = _make_inplace(neg)
  5289. nextafter_ = _make_inplace(nextafter)
  5290. pow_ = _make_inplace(pow)
  5291. rad2deg_ = _make_inplace(rad2deg)
  5292. reciprocal_ = _make_inplace(reciprocal)
  5293. remainder_ = _make_inplace(remainder)
  5294. rsqrt_ = _make_inplace(rsqrt)
  5295. sgn_ = _make_inplace(sgn)
  5296. sigmoid_ = _make_inplace(sigmoid)
  5297. sign_ = _make_inplace(sign)
  5298. sin_ = _make_inplace(sin)
  5299. sinc_ = _make_inplace(sinc)
  5300. sinh_ = _make_inplace(sinh)
  5301. sqrt_ = _make_inplace(sqrt)
  5302. square_ = _make_inplace(square)
  5303. sub_ = _make_inplace(sub)
  5304. tan_ = _make_inplace(tan)
  5305. tanh_ = _make_inplace(tanh)
  5306. tril_ = _make_inplace(tril)
  5307. triu_ = _make_inplace(triu)
  5308. true_divide_ = _make_inplace(true_divide)
  5309. trunc_ = _make_inplace(trunc)
  5310. xlogy_ = _make_inplace(xlogy)
  5311. cauchy_ = _make_inplace(cauchy)
  5312. exponential_ = _make_inplace(exponential)
  5313. geometric_ = _make_inplace(geometric)
  5314. log_normal_ = _make_inplace(log_normal)
  5315. zero_ = _make_inplace(zero)
  5316. # xref: isStorage in torch/csrc/DynamicTypes.cpp
  5317. def _isStorage(obj):
  5318. return isinstance(obj, (torch.TypedStorage, torch.UntypedStorage))
  5319. # xref: compute_sizes in torch/csrc/utils/tensor_new.cpp
  5320. def _compute_sizes(seq, scalar_type):
  5321. MAX_DIMS = 128
  5322. is_storage = _isStorage(seq)
  5323. sizes = []
  5324. # TODO: this is inaccurate, we actually test PySequence_Check
  5325. while isinstance(seq, (list, tuple)):
  5326. length = len(seq)
  5327. if is_storage:
  5328. length //= scalar_type.itemsize
  5329. sizes.append(length)
  5330. if len(sizes) > MAX_DIMS:
  5331. raise ValueError(f"too many dimensions '{type(seq).__name__}'")
  5332. if length == 0:
  5333. break
  5334. try:
  5335. handle = seq[0]
  5336. except Exception:
  5337. raise ValueError( # noqa: B904
  5338. f"could not determine the shape of object type '{type(seq).__name__}'"
  5339. )
  5340. seq = handle
  5341. return sizes
  5342. # xref: infer_scalar_type in torch/csrc/utils/tensor_new.cpp
  5343. def _infer_scalar_type(obj):
  5344. if isinstance(obj, FloatLike):
  5345. return torch.get_default_dtype()
  5346. if isinstance(obj, IntLike) and not isinstance(obj, bool): # careful!
  5347. return torch.int64
  5348. if isinstance(obj, BoolLike):
  5349. return torch.bool
  5350. if isinstance(obj, complex):
  5351. default_dtype = torch.get_default_dtype()
  5352. if default_dtype is torch.float:
  5353. return torch.cfloat
  5354. elif default_dtype is torch.double:
  5355. return torch.cdouble
  5356. elif default_dtype is torch.half:
  5357. return torch.chalf
  5358. else:
  5359. raise RuntimeError("invalid default scalar type for complex")
  5360. if isinstance(obj, torch.Tensor):
  5361. return obj.dtype
  5362. if isinstance(obj, str):
  5363. raise TypeError(f"new(): invalid data type '{type(obj).__name__}'")
  5364. # TODO: this is inaccurate, we actually test PySequence_Check
  5365. if isinstance(obj, (list, tuple)):
  5366. scalarType = None
  5367. length = len(obj)
  5368. # match NumPy semantics, except use default tensor type instead of
  5369. # double.
  5370. if length == 0:
  5371. return torch.get_default_dtype()
  5372. for i in range(length):
  5373. cur_item = obj[i]
  5374. # TODO: test this
  5375. """
  5376. if cur_item is obj:
  5377. raise TypeError("new(): self-referential lists are incompatible")
  5378. """
  5379. item_scalarType = _infer_scalar_type(cur_item) # recurse!
  5380. if scalarType is not None:
  5381. scalarType = torch.promote_types(scalarType, item_scalarType)
  5382. else:
  5383. scalarType = item_scalarType
  5384. if scalarType is torch.cdouble:
  5385. # this won't change (unless we hit undefined, but that will
  5386. # fail later)
  5387. return scalarType
  5388. return scalarType
  5389. raise RuntimeError(f"Could not infer dtype of {type(obj).__name__}")
  5390. # Analogous to recursive_store
  5391. # xref: recursive_store in torch/csrc/utils/tensor_new.cpp
  5392. def _recursive_build(
  5393. scalarType: torch.dtype, obj: Union[TensorOrNumberLikeType, TensorSequenceType]
  5394. ):
  5395. if isinstance(obj, Tensor) and obj.numel() == 1:
  5396. return obj.detach().to(dtype=scalarType, device="cpu", copy=True).view(())
  5397. elif isinstance(obj, Tensor):
  5398. # It is invalid to call ".tensor([...])" with a non-scalar tensor in eager mode
  5399. # >>> torch.tensor([torch.randn(2)])
  5400. # ValueError: only one element tensors can be converted to Python scalars
  5401. #
  5402. # But it is possible with a NumPy array
  5403. # >>> torch.tensor([np.random.uniform(size=(2,))]).shape
  5404. # torch.Size([1, 2])
  5405. return obj.detach().to(dtype=scalarType, device="cpu", copy=True)
  5406. elif isinstance(obj, Number):
  5407. return torch.scalar_tensor(obj, dtype=scalarType)
  5408. # seq can be a list of tensors
  5409. seq = obj
  5410. return torch.stack([_recursive_build(scalarType, item) for item in seq])
  5411. # xref: internal_new_from_data in torch/csrc/utils/tensor_new.cpp
  5412. def _internal_new_from_data(
  5413. options,
  5414. scalar_type,
  5415. device_opt,
  5416. data,
  5417. copy_variables,
  5418. copy_numpy,
  5419. type_inference,
  5420. pin_memory=False,
  5421. ):
  5422. if isinstance(data, torch.Tensor):
  5423. torch._check(
  5424. not pin_memory, lambda: "Can't pin tensor constructed from a variable"
  5425. )
  5426. var = data
  5427. if copy_variables:
  5428. var = var.detach()
  5429. inferred_scalar_type = var.dtype if type_inference else scalar_type
  5430. device = device_opt if device_opt is not None else var.device
  5431. return var.to(
  5432. device=device,
  5433. dtype=inferred_scalar_type,
  5434. non_blocking=False,
  5435. copy=copy_variables,
  5436. )
  5437. # TODO
  5438. if hasattr(data, "__cuda_array_interface__"):
  5439. return NotImplemented
  5440. # TODO: test for numpy input with PyArray_Check
  5441. device = device_opt if device_opt is not None else options["device"]
  5442. inferred_scalar_type = _infer_scalar_type(data) if type_inference else scalar_type
  5443. # NB: Don't need to avoid tracing, as we aren't going to do any manual
  5444. # pointer filling tricks
  5445. if _isStorage(data):
  5446. return NotImplemented
  5447. else:
  5448. if torch.device(device).type == "meta":
  5449. return NotImplemented
  5450. # In the C implementation, we would directly start poking the memory
  5451. # of a freshly allocated CPU tensor. Here, we're going to do an
  5452. # alternate, heinously slow implementation: turn each individual
  5453. # scalar into a tensor, and then repeatedly cat them together
  5454. tensor = _recursive_build(inferred_scalar_type, data)
  5455. tensor = tensor.to(device, inferred_scalar_type, non_blocking=False, copy=False)
  5456. # NB: lift_fresh is not needed, because we built the tensor from scalars
  5457. # guaranteeing a fresh tensor in this case
  5458. return tensor
  5459. # xref: tensor_ctor in torch/csrc/utils/tensor_new.cpp
  5460. def tensor(data, *, dtype=None, device=None, pin_memory=False, requires_grad=False):
  5461. # TODO (or not): support names kwarg
  5462. if isinstance(data, torch.Tensor):
  5463. warnings.warn(
  5464. "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() "
  5465. "or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor)"
  5466. )
  5467. type_inference = dtype is None
  5468. new_tensor = _internal_new_from_data(
  5469. # device="cpu" because that's what you get with torch.tensor(2) no
  5470. # device by default
  5471. {"device": "cpu"}, # TODO: use torch.get_default_tensor_type
  5472. dtype if dtype is not None else torch.get_default_dtype(),
  5473. device,
  5474. data,
  5475. copy_variables=True,
  5476. copy_numpy=True,
  5477. type_inference=type_inference,
  5478. pin_memory=pin_memory,
  5479. )
  5480. new_tensor.detach_()
  5481. if requires_grad:
  5482. new_tensor.requires_grad_(requires_grad)
  5483. return new_tensor
  5484. # Views
  5485. # We can't model these as above, as the pattern of doing `op(a, out=a)` does not work for a view function
  5486. # given that it does not reshape the input (it just copies the result into it)
  5487. # squeeze_ = _make_inplace(squeeze)
  5488. # t_ = _make_inplace(t)
  5489. # transpose_ = _make_inplace(transpose)
  5490. # unsqueeze_ = _make_inplace(unsqueeze)
  5491. import torch._refs._conversions
  5492. import torch._refs.fft
  5493. import torch._refs.linalg
  5494. import torch._refs.nn.functional
  5495. import torch._refs.special