common_nn.py 162 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986
  1. # mypy: ignore-errors
  2. from abc import abstractmethod
  3. import tempfile
  4. import unittest
  5. from copy import deepcopy
  6. from functools import reduce, partial
  7. from itertools import product
  8. from operator import mul
  9. import torch
  10. import torch.cuda
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. from torch.nn import _reduction as _Reduction
  14. from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
  15. gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo
  16. from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
  17. from torch.autograd.gradcheck import _get_numerical_jacobian, _iter_tensors
  18. from torch.autograd import Variable
  19. from torch.types import _TensorOrTensors
  20. import torch.backends.cudnn
  21. from typing import Dict, Callable, Tuple, List, Sequence, Union, Any
  22. TemporaryFile = tempfile.TemporaryFile
  23. PRECISION = 1e-5
  24. def get_reduction(m):
  25. result = getattr(m, 'reduction', None)
  26. if result is None:
  27. result = _Reduction.legacy_get_string(getattr(m, 'sizeAverage', None), True, emit_warning=False)
  28. assert result is not None
  29. return result
  30. def get_weight(m):
  31. result = getattr(m, 'weight', None)
  32. if result is not None:
  33. return result
  34. return getattr(m, 'weights', None)
  35. # NOTE [How to check NN module / functional API parity between Python and C++ frontends]
  36. #
  37. # The way to check API parity is to add parity tests for the NN module / functional of interest.
  38. # Here are the detailed steps:
  39. #
  40. # For NN module:
  41. # 1. Make sure you already have a test dict with the module configuration you want to test.
  42. # 2. Add `cpp_constructor_args` entry to the test dict, with its value exactly matching
  43. # the Python module constructor arguments. For example, if in the test dict we pass
  44. # `(10, 8)` to `torch.nn.Linear` constructor, then we should pass `torch::nn::LinearOptions(10, 8)`
  45. # as the corresponding C++ constructor argument to `torch::nn::Linear`.
  46. # 3. If in the process of performing the above step you referenced any variables
  47. # in the `cpp_constructor_args` entry, you must add `cpp_var_map` entry
  48. # to the test dict to make sure that those variables are populated with the right Python values.
  49. # For example, if the Python constructor call is
  50. # `torch.nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=random_samples)`,
  51. # the corresponding C++ constructor argument is
  52. # `torch::nn::FractionalMaxPool2dOptions(2).output_ratio(0.5)._random_samples(random_samples)`,
  53. # and the `cpp_var_map` entry must be
  54. # `{'random_samples': random_samples}` in order to populate the C++ variable `random_samples`
  55. # used in the C++ constructor argument with the Python tensor value `random_samples`.
  56. #
  57. # For NN functional:
  58. # 1. Make sure you already have a test dict with the functional configuration you want to test.
  59. # 2. If the test dict's `constructor` entry looks like `wrap_functional(F.some_functional_name, ...)`,
  60. # then you must add `cpp_options_args` entry to the test dict, with its value exactly matching the Python
  61. # functional optional arguments. For example, if the test dict's `constructor` entry is
  62. # `wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest')`,
  63. # then the `cpp_options_args` entry should be
  64. # "F::InterpolateFuncOptions().size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)".
  65. # 3. Otherwise, if the test dict's `constructor` entry looks like
  66. # `wrap_functional(lambda i: F.some_functional_name(...))`,
  67. # then you must add `cpp_function_call` entry to the test dict, with its value exactly matching the Python
  68. # functional function call. For example, if the test dict's `constructor` entry is
  69. # `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
  70. # then the `cpp_function_call` entry should be
  71. # "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
  72. # 4. If in the process of performing the above two steps you referenced any variables
  73. # in the `cpp_options_args` or `cpp_function_call` entry, you must
  74. # add `cpp_var_map` entry to the test dict to make sure that those variables
  75. # are populated with the right Python values. For example, if the test dict's `constructor` entry is
  76. # `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
  77. # then the `cpp_function_call` entry should be
  78. # "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
  79. # Notice that there are two variables `i` and `t` that need to have their values provided,
  80. # and the way to do so is to add a `cpp_var_map` entry: `cpp_var_map={'i': '_get_input()', 't': t}`.
  81. # (Note that for `i`, since we want it to take the Python input value, we pass '_get_input()' string as value
  82. # and the C++ parity test mechanism will populate `i` with the Python input value correctly.)
  83. #
  84. # There are also a few optional flags in the test dict to control the C++ parity test behavior:
  85. #
  86. # - `test_cpp_api_parity`: if `False`, skips the C++ parity test for this test dict. Default: True.
  87. # - `has_parity`: if `False`, expects this test dict to fail the C++ parity test. Default: True.
  88. module_tests = [
  89. dict(
  90. module_name='Linear',
  91. constructor_args=(10, 8),
  92. cpp_constructor_args='torch::nn::LinearOptions(10, 8)',
  93. input_size=(4, 10),
  94. reference_fn=lambda i, p, _: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8),
  95. with_tf32=True,
  96. tf32_precision=0.005,
  97. default_dtype=torch.double,
  98. ),
  99. dict(
  100. module_name='Linear',
  101. constructor_args=(10, 8, False),
  102. cpp_constructor_args='torch::nn::LinearOptions(10, 8).bias(false)',
  103. input_size=(4, 10),
  104. desc='no_bias',
  105. reference_fn=lambda i, p, _: torch.mm(i, p[0].t()),
  106. with_tf32=True,
  107. tf32_precision=0.005,
  108. default_dtype=torch.double,
  109. ),
  110. dict(
  111. module_name='RReLU',
  112. input_size=(1, 2, 2),
  113. test_cuda=False,
  114. default_dtype=torch.double,
  115. ),
  116. dict(
  117. module_name='RReLU',
  118. constructor_args=(0.1, 0.9),
  119. cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)',
  120. input_size=(4, 4, 5),
  121. desc='with_up_down',
  122. test_cuda=False,
  123. default_dtype=torch.double,
  124. ),
  125. dict(
  126. module_name='Flatten',
  127. input_size=(2, 3, 4, 5),
  128. reference_fn=lambda i, *_: torch.flatten(i, 1),
  129. default_dtype=torch.double,
  130. ),
  131. # TODO: reference function
  132. dict(
  133. module_name='CrossMapLRN2d',
  134. constructor_args=(5, 5e-3, 1e-3, 2),
  135. cpp_constructor_args='torch::nn::CrossMapLRN2dOptions(5).alpha(5e-3).beta(1e-3).k(2)',
  136. input_size=(2, 3, 6, 6),
  137. check_gradgrad=False,
  138. # TODO(#50743): Figure out the error. "RuntimeError: Unrecognized tensor type ID: Batched"
  139. check_batched_grad=False,
  140. default_dtype=torch.double,
  141. ),
  142. ]
  143. # Generates rand tensor with non-equal values. This ensures that duplicate
  144. # values won't be causing test failure for modules like MaxPooling.
  145. # size should be small, otherwise randperm fails / long overflows.
  146. def _rand_tensor_non_equal(*size):
  147. total = reduce(mul, size, 1)
  148. return torch.randperm(total).view(*size).double()
  149. def wrap_functional(fn, **kwargs):
  150. class FunctionalModule(nn.Module):
  151. def forward(self, *args):
  152. return fn(*args, **kwargs)
  153. return FunctionalModule
  154. def poissonnllloss_no_reduce_test():
  155. t = torch.randn(10, 10)
  156. return dict(
  157. fullname='PoissonNLLLoss_no_reduce',
  158. constructor=wrap_functional(
  159. lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')),
  160. cpp_function_call='F::poisson_nll_loss('
  161. 'i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))',
  162. input_fn=lambda: torch.rand(10, 10),
  163. cpp_var_map={'i': '_get_input()', 't': t},
  164. reference_fn=lambda i, *_: i.exp() - t.mul(i),
  165. pickle=False,
  166. default_dtype=torch.double)
  167. def bceloss_no_reduce_test():
  168. t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
  169. return dict(
  170. fullname='BCELoss_no_reduce',
  171. constructor=wrap_functional(
  172. lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
  173. cpp_function_call='F::binary_cross_entropy('
  174. 'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))',
  175. input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
  176. cpp_var_map={'i': '_get_input()', 't': t},
  177. reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
  178. pickle=False,
  179. precision=7e-4,
  180. default_dtype=torch.double)
  181. def bceloss_no_reduce_scalar_test():
  182. t = torch.randn(()).gt(0).to(torch.double)
  183. return dict(
  184. fullname='BCELoss_no_reduce_scalar',
  185. constructor=wrap_functional(
  186. lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
  187. cpp_function_call='F::binary_cross_entropy('
  188. 'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))',
  189. input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
  190. cpp_var_map={'i': '_get_input()', 't': t},
  191. reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
  192. pickle=False,
  193. default_dtype=torch.double)
  194. def bceloss_weights_no_reduce_test():
  195. t = Variable(torch.randn(15, 10, dtype=torch.double).gt(0).to(torch.double))
  196. weights = torch.rand(10, dtype=torch.double)
  197. return dict(
  198. fullname='BCELoss_weights_no_reduce',
  199. constructor=wrap_functional(
  200. lambda i: F.binary_cross_entropy(i, t.type_as(i),
  201. weight=weights.type_as(i), reduction='none')),
  202. cpp_function_call='F::binary_cross_entropy('
  203. 'i, t.to(i.options()), '
  204. 'F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))',
  205. input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
  206. cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
  207. reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
  208. pickle=False,
  209. precision=3e-4,
  210. default_dtype=torch.double,
  211. )
  212. def bceloss_weights_no_reduce_scalar_test():
  213. t = torch.randn(()).gt(0).to(torch.double)
  214. weights = torch.rand((), dtype=torch.double)
  215. return dict(
  216. fullname='BCELoss_weights_no_reduce_scalar',
  217. constructor=wrap_functional(
  218. lambda i: F.binary_cross_entropy(i, t.type_as(i),
  219. weight=weights.type_as(i), reduction='none')),
  220. cpp_function_call='''F::binary_cross_entropy(
  221. i, t.to(i.options()),
  222. F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
  223. cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
  224. input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
  225. reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
  226. pickle=False,
  227. default_dtype=torch.double,
  228. )
  229. def bce_with_logistic_legacy_enum_test():
  230. t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
  231. sigmoid = nn.Sigmoid()
  232. return dict(
  233. fullname='BCEWithLogitsLoss_legacy_enum',
  234. constructor=wrap_functional(
  235. lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)),
  236. cpp_function_call='''F::binary_cross_entropy_with_logits(
  237. i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
  238. input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
  239. cpp_var_map={'i': '_get_input()', 't': t},
  240. reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
  241. check_gradgrad=False,
  242. pickle=False,
  243. default_dtype=torch.double,
  244. )
  245. def bce_with_logistic_no_reduce_test():
  246. t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
  247. sigmoid = nn.Sigmoid()
  248. return dict(
  249. fullname='BCEWithLogitsLoss_no_reduce',
  250. constructor=wrap_functional(
  251. lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
  252. cpp_function_call='''F::binary_cross_entropy_with_logits(
  253. i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
  254. input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
  255. cpp_var_map={'i': '_get_input()', 't': t},
  256. reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
  257. check_gradgrad=False,
  258. pickle=False,
  259. default_dtype=torch.double,
  260. )
  261. def bce_with_logistic_no_reduce_scalar_test():
  262. t = torch.randn(()).gt(0).to(torch.double)
  263. sigmoid = nn.Sigmoid()
  264. return dict(
  265. fullname='BCEWithLogitsLoss_no_reduce_scalar',
  266. constructor=wrap_functional(
  267. lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
  268. cpp_function_call='''F::binary_cross_entropy_with_logits(
  269. i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
  270. input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
  271. cpp_var_map={'i': '_get_input()', 't': t},
  272. reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
  273. check_gradgrad=False,
  274. pickle=False,
  275. default_dtype=torch.double,
  276. )
  277. def kldivloss_with_target_no_reduce_test():
  278. t = torch.rand(10, 10, dtype=torch.double)
  279. return dict(
  280. fullname='KLDivLoss_with_target_no_reduce',
  281. constructor=wrap_functional(
  282. lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
  283. cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
  284. input_fn=lambda: torch.rand(10, 10).log(),
  285. cpp_var_map={'i': '_get_input()', 't': t},
  286. reference_fn=lambda i, *_:
  287. loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
  288. supports_forward_ad=True,
  289. pickle=False,
  290. default_dtype=torch.double)
  291. def kldivloss_no_reduce_test():
  292. t = torch.rand(10, 10, dtype=torch.double)
  293. return dict(
  294. fullname='KLDivLoss_no_reduce',
  295. constructor=wrap_functional(
  296. lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
  297. cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
  298. input_fn=lambda: torch.rand(10, 10).log(),
  299. cpp_var_map={'i': '_get_input()', 't': t},
  300. reference_fn=lambda i, *_:
  301. loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
  302. supports_forward_ad=True,
  303. pickle=False,
  304. default_dtype=torch.double,
  305. )
  306. def kldivloss_no_reduce_scalar_test():
  307. t = torch.rand((), dtype=torch.double)
  308. return dict(
  309. fullname='KLDivLoss_no_reduce_scalar',
  310. constructor=wrap_functional(
  311. lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
  312. cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
  313. input_fn=lambda: torch.rand(()).log(),
  314. cpp_var_map={'i': '_get_input()', 't': t},
  315. reference_fn=lambda i, *_:
  316. loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
  317. supports_forward_ad=True,
  318. pickle=False,
  319. default_dtype=torch.double)
  320. def kldivloss_with_log_target_no_reduce_test():
  321. t = torch.rand(10, 10, dtype=torch.double).log()
  322. return dict(
  323. fullname='KLDivLoss_with_log_target_no_reduce',
  324. constructor=wrap_functional(
  325. lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
  326. cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
  327. input_fn=lambda: torch.rand(10, 10).log(),
  328. cpp_var_map={'i': '_get_input()', 't': t},
  329. reference_fn=lambda i, *_:
  330. loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
  331. supports_forward_ad=True,
  332. pickle=False,
  333. default_dtype=torch.double)
  334. def kldivloss_no_reduce_log_target_test():
  335. t = torch.rand(10, 10, dtype=torch.double).log()
  336. return dict(
  337. fullname='KLDivLoss_no_reduce_log_target',
  338. constructor=wrap_functional(
  339. lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
  340. cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
  341. input_fn=lambda: torch.rand(10, 10).log(),
  342. cpp_var_map={'i': '_get_input()', 't': t},
  343. reference_fn=lambda i, *_:
  344. loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
  345. supports_forward_ad=True,
  346. pickle=False,
  347. default_dtype=torch.double,
  348. )
  349. def kldivloss_no_reduce_scalar_log_target_test():
  350. t = torch.rand((), dtype=torch.double).log()
  351. return dict(
  352. fullname='KLDivLoss_no_reduce_scalar_log_target',
  353. constructor=wrap_functional(
  354. lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
  355. cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
  356. input_fn=lambda: torch.rand(()).log(),
  357. cpp_var_map={'i': '_get_input()', 't': t},
  358. reference_fn=lambda i, *_:
  359. loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
  360. supports_forward_ad=True,
  361. pickle=False,
  362. default_dtype=torch.double)
  363. def l1loss_no_reduce_test():
  364. t = torch.randn(2, 3, 4, dtype=torch.double)
  365. return dict(
  366. fullname='L1Loss_no_reduce',
  367. constructor=wrap_functional(
  368. lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
  369. cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
  370. input_fn=lambda: torch.randn(2, 3, 4),
  371. cpp_var_map={'i': '_get_input()', 't': t},
  372. reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
  373. supports_forward_ad=True,
  374. pickle=False,
  375. default_dtype=torch.double)
  376. def l1loss_no_reduce_complex_test():
  377. t = torch.randn(2, 3, 4, dtype=torch.cdouble)
  378. return dict(
  379. fullname='L1Loss_no_reduce_complex',
  380. constructor=wrap_functional(
  381. lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
  382. cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
  383. input_fn=lambda: torch.randn(2, 3, 4, dtype=torch.cdouble),
  384. cpp_var_map={'i': '_get_input()', 't': t},
  385. reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
  386. supports_forward_ad=True,
  387. pickle=False)
  388. def l1loss_no_reduce_scalar_test():
  389. t = torch.randn((), dtype=torch.double)
  390. return dict(
  391. fullname='L1Loss_no_reduce_scalar',
  392. constructor=wrap_functional(
  393. lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
  394. cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
  395. input_fn=lambda: torch.randn(()),
  396. cpp_var_map={'i': '_get_input()', 't': t},
  397. reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
  398. supports_forward_ad=True,
  399. pickle=False,
  400. default_dtype=torch.double)
  401. def mseloss_no_reduce_test():
  402. input_size = (2, 3, 4, 5)
  403. target = torch.randn(*input_size, dtype=torch.double)
  404. return dict(
  405. fullname='MSELoss_no_reduce',
  406. constructor=wrap_functional(
  407. lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
  408. cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))',
  409. input_size=input_size,
  410. cpp_var_map={'i': '_get_input()', 'target': target},
  411. reference_fn=lambda i, *_: (i - target).pow(2),
  412. supports_forward_ad=True,
  413. pickle=False,
  414. default_dtype=torch.double)
  415. def mseloss_no_reduce_scalar_test():
  416. input_size = ()
  417. target = torch.randn(input_size, dtype=torch.double)
  418. return dict(
  419. fullname='MSELoss_no_reduce_scalar',
  420. constructor=wrap_functional(
  421. lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
  422. cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))',
  423. input_size=input_size,
  424. cpp_var_map={'i': '_get_input()', 'target': target},
  425. reference_fn=lambda i, *_: (i - target).pow(2),
  426. supports_forward_ad=True,
  427. pickle=False,
  428. default_dtype=torch.double)
  429. def nllloss_no_reduce_test():
  430. t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
  431. kwargs = {'reduction': 'none'}
  432. return dict(
  433. fullname='NLLLoss_no_reduce',
  434. constructor=wrap_functional(
  435. lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
  436. cpp_function_call='''F::nll_loss(
  437. i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
  438. input_fn=lambda: torch.rand(15, 10).log(),
  439. cpp_var_map={'i': '_get_input()', 't': t},
  440. reference_fn=lambda i, *_:
  441. loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
  442. pickle=False,
  443. default_dtype=torch.double)
  444. def nllloss_no_reduce_ignore_index_test():
  445. t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
  446. kwargs: Dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
  447. return dict(
  448. fullname='NLLLoss_no_reduce_ignore_index',
  449. constructor=wrap_functional(
  450. lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
  451. reduction=str(kwargs['reduction']))),
  452. cpp_function_call='''F::nll_loss(
  453. i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))''',
  454. input_fn=lambda: torch.rand(15, 10).log(),
  455. cpp_var_map={'i': '_get_input()', 't': t},
  456. reference_fn=lambda i, *_:
  457. loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
  458. pickle=False,
  459. default_dtype=torch.double)
  460. def nllloss_no_reduce_weights_test():
  461. t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
  462. weight = torch.rand(10)
  463. def kwargs(i):
  464. return {'weight': weight.type_as(i), 'reduction': 'none'}
  465. return dict(
  466. fullname='NLLLoss_no_reduce_weights',
  467. constructor=wrap_functional(
  468. lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
  469. cpp_function_call='''F::nll_loss(
  470. i, t.to(i.options()).to(torch::kLong),
  471. F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
  472. input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
  473. cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
  474. reference_fn=lambda i, *_:
  475. loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
  476. pickle=False,
  477. default_dtype=torch.double)
  478. def nllloss_no_reduce_weights_ignore_index_test():
  479. t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
  480. weight = torch.rand(10)
  481. def kwargs(i):
  482. return {'weight': weight.type_as(i), 'reduction': 'none',
  483. 'ignore_index': 2}
  484. return dict(
  485. fullname='NLLLoss_no_reduce_weights_ignore_index',
  486. constructor=wrap_functional(
  487. lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))),
  488. cpp_function_call='''F::nll_loss(
  489. i, t.to(i.options()).to(torch::kLong),
  490. F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(2))''',
  491. input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
  492. cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
  493. reference_fn=lambda i, *_:
  494. loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
  495. pickle=False,
  496. default_dtype=torch.double)
  497. def nllloss_no_reduce_weights_ignore_index_neg_test():
  498. t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
  499. weight = torch.rand(10)
  500. def kwargs(i):
  501. return {'weight': weight.type_as(i), 'reduction': 'none',
  502. 'ignore_index': -1}
  503. return dict(
  504. fullname='NLLLoss_no_reduce_weights_ignore_index_neg',
  505. constructor=wrap_functional(
  506. lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
  507. cpp_function_call='''F::nll_loss(
  508. i, t.to(i.options()).to(torch::kLong),
  509. F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(-1))''',
  510. input=torch.rand(15, 10, dtype=torch.double).add(1e-2).log(),
  511. cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
  512. reference_fn=lambda i, *_:
  513. loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
  514. pickle=False,
  515. default_dtype=torch.double)
  516. def nllloss2d_no_reduce_test():
  517. t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
  518. kwargs = {'reduction': 'none'}
  519. return dict(
  520. fullname='NLLLoss2d_no_reduce',
  521. constructor=wrap_functional(
  522. lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
  523. cpp_function_call='''F::nll_loss(
  524. i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
  525. input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
  526. cpp_var_map={'i': '_get_input()', 't': t},
  527. reference_fn=lambda i, *_:
  528. loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
  529. pickle=False,
  530. default_dtype=torch.double)
  531. def nllloss2d_no_reduce_ignore_index_test():
  532. t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
  533. kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
  534. return dict(
  535. fullname='NLLLoss2d_no_reduce_ignore_index',
  536. constructor=wrap_functional(
  537. lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
  538. reduction=str(kwargs['reduction']))),
  539. cpp_function_call='''F::nll_loss(
  540. i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
  541. input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
  542. cpp_var_map={'i': '_get_input()', 't': t},
  543. reference_fn=lambda i, *_:
  544. loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
  545. pickle=False,
  546. default_dtype=torch.double)
  547. def nllloss2d_no_reduce_weights_test():
  548. t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
  549. weight = torch.rand(3)
  550. def kwargs(i):
  551. return {'weight': weight.type_as(i), 'reduction': 'none'}
  552. return dict(
  553. fullname='NLLLoss2d_no_reduce_weights',
  554. constructor=wrap_functional(
  555. lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
  556. cpp_function_call='''F::nll_loss(
  557. i, t.to(i.options()).to(torch::kLong),
  558. F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
  559. input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
  560. cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
  561. reference_fn=lambda i, *_:
  562. loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
  563. pickle=False,
  564. default_dtype=torch.double)
  565. def nlllossNd_no_reduce_test():
  566. t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
  567. kwargs = {'reduction': 'none'}
  568. return dict(
  569. fullname='NLLLossNd_no_reduce',
  570. constructor=wrap_functional(
  571. lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
  572. cpp_function_call='''F::nll_loss(
  573. i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
  574. input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
  575. cpp_var_map={'i': '_get_input()', 't': t},
  576. reference_fn=lambda i, *_:
  577. loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
  578. pickle=False,
  579. default_dtype=torch.double)
  580. def nlllossNd_no_reduce_ignore_index_test():
  581. t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
  582. kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
  583. return dict(
  584. fullname='NLLLossNd_no_reduce_ignore_index',
  585. constructor=wrap_functional(
  586. lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
  587. reduction=str(kwargs['reduction']))),
  588. cpp_function_call='''F::nll_loss(
  589. i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
  590. input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
  591. cpp_var_map={'i': '_get_input()', 't': t},
  592. reference_fn=lambda i, *_:
  593. loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
  594. pickle=False,
  595. default_dtype=torch.double)
  596. def nlllossNd_no_reduce_weights_test():
  597. t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
  598. weight = torch.rand(3)
  599. def kwargs(i):
  600. return {'weight': weight.type_as(i), 'reduction': 'none'}
  601. return dict(
  602. fullname='NLLLossNd_no_reduce_weights',
  603. constructor=wrap_functional(
  604. lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
  605. cpp_function_call='''F::nll_loss(
  606. i, t.to(i.options()).to(torch::kLong),
  607. F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
  608. input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
  609. cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
  610. reference_fn=lambda i, *_:
  611. loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
  612. pickle=False,
  613. default_dtype=torch.double)
  614. def smoothl1loss_no_reduce_test():
  615. t = torch.randn(2, 3, 4, dtype=torch.double)
  616. return dict(
  617. fullname='SmoothL1Loss_no_reduce',
  618. constructor=wrap_functional(
  619. lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
  620. cpp_function_call='''F::smooth_l1_loss(
  621. i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''',
  622. input_fn=lambda: torch.randn(2, 3, 4),
  623. cpp_var_map={'i': '_get_input()', 't': t},
  624. reference_fn=lambda i, *_:
  625. loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
  626. supports_forward_ad=True,
  627. pickle=False,
  628. default_dtype=torch.double)
  629. def smoothl1loss_no_reduce_scalar_test():
  630. t = torch.randn((), dtype=torch.double)
  631. return dict(
  632. fullname='SmoothL1Loss_no_reduce_scalar',
  633. constructor=wrap_functional(
  634. lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
  635. cpp_function_call='''F::smooth_l1_loss(
  636. i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''',
  637. input_fn=lambda: torch.randn(()),
  638. cpp_var_map={'i': '_get_input()', 't': t},
  639. reference_fn=lambda i, *_:
  640. loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
  641. supports_forward_ad=True,
  642. pickle=False,
  643. default_dtype=torch.double)
  644. def smoothl1loss_beta_test():
  645. t = torch.randn(2, 3, 4, dtype=torch.double)
  646. return dict(
  647. fullname='SmoothL1Loss_beta',
  648. constructor=wrap_functional(
  649. lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0.5)),
  650. cpp_function_call='''F::smooth_l1_loss(
  651. i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0.5)''',
  652. input_fn=lambda: torch.randn(2, 3, 4),
  653. cpp_var_map={'i': '_get_input()', 't': t},
  654. reference_fn=lambda i, *_:
  655. loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0.5),
  656. supports_forward_ad=True,
  657. pickle=False,
  658. default_dtype=torch.double)
  659. def smoothl1loss_zero_beta_test():
  660. t = torch.randn(2, 3, 4, dtype=torch.double)
  661. return dict(
  662. fullname='SmoothL1Loss_zero_beta',
  663. constructor=wrap_functional(
  664. lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0)),
  665. cpp_function_call='''F::smooth_l1_loss(
  666. i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0)''',
  667. input_fn=lambda: torch.randn(2, 3, 4),
  668. cpp_var_map={'i': '_get_input()', 't': t},
  669. reference_fn=lambda i, *_:
  670. loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0),
  671. supports_forward_ad=True,
  672. pickle=False,
  673. default_dtype=torch.double)
  674. def huberloss_delta_test():
  675. t = torch.randn(2, 3, 4)
  676. return dict(
  677. fullname='HuberLoss_delta',
  678. constructor=wrap_functional(
  679. lambda i: F.huber_loss(i, t.type_as(i), reduction='none', delta=0.5)),
  680. cpp_function_call='''F::huber_loss(
  681. i, t.to(i.options()), F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5))''',
  682. input_fn=lambda: torch.randn(2, 3, 4),
  683. cpp_var_map={'i': '_get_input()', 't': t},
  684. reference_fn=lambda i, *_:
  685. loss_reference_fns['HuberLoss'](i, t.type_as(i), reduction='none', delta=0.5),
  686. supports_forward_ad=True,
  687. pickle=False,
  688. default_dtype=torch.double)
  689. def multilabelmarginloss_0d_no_reduce_test():
  690. t = torch.zeros(()).long()
  691. return dict(
  692. fullname='MultiLabelMarginLoss_0d_no_reduce',
  693. constructor=wrap_functional(
  694. lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
  695. cpp_function_call='''F::multilabel_margin_loss(
  696. i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
  697. input_fn=lambda: torch.randn(()),
  698. cpp_var_map={'i': '_get_input()', 't': t},
  699. reference_fn=lambda i, *_:
  700. loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
  701. check_sum_reduction=True,
  702. check_gradgrad=False,
  703. pickle=False)
  704. def multilabelmarginloss_1d_no_reduce_test():
  705. t = Variable(torch.rand(10).mul(10).floor().long())
  706. return dict(
  707. fullname='MultiLabelMarginLoss_1d_no_reduce',
  708. constructor=wrap_functional(
  709. lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
  710. cpp_function_call='''F::multilabel_margin_loss(
  711. i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
  712. input_fn=lambda: torch.randn(10),
  713. cpp_var_map={'i': '_get_input()', 't': t},
  714. reference_fn=lambda i, *_:
  715. loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
  716. check_sum_reduction=True,
  717. check_gradgrad=False,
  718. pickle=False,
  719. default_dtype=torch.double)
  720. def multilabelmarginloss_index_neg_test():
  721. t = Variable(torch.clamp(torch.rand(5, 10).add(-.5).mul(20).floor().long(), min=-1))
  722. return dict(
  723. fullname='MultiLabelMarginLoss_index_neg',
  724. constructor=wrap_functional(
  725. lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
  726. cpp_function_call='''F::multilabel_margin_loss(
  727. i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
  728. input_fn=lambda: torch.randn(5, 10),
  729. cpp_var_map={'i': '_get_input()', 't': t},
  730. reference_fn=lambda i, *_:
  731. loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
  732. check_sum_reduction=True,
  733. check_gradgrad=False,
  734. pickle=False,
  735. default_dtype=torch.double)
  736. def multilabelmarginloss_no_reduce_test():
  737. t = Variable(torch.rand(5, 10).mul(10).floor().long())
  738. return dict(
  739. fullname='MultiLabelMarginLoss_no_reduce',
  740. constructor=wrap_functional(
  741. lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
  742. cpp_function_call='''F::multilabel_margin_loss(
  743. i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
  744. input_fn=lambda: torch.randn(5, 10),
  745. cpp_var_map={'i': '_get_input()', 't': t},
  746. reference_fn=lambda i, *_:
  747. loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
  748. check_sum_reduction=True,
  749. check_gradgrad=False,
  750. pickle=False,
  751. default_dtype=torch.double)
  752. def hingeembeddingloss_no_reduce_test():
  753. t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1))
  754. return dict(
  755. fullname='HingeEmbeddingLoss_no_reduce',
  756. constructor=wrap_functional(
  757. lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction='none')),
  758. cpp_function_call='''F::hinge_embedding_loss(
  759. i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().reduction(torch::kNone))''',
  760. input_fn=lambda: torch.randn(10),
  761. cpp_var_map={'i': '_get_input()', 't': t},
  762. reference_fn=lambda i, *_:
  763. loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduction='none'),
  764. check_sum_reduction=True,
  765. pickle=False,
  766. default_dtype=torch.double)
  767. def hingeembeddingloss_margin_no_reduce_test():
  768. t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1))
  769. return dict(
  770. fullname='HingeEmbeddingLoss_margin_no_reduce',
  771. constructor=wrap_functional(
  772. lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction='none')),
  773. cpp_function_call='''F::hinge_embedding_loss(
  774. i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().margin(0.5).reduction(torch::kNone))''',
  775. input_fn=lambda: torch.randn(10),
  776. cpp_var_map={'i': '_get_input()', 't': t},
  777. reference_fn=lambda i, *_:
  778. loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction='none'),
  779. check_sum_reduction=True,
  780. pickle=False,
  781. default_dtype=torch.double)
  782. def softmarginloss_no_reduce_test():
  783. t = torch.randn(5, 5, dtype=torch.double)
  784. return dict(
  785. fullname='SoftMarginLoss_no_reduce',
  786. constructor=wrap_functional(
  787. lambda i: F.soft_margin_loss(i, t.type_as(i), reduction='none')),
  788. cpp_function_call='''F::soft_margin_loss(
  789. i, t.to(i.options()), F::SoftMarginLossFuncOptions().reduction(torch::kNone))''',
  790. input_fn=lambda: torch.randn(5, 5),
  791. cpp_var_map={'i': '_get_input()', 't': t},
  792. reference_fn=lambda i, *_:
  793. loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduction='none'),
  794. supports_forward_ad=True,
  795. pickle=False,
  796. default_dtype=torch.double)
  797. def multilabelsoftmarginloss_no_reduce_test():
  798. t = torch.rand(5, 10).mul(2).floor()
  799. return dict(
  800. fullname='MultiLabelSoftMarginLoss_no_reduce',
  801. constructor=wrap_functional(
  802. lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')),
  803. cpp_function_call='''F::multilabel_soft_margin_loss(
  804. i, t.to(i.options()), F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone))''',
  805. input_fn=lambda: torch.randn(5, 10),
  806. cpp_var_map={'i': '_get_input()', 't': t},
  807. reference_fn=lambda i, *_:
  808. (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())).sum(dim=1) / i.size(1),
  809. check_gradgrad=False,
  810. pickle=False,
  811. default_dtype=torch.double)
  812. def multilabelsoftmarginloss_weights_no_reduce_test():
  813. t = torch.rand(5, 10).mul(2).floor()
  814. weights = torch.rand(10)
  815. return dict(
  816. fullname='MultiLabelSoftMarginLoss_weights_no_reduce',
  817. constructor=wrap_functional(
  818. lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i),
  819. weight=weights.type_as(i), reduction='none')),
  820. cpp_function_call='''F::multilabel_soft_margin_loss(
  821. i, t.to(i.options()),
  822. F::MultilabelSoftMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
  823. input_fn=lambda: torch.randn(5, 10),
  824. cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
  825. reference_fn=lambda i, *_:
  826. (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights).sum(dim=1) / i.size(1),
  827. check_sum_reduction=True,
  828. check_gradgrad=False,
  829. pickle=False,
  830. default_dtype=torch.double)
  831. def multimarginloss_no_reduce_test():
  832. t = torch.rand(5).mul(8).floor().long()
  833. return dict(
  834. fullname='MultiMarginLoss_no_reduce',
  835. constructor=wrap_functional(
  836. lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
  837. cpp_function_call='''F::multi_margin_loss(
  838. i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
  839. input_fn=lambda: torch.randn(5, 10),
  840. cpp_var_map={'i': '_get_input()', 't': t},
  841. reference_fn=lambda i, *_:
  842. loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
  843. check_sum_reduction=True,
  844. check_gradgrad=False,
  845. pickle=False,
  846. default_dtype=torch.double)
  847. def multimarginloss_1d_no_reduce_test():
  848. t = torch.rand(1).mul(8).floor().long()
  849. return dict(
  850. fullname='MultiMarginLoss_1d_no_reduce',
  851. constructor=wrap_functional(
  852. lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
  853. cpp_function_call='''F::multi_margin_loss(
  854. i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
  855. input_fn=lambda: torch.randn(10),
  856. cpp_var_map={'i': '_get_input()', 't': t},
  857. reference_fn=lambda i, *_:
  858. loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
  859. check_sum_reduction=True,
  860. check_gradgrad=False,
  861. pickle=False,
  862. default_dtype=torch.double)
  863. def multimarginloss_1d_input_0d_target_no_reduce_test():
  864. t = torch.rand(()).mul(8).floor().long()
  865. return dict(
  866. fullname='multimarginloss_1d_input_0d_target_no_reduce',
  867. constructor=wrap_functional(
  868. lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
  869. cpp_function_call='''F::multi_margin_loss(
  870. i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
  871. input_fn=lambda: torch.randn(10),
  872. cpp_var_map={'i': '_get_input()', 't': t},
  873. reference_fn=lambda i, *_:
  874. loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
  875. check_sum_reduction=True,
  876. check_gradgrad=False,
  877. pickle=False,
  878. default_dtype=torch.double)
  879. def multimarginloss_p_no_reduce_test():
  880. t = torch.rand(5).mul(8).floor().long()
  881. return dict(
  882. fullname='MultiMarginLoss_p_no_reduce',
  883. constructor=wrap_functional(
  884. lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction='none')),
  885. cpp_function_call='''F::multi_margin_loss(
  886. i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().p(2).reduction(torch::kNone))''',
  887. input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2),
  888. cpp_var_map={'i': '_get_input()', 't': t},
  889. reference_fn=lambda i, *_:
  890. loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction='none'),
  891. check_sum_reduction=True,
  892. check_gradgrad=False,
  893. pickle=False,
  894. default_dtype=torch.double)
  895. def multimarginloss_margin_no_reduce_test():
  896. t = torch.rand(5).mul(8).floor().long()
  897. return dict(
  898. fullname='MultiMarginLoss_margin_no_reduce',
  899. constructor=wrap_functional(
  900. lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction='none')),
  901. cpp_function_call='''F::multi_margin_loss(
  902. i, t.to(i.options()).to(torch::kLong),
  903. F::MultiMarginLossFuncOptions().margin(0.5).reduction(torch::kNone))''',
  904. input_fn=lambda: torch.randn(5, 10),
  905. cpp_var_map={'i': '_get_input()', 't': t},
  906. reference_fn=lambda i, *_:
  907. loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
  908. margin=0.5, reduction='none'),
  909. check_sum_reduction=True,
  910. check_gradgrad=False,
  911. pickle=False,
  912. default_dtype=torch.double)
  913. def multimarginloss_weights_no_reduce_test():
  914. t = torch.rand(5).mul(8).floor().long()
  915. weights = torch.rand(10, dtype=torch.double)
  916. return dict(
  917. fullname='MultiMarginLoss_weights_no_reduce',
  918. constructor=wrap_functional(
  919. lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i),
  920. reduction='none')),
  921. cpp_function_call='''F::multi_margin_loss(
  922. i, t.to(i.options()).to(torch::kLong),
  923. F::MultiMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
  924. input_fn=lambda: torch.randn(5, 10),
  925. cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
  926. reference_fn=lambda i, *_:
  927. loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
  928. weight=weights, reduction='none'),
  929. check_sum_reduction=True,
  930. check_gradgrad=False,
  931. pickle=False,
  932. default_dtype=torch.double)
  933. def single_batch_reference_fn(input, parameters, module):
  934. """Reference function for modules supporting no batch dimensions.
  935. The module is passed the input and target in batched form with a single item.
  936. The output is squeezed to compare with the no-batch input.
  937. """
  938. def unsqueeze_inp(inp):
  939. if isinstance(inp, (list, tuple)):
  940. return [t.unsqueeze(0) for t in inp]
  941. return inp.unsqueeze(0)
  942. single_batch_input = unsqueeze_inp(input)
  943. single_batch_input = [single_batch_input] if isinstance(single_batch_input, torch.Tensor) else single_batch_input
  944. with freeze_rng_state():
  945. return module(*single_batch_input).squeeze(0)
  946. new_module_tests = [
  947. poissonnllloss_no_reduce_test(),
  948. bceloss_no_reduce_test(),
  949. bceloss_weights_no_reduce_test(),
  950. bce_with_logistic_legacy_enum_test(),
  951. bce_with_logistic_no_reduce_test(),
  952. bceloss_no_reduce_scalar_test(),
  953. bceloss_weights_no_reduce_scalar_test(),
  954. bce_with_logistic_no_reduce_scalar_test(),
  955. kldivloss_with_target_no_reduce_test(),
  956. kldivloss_no_reduce_test(),
  957. kldivloss_no_reduce_scalar_test(),
  958. kldivloss_with_log_target_no_reduce_test(),
  959. kldivloss_no_reduce_log_target_test(),
  960. kldivloss_no_reduce_scalar_log_target_test(),
  961. l1loss_no_reduce_test(),
  962. l1loss_no_reduce_complex_test(),
  963. l1loss_no_reduce_scalar_test(),
  964. mseloss_no_reduce_test(),
  965. mseloss_no_reduce_scalar_test(),
  966. nllloss_no_reduce_test(),
  967. nllloss_no_reduce_ignore_index_test(),
  968. nllloss_no_reduce_weights_test(),
  969. nllloss_no_reduce_weights_ignore_index_test(),
  970. nllloss_no_reduce_weights_ignore_index_neg_test(),
  971. nllloss2d_no_reduce_test(),
  972. nllloss2d_no_reduce_weights_test(),
  973. nllloss2d_no_reduce_ignore_index_test(),
  974. nlllossNd_no_reduce_test(),
  975. nlllossNd_no_reduce_weights_test(),
  976. nlllossNd_no_reduce_ignore_index_test(),
  977. smoothl1loss_no_reduce_test(),
  978. smoothl1loss_no_reduce_scalar_test(),
  979. smoothl1loss_beta_test(),
  980. smoothl1loss_zero_beta_test(),
  981. huberloss_delta_test(),
  982. multilabelmarginloss_0d_no_reduce_test(),
  983. multilabelmarginloss_1d_no_reduce_test(),
  984. multilabelmarginloss_index_neg_test(),
  985. multilabelmarginloss_no_reduce_test(),
  986. hingeembeddingloss_no_reduce_test(),
  987. hingeembeddingloss_margin_no_reduce_test(),
  988. softmarginloss_no_reduce_test(),
  989. multilabelsoftmarginloss_no_reduce_test(),
  990. multilabelsoftmarginloss_weights_no_reduce_test(),
  991. multimarginloss_no_reduce_test(),
  992. multimarginloss_1d_no_reduce_test(),
  993. multimarginloss_1d_input_0d_target_no_reduce_test(),
  994. multimarginloss_p_no_reduce_test(),
  995. multimarginloss_margin_no_reduce_test(),
  996. multimarginloss_weights_no_reduce_test(),
  997. dict(
  998. module_name='Conv1d',
  999. constructor_args=(4, 5, 3),
  1000. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
  1001. input_size=(2, 4, 10),
  1002. cudnn=True,
  1003. with_tf32=True,
  1004. tf32_precision=0.005,
  1005. default_dtype=torch.double,
  1006. ),
  1007. dict(
  1008. module_name='Conv1d',
  1009. constructor_args=(4, 5, 3, 2),
  1010. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(2)',
  1011. input_size=(2, 4, 10),
  1012. cudnn=True,
  1013. desc='stride',
  1014. with_tf32=True,
  1015. tf32_precision=0.005,
  1016. default_dtype=torch.double,
  1017. ),
  1018. dict(
  1019. module_name='Conv1d',
  1020. constructor_args=(4, 5, 3, 1, 1),
  1021. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)',
  1022. input_size=(2, 4, 10),
  1023. cudnn=True,
  1024. desc='pad1',
  1025. with_tf32=True,
  1026. tf32_precision=0.01,
  1027. default_dtype=torch.double,
  1028. ),
  1029. dict(
  1030. module_name='Conv1d',
  1031. constructor_args=(4, 5, 5, 1, 2),
  1032. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)',
  1033. input_size=(2, 4, 10),
  1034. cudnn=True,
  1035. desc='pad2',
  1036. with_tf32=True,
  1037. tf32_precision=0.005,
  1038. default_dtype=torch.double,
  1039. ),
  1040. dict(
  1041. module_name='Conv1d',
  1042. constructor_args=(4, 4, 3, 1, 1),
  1043. cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1)',
  1044. input_size=(1, 4, 1),
  1045. cudnn=True,
  1046. desc='pad1size1',
  1047. with_tf32=True,
  1048. tf32_precision=0.005,
  1049. default_dtype=torch.double,
  1050. ),
  1051. dict(
  1052. module_name='Conv1d',
  1053. constructor_args=(4, 4, 5, 1, 2),
  1054. cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)',
  1055. input_size=(1, 4, 1),
  1056. cudnn=True,
  1057. desc='pad2size1',
  1058. with_tf32=True,
  1059. tf32_precision=0.005,
  1060. default_dtype=torch.double,
  1061. ),
  1062. dict(
  1063. module_name='Conv1d',
  1064. constructor_args=(4, 5, 3),
  1065. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
  1066. input_size=(0, 4, 10),
  1067. cudnn=True,
  1068. desc='zero_batch',
  1069. with_tf32=True,
  1070. tf32_precision=0.005,
  1071. ),
  1072. dict(
  1073. fullname='Conv1d_dilated',
  1074. constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2),
  1075. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)',
  1076. input_size=(2, 4, 10),
  1077. with_tf32=True,
  1078. tf32_precision=0.005,
  1079. default_dtype=torch.double,
  1080. ),
  1081. dict(
  1082. fullname='Conv1d_groups',
  1083. constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2),
  1084. cpp_constructor_args='torch::nn::Conv1dOptions(4, 6, 3).groups(2)',
  1085. input_size=(2, 4, 6),
  1086. cudnn=True,
  1087. with_tf32=True,
  1088. tf32_precision=0.005,
  1089. default_dtype=torch.double,
  1090. ),
  1091. dict(
  1092. fullname='Conv1d_pad_valid',
  1093. constructor=lambda: nn.Conv1d(4, 5, 3, padding="valid"),
  1094. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kValid)',
  1095. input_size=(2, 4, 10),
  1096. cudnn=True,
  1097. with_tf32=True,
  1098. tf32_precision=0.005,
  1099. default_dtype=torch.double,
  1100. ),
  1101. dict(
  1102. fullname='Conv1d_pad_same',
  1103. constructor=lambda: nn.Conv1d(4, 5, 3, padding="same"),
  1104. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame)',
  1105. input_size=(2, 4, 10),
  1106. cudnn=True,
  1107. with_tf32=True,
  1108. tf32_precision=0.005,
  1109. default_dtype=torch.double,
  1110. ),
  1111. dict(
  1112. fullname='Conv1d_pad_same2',
  1113. constructor=lambda: nn.Conv1d(4, 5, 4, padding="same"),
  1114. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 4).padding(torch::kSame)',
  1115. input_size=(2, 4, 10),
  1116. cudnn=True,
  1117. with_tf32=True,
  1118. tf32_precision=0.005,
  1119. default_dtype=torch.double,
  1120. ),
  1121. dict(
  1122. fullname='Conv1d_pad_same_dilated',
  1123. constructor=lambda: nn.Conv1d(4, 5, 4, padding="same", dilation=2),
  1124. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame).dilation(2)',
  1125. input_size=(2, 4, 10),
  1126. cudnn=True,
  1127. with_tf32=True,
  1128. tf32_precision=0.005,
  1129. default_dtype=torch.double,
  1130. ),
  1131. dict(
  1132. fullname='ConvTranspose1d',
  1133. constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)),
  1134. cpp_constructor_args='torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)',
  1135. cudnn=True,
  1136. input_size=(1, 3, 7),
  1137. with_tf32=True,
  1138. tf32_precision=0.005,
  1139. default_dtype=torch.double,
  1140. ),
  1141. dict(
  1142. module_name='ConvTranspose1d',
  1143. constructor_args=(3, 4, 3, 2, 1, 1, 1, False),
  1144. cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3)
  1145. .stride(2).padding(1).output_padding(1).groups(1).bias(false)''',
  1146. input_size=(1, 3, 6),
  1147. cudnn=True,
  1148. desc='no_bias',
  1149. with_tf32=True,
  1150. tf32_precision=0.005,
  1151. default_dtype=torch.double,
  1152. ),
  1153. dict(
  1154. module_name='ConvTranspose1d',
  1155. constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2),
  1156. cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3)
  1157. .stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)''',
  1158. input_size=(1, 3, 6),
  1159. cudnn=True,
  1160. desc='dilated',
  1161. with_tf32=True,
  1162. tf32_precision=0.005,
  1163. default_dtype=torch.double,
  1164. ),
  1165. dict(
  1166. fullname='ConvTranspose1d_groups',
  1167. constructor=lambda: nn.ConvTranspose1d(4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2),
  1168. cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(4, 6, 3)
  1169. .stride(3).padding(1).output_padding(1).groups(2)''',
  1170. cudnn=True,
  1171. input_size=(2, 4, 7),
  1172. with_tf32=True,
  1173. tf32_precision=0.005,
  1174. default_dtype=torch.double,
  1175. ),
  1176. dict(
  1177. module_name='Conv2d',
  1178. constructor_args=(3, 4, (3, 2)),
  1179. cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})',
  1180. input_size=(2, 3, 7, 5),
  1181. cudnn=True,
  1182. check_with_long_tensor=True,
  1183. with_tf32=True,
  1184. tf32_precision=0.005,
  1185. default_dtype=torch.double,
  1186. ),
  1187. dict(
  1188. module_name='Conv2d',
  1189. constructor_args=(3, 4, (3, 3), (2, 2)),
  1190. cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})',
  1191. input_size=(2, 3, 6, 6),
  1192. cudnn=True,
  1193. desc='strided',
  1194. check_with_long_tensor=True,
  1195. with_tf32=True,
  1196. tf32_precision=0.005,
  1197. default_dtype=torch.double,
  1198. ),
  1199. dict(
  1200. module_name='Conv2d',
  1201. constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)),
  1202. cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})',
  1203. input_size=(2, 3, 6, 6),
  1204. cudnn=True,
  1205. desc='padding',
  1206. check_with_long_tensor=True,
  1207. with_tf32=True,
  1208. tf32_precision=0.005,
  1209. default_dtype=torch.double,
  1210. ),
  1211. dict(
  1212. module_name='Conv2d',
  1213. constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)),
  1214. cpp_constructor_args='torch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2})',
  1215. input_size=(2, 3, 8, 8),
  1216. cudnn=True,
  1217. desc='dilated',
  1218. check_with_long_tensor=True,
  1219. with_tf32=True,
  1220. tf32_precision=0.005,
  1221. default_dtype=torch.double,
  1222. ),
  1223. dict(
  1224. module_name='Conv2d',
  1225. constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False),
  1226. cpp_constructor_args='''torch::nn::Conv2dOptions(3, 4, {3, 2})
  1227. .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
  1228. input_size=(2, 3, 6, 5),
  1229. cudnn=True,
  1230. desc='no_bias',
  1231. check_with_long_tensor=True,
  1232. with_tf32=True,
  1233. tf32_precision=0.015,
  1234. default_dtype=torch.double,
  1235. ),
  1236. dict(
  1237. module_name='Conv2d',
  1238. constructor_args=(3, 4, (3, 2)),
  1239. cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})',
  1240. input_size=(0, 3, 7, 5),
  1241. cudnn=True,
  1242. desc='zero_batch',
  1243. check_with_long_tensor=True,
  1244. with_tf32=True,
  1245. ),
  1246. dict(
  1247. fullname='Conv2d_groups',
  1248. constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
  1249. cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
  1250. input_size=(2, 4, 6, 5),
  1251. cudnn=True,
  1252. check_with_long_tensor=True,
  1253. with_tf32=True,
  1254. tf32_precision=0.015,
  1255. default_dtype=torch.double,
  1256. ),
  1257. dict(
  1258. fullname='Conv2d_groups_thnn',
  1259. constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
  1260. cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
  1261. input_size=(2, 4, 6, 5),
  1262. check_with_long_tensor=True,
  1263. with_tf32=True,
  1264. tf32_precision=0.015,
  1265. default_dtype=torch.double,
  1266. ),
  1267. dict(
  1268. fullname='Conv2d_pad_valid',
  1269. constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="valid"),
  1270. cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kValid)',
  1271. input_size=(2, 2, 6, 5),
  1272. cudnn=True,
  1273. with_tf32=True,
  1274. tf32_precision=0.005,
  1275. default_dtype=torch.double,
  1276. ),
  1277. dict(
  1278. fullname='Conv2d_pad_same',
  1279. constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same"),
  1280. cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame)',
  1281. input_size=(2, 2, 6, 5),
  1282. cudnn=True,
  1283. with_tf32=True,
  1284. tf32_precision=0.01,
  1285. default_dtype=torch.double,
  1286. ),
  1287. dict(
  1288. fullname='Conv2d_pad_same_dilated',
  1289. constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same", dilation=2),
  1290. cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame).dilation(2)',
  1291. input_size=(2, 2, 6, 5),
  1292. cudnn=True,
  1293. with_tf32=True,
  1294. tf32_precision=0.01,
  1295. default_dtype=torch.double,
  1296. ),
  1297. dict(
  1298. module_name='ConvTranspose2d',
  1299. constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)),
  1300. cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
  1301. .stride({3, 2}).padding(1).output_padding({1, 1})''',
  1302. cudnn=True,
  1303. input_size=(1, 3, 7, 6),
  1304. check_with_long_tensor=True,
  1305. with_tf32=True,
  1306. tf32_precision=0.01,
  1307. default_dtype=torch.double,
  1308. ),
  1309. dict(
  1310. module_name='ConvTranspose2d',
  1311. constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)),
  1312. cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
  1313. .stride({2, 3})
  1314. .padding(1)
  1315. .output_padding({1, 1})
  1316. .groups(1)
  1317. .bias(false)
  1318. .dilation({2, 2})''',
  1319. input_size=(1, 3, 6, 7),
  1320. cudnn=True,
  1321. desc='dilated',
  1322. check_with_long_tensor=True,
  1323. with_tf32=True,
  1324. tf32_precision=0.01,
  1325. default_dtype=torch.double,
  1326. ),
  1327. dict(
  1328. module_name='ConvTranspose2d',
  1329. constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False),
  1330. cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
  1331. .stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)''',
  1332. input_size=(1, 3, 6, 7),
  1333. cudnn=True,
  1334. desc='no_bias',
  1335. check_with_long_tensor=True,
  1336. with_tf32=True,
  1337. tf32_precision=0.01,
  1338. default_dtype=torch.double,
  1339. ),
  1340. dict(
  1341. fullname='ConvTranspose2d_groups',
  1342. constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2),
  1343. cpp_constructor_args='torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2)',
  1344. input_size=(1, 2, 4, 5),
  1345. cudnn=True,
  1346. check_with_long_tensor=True,
  1347. with_tf32=True,
  1348. tf32_precision=0.01,
  1349. default_dtype=torch.double,
  1350. ),
  1351. dict(
  1352. fullname='Conv2d_depthwise',
  1353. constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4),
  1354. cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)',
  1355. input_size=(2, 4, 6, 6),
  1356. with_tf32=True,
  1357. tf32_precision=0.005,
  1358. default_dtype=torch.double,
  1359. ),
  1360. dict(
  1361. fullname='Conv2d_depthwise_with_multiplier',
  1362. constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4),
  1363. cpp_constructor_args='torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)',
  1364. input_size=(2, 4, 6, 6),
  1365. with_tf32=True,
  1366. tf32_precision=0.005,
  1367. default_dtype=torch.double,
  1368. ),
  1369. dict(
  1370. fullname='Conv2d_depthwise_strided',
  1371. constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4),
  1372. cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)',
  1373. input_size=(2, 4, 6, 6),
  1374. with_tf32=True,
  1375. tf32_precision=0.005,
  1376. default_dtype=torch.double,
  1377. ),
  1378. dict(
  1379. fullname='Conv2d_depthwise_padded',
  1380. constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4),
  1381. cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)',
  1382. input_size=(2, 4, 6, 6),
  1383. with_tf32=True,
  1384. tf32_precision=0.005,
  1385. default_dtype=torch.double,
  1386. ),
  1387. dict(
  1388. fullname='Conv2d_depthwise_dilated',
  1389. constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
  1390. cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)',
  1391. input_size=(2, 4, 5, 5),
  1392. with_tf32=True,
  1393. tf32_precision=0.005,
  1394. default_dtype=torch.double,
  1395. ),
  1396. dict(
  1397. module_name='Conv3d',
  1398. constructor_args=(2, 3, (2, 3, 2)),
  1399. cpp_constructor_args='torch::nn::Conv3dOptions(2, 3, {2, 3, 2})',
  1400. input_size=(1, 2, 4, 5, 4),
  1401. cudnn=True,
  1402. check_with_long_tensor=True,
  1403. with_tf32=True,
  1404. tf32_precision=0.05,
  1405. default_dtype=torch.double,
  1406. ),
  1407. dict(
  1408. module_name='Conv3d',
  1409. constructor_args=(2, 3, (2, 3, 4), 1, 0, 1, 1, False),
  1410. cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4})
  1411. .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
  1412. input_size=(1, 2, 3, 4, 5),
  1413. cudnn=True,
  1414. desc='no_bias',
  1415. check_with_long_tensor=True,
  1416. with_tf32=True,
  1417. tf32_precision=0.05,
  1418. default_dtype=torch.double,
  1419. ),
  1420. dict(
  1421. module_name='Conv3d',
  1422. constructor_args=(2, 3, (1, 1, 1), 1, 0, 1, 1, False),
  1423. cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4})
  1424. .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
  1425. input_size=(1, 2, 3, 4, 5),
  1426. cudnn=True,
  1427. desc='1x1x1_no_bias',
  1428. check_with_long_tensor=False,
  1429. with_tf32=True,
  1430. tf32_precision=0.05,
  1431. default_dtype=torch.double,
  1432. ),
  1433. dict(
  1434. module_name='Conv3d',
  1435. constructor_args=(3, 4, 2, 2),
  1436. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2)',
  1437. input_size=(2, 3, 5, 5, 5),
  1438. cudnn=True,
  1439. desc='stride',
  1440. check_with_long_tensor=True,
  1441. with_tf32=True,
  1442. tf32_precision=0.05,
  1443. default_dtype=torch.double,
  1444. ),
  1445. dict(
  1446. module_name='Conv3d',
  1447. constructor_args=(3, 4, 2, 2, 1),
  1448. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)',
  1449. input_size=(2, 3, 5, 5, 5),
  1450. cudnn=True,
  1451. desc='stride_padding',
  1452. check_with_long_tensor=True,
  1453. with_tf32=True,
  1454. tf32_precision=0.05,
  1455. default_dtype=torch.double,
  1456. ),
  1457. dict(
  1458. module_name='Conv3d',
  1459. constructor_args=(3, 4, (2, 3, 4)),
  1460. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4})',
  1461. input_size=(0, 3, 3, 4, 5),
  1462. cudnn=True,
  1463. check_with_long_tensor=True,
  1464. desc='zero_batch',
  1465. with_tf32=True,
  1466. ),
  1467. dict(
  1468. fullname='Conv3d_groups',
  1469. constructor=lambda: nn.Conv3d(2, 4, kernel_size=3, groups=2),
  1470. cpp_constructor_args='torch::nn::Conv3dOptions(2, 4, 3).groups(2)',
  1471. input_size=(1, 2, 4, 5, 4),
  1472. cudnn=True,
  1473. check_with_long_tensor=True,
  1474. with_tf32=True,
  1475. tf32_precision=0.005,
  1476. default_dtype=torch.double,
  1477. ),
  1478. dict(
  1479. fullname='Conv3d_dilated',
  1480. constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2),
  1481. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)',
  1482. input_size=(2, 3, 5, 5, 5),
  1483. with_tf32=True,
  1484. tf32_precision=0.05,
  1485. default_dtype=torch.double,
  1486. ),
  1487. dict(
  1488. fullname='Conv3d_dilated_strided',
  1489. constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2),
  1490. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)',
  1491. input_size=(2, 3, 5, 5, 5),
  1492. with_tf32=True,
  1493. tf32_precision=0.05,
  1494. default_dtype=torch.double,
  1495. ),
  1496. dict(
  1497. fullname='Conv3d_pad_valid',
  1498. constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="valid"),
  1499. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kValid)',
  1500. input_size=(2, 3, 6, 5, 4),
  1501. cudnn=True,
  1502. with_tf32=True,
  1503. tf32_precision=0.05,
  1504. default_dtype=torch.double,
  1505. ),
  1506. dict(
  1507. fullname='Conv3d_pad_same',
  1508. constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same"),
  1509. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame)',
  1510. input_size=(2, 3, 6, 5, 4),
  1511. cudnn=True,
  1512. with_tf32=True,
  1513. tf32_precision=0.05,
  1514. default_dtype=torch.double,
  1515. ),
  1516. dict(
  1517. fullname='Conv3d_pad_same_dilated',
  1518. constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same", dilation=2),
  1519. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame).dilation(2)',
  1520. input_size=(2, 3, 6, 5, 4),
  1521. cudnn=True,
  1522. with_tf32=True,
  1523. tf32_precision=0.05,
  1524. default_dtype=torch.double,
  1525. ),
  1526. dict(
  1527. module_name='ConvTranspose3d',
  1528. constructor_args=(2, 3, (2, 3, 2)),
  1529. cpp_constructor_args='torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})',
  1530. cudnn=True,
  1531. input_size=(1, 2, 4, 5, 4),
  1532. with_tf32=True,
  1533. tf32_precision=0.05,
  1534. default_dtype=torch.double,
  1535. ),
  1536. dict(
  1537. module_name='ConvTranspose3d',
  1538. constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)),
  1539. cpp_constructor_args='''torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})
  1540. .stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2})''',
  1541. cudnn=True,
  1542. input_size=(1, 2, 4, 5, 4),
  1543. desc='dilated',
  1544. with_tf32=True,
  1545. tf32_precision=0.05,
  1546. default_dtype=torch.double,
  1547. ),
  1548. dict(
  1549. module_name='ReplicationPad3d',
  1550. constructor_args=((1, 2, 3, 3, 2, 1),),
  1551. cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
  1552. input_size=(2, 3, 2, 2, 2),
  1553. default_dtype=torch.double,
  1554. ),
  1555. dict(
  1556. module_name='ReplicationPad3d',
  1557. constructor_args=((1, 2, 3, 3, 2, 1),),
  1558. cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
  1559. input_size=(3, 2, 2, 2),
  1560. reference_fn=single_batch_reference_fn,
  1561. desc='no_batch_dim',
  1562. default_dtype=torch.double,
  1563. ),
  1564. dict(
  1565. module_name='ReplicationPad3d',
  1566. constructor_args=((1, 2, 3, 3, 2, 1),),
  1567. cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
  1568. input_fn=lambda: torch.rand(2, 3, 2, 2, 2, dtype=torch.complex128, requires_grad=True),
  1569. skip_half=True,
  1570. desc='complex'
  1571. ),
  1572. dict(
  1573. module_name='Embedding',
  1574. constructor_args=(4, 3),
  1575. cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
  1576. input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
  1577. check_gradgrad=False,
  1578. default_dtype=torch.double,
  1579. decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
  1580. ),
  1581. dict(
  1582. module_name='Embedding',
  1583. constructor_args=(4, 3),
  1584. cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
  1585. input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
  1586. check_gradgrad=False,
  1587. desc='discontiguous',
  1588. default_dtype=torch.double,
  1589. decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
  1590. ),
  1591. dict(
  1592. module_name='EmbeddingBag',
  1593. constructor_args=(4, 3),
  1594. cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
  1595. input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
  1596. check_gradgrad=False,
  1597. desc='mean',
  1598. default_dtype=torch.double,
  1599. ),
  1600. dict(
  1601. module_name='EmbeddingBag',
  1602. constructor_args=(4, 3),
  1603. cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
  1604. input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
  1605. check_gradgrad=False,
  1606. desc='discontiguous',
  1607. default_dtype=torch.double,
  1608. ),
  1609. dict(
  1610. module_name='EmbeddingBag',
  1611. constructor_args=(4, 3, None, 2., False, 'sum'),
  1612. cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
  1613. .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''',
  1614. input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
  1615. check_gradgrad=False,
  1616. desc='sum',
  1617. default_dtype=torch.double,
  1618. ),
  1619. dict(
  1620. module_name='EmbeddingBag',
  1621. constructor_args=(4, 3, None, 2., False, 'max'),
  1622. cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
  1623. .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)''',
  1624. input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
  1625. check_gradgrad=False,
  1626. desc='max',
  1627. default_dtype=torch.double,
  1628. ),
  1629. dict(
  1630. fullname='EmbeddingBag_mean_padding_idx',
  1631. constructor=lambda: nn.EmbeddingBag(4, 3, padding_idx=1),
  1632. cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).padding_idx(1)',
  1633. input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
  1634. check_gradgrad=False,
  1635. default_dtype=torch.double,
  1636. ),
  1637. dict(
  1638. fullname='EmbeddingBag_sum_padding_idx',
  1639. constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'sum', padding_idx=1),
  1640. cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
  1641. .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum).padding_idx(1)''',
  1642. input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
  1643. check_gradgrad=False,
  1644. default_dtype=torch.double,
  1645. ),
  1646. dict(
  1647. fullname='EmbeddingBag_max_padding_idx',
  1648. constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'max', padding_idx=1),
  1649. cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
  1650. .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax).padding_idx(1)''',
  1651. input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
  1652. check_gradgrad=False,
  1653. default_dtype=torch.double,
  1654. ),
  1655. dict(
  1656. fullname='EmbeddingBag_sparse',
  1657. constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double),
  1658. cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))',
  1659. input_fn=lambda: torch.randperm(2).repeat(1, 2),
  1660. check_gradgrad=False,
  1661. has_sparse_gradients=True,
  1662. ),
  1663. dict(
  1664. constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True),
  1665. cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))',
  1666. input_fn=lambda: torch.randperm(2).repeat(1, 2),
  1667. fullname='Embedding_sparse',
  1668. check_gradgrad=False,
  1669. has_sparse_gradients=True,
  1670. ),
  1671. dict(
  1672. module_name='PixelShuffle',
  1673. constructor_args=(3,),
  1674. cpp_constructor_args='torch::nn::PixelShuffleOptions(3)',
  1675. input_size=(1, 9, 4, 4),
  1676. default_dtype=torch.double,
  1677. ),
  1678. dict(
  1679. module_name='PixelUnshuffle',
  1680. constructor_args=(3,),
  1681. cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)',
  1682. input_size=(1, 1, 12, 12),
  1683. default_dtype=torch.double,
  1684. ),
  1685. dict(
  1686. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
  1687. cpp_options_args='''F::InterpolateFuncOptions()
  1688. .size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)''',
  1689. input_size=(1, 2, 4),
  1690. fullname='interpolate_nearest_1d',
  1691. pickle=False,
  1692. default_dtype=torch.double,
  1693. ),
  1694. dict(
  1695. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
  1696. cpp_options_args='''F::InterpolateFuncOptions()
  1697. .size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)''',
  1698. input_size=(0, 2, 4),
  1699. fullname='interpolate_nearest_1d_zero_dim',
  1700. pickle=False,
  1701. ),
  1702. dict(
  1703. constructor=wrap_functional(F.interpolate, size=(12, ), scale_factor=None, mode='nearest'),
  1704. cpp_options_args='''F::InterpolateFuncOptions()
  1705. .size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)''',
  1706. input_size=(1, 2, 3),
  1707. fullname='interpolate_nearest_tuple_1d',
  1708. pickle=False,
  1709. default_dtype=torch.double,
  1710. ),
  1711. dict(
  1712. constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
  1713. cpp_options_args='''F::InterpolateFuncOptions()
  1714. .size(c10::nullopt).scale_factor(std::vector<double>({4.})).mode(torch::kNearest)''',
  1715. input_size=(1, 2, 4),
  1716. fullname='interpolate_nearest_scale_1d',
  1717. pickle=False,
  1718. default_dtype=torch.double,
  1719. ),
  1720. dict(
  1721. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
  1722. cpp_options_args='''F::InterpolateFuncOptions()
  1723. .size(std::vector<int64_t>({12}))
  1724. .scale_factor(c10::nullopt)
  1725. .mode(torch::kLinear)
  1726. .align_corners(false)''',
  1727. input_size=(1, 2, 4),
  1728. fullname='interpolate_linear_1d',
  1729. pickle=False,
  1730. default_dtype=torch.double,
  1731. ),
  1732. dict(
  1733. constructor=wrap_functional(F.interpolate, size=(4, ), scale_factor=None, mode='linear', align_corners=False),
  1734. cpp_options_args='''F::InterpolateFuncOptions()
  1735. .size(std::vector<int64_t>({4}))
  1736. .scale_factor(c10::nullopt)
  1737. .mode(torch::kLinear)
  1738. .align_corners(false)''',
  1739. input_size=(1, 2, 3),
  1740. fullname='interpolate_linear_tuple_1d',
  1741. pickle=False,
  1742. default_dtype=torch.double,
  1743. ),
  1744. dict(
  1745. constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=False),
  1746. cpp_options_args='''F::InterpolateFuncOptions()
  1747. .size(c10::nullopt)
  1748. .scale_factor(std::vector<double>({4.}))
  1749. .mode(torch::kLinear)
  1750. .align_corners(false)''',
  1751. input_size=(1, 2, 4),
  1752. fullname='interpolate_linear_scale_1d',
  1753. pickle=False,
  1754. default_dtype=torch.double,
  1755. ),
  1756. dict(
  1757. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
  1758. cpp_options_args='''F::InterpolateFuncOptions()
  1759. .size(std::vector<int64_t>({12}))
  1760. .scale_factor(c10::nullopt)
  1761. .mode(torch::kLinear)
  1762. .align_corners(false)''',
  1763. input_size=(0, 2, 4),
  1764. fullname='interpolate_linear_1d_zero_dim',
  1765. pickle=False,
  1766. ),
  1767. dict(
  1768. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=True),
  1769. cpp_options_args='''F::InterpolateFuncOptions()
  1770. .size(std::vector<int64_t>({12}))
  1771. .scale_factor(c10::nullopt)
  1772. .mode(torch::kLinear)
  1773. .align_corners(true)''',
  1774. input_size=(1, 2, 4),
  1775. fullname='interpolate_linear_1d_align_corners',
  1776. pickle=False,
  1777. default_dtype=torch.double,
  1778. ),
  1779. dict(
  1780. constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=True),
  1781. cpp_options_args='''F::InterpolateFuncOptions()
  1782. .size(c10::nullopt)
  1783. .scale_factor(std::vector<double>({4.}))
  1784. .mode(torch::kLinear)
  1785. .align_corners(true)''',
  1786. input_size=(1, 2, 4),
  1787. fullname='interpolate_linear_scale_1d_align_corners',
  1788. pickle=False,
  1789. default_dtype=torch.double,
  1790. ),
  1791. dict(
  1792. constructor=wrap_functional(F.interpolate, size=2, scale_factor=None, mode='nearest'),
  1793. cpp_options_args='''F::InterpolateFuncOptions()
  1794. .size(std::vector<int64_t>({2, 2}))
  1795. .scale_factor(c10::nullopt)
  1796. .mode(torch::kNearest)''',
  1797. input_size=(1, 128, 1, 1),
  1798. fullname='interpolate_nearest_2d_launch_configs',
  1799. pickle=False,
  1800. default_dtype=torch.double,
  1801. ),
  1802. dict(
  1803. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
  1804. cpp_options_args='''F::InterpolateFuncOptions()
  1805. .size(std::vector<int64_t>({12, 12}))
  1806. .scale_factor(c10::nullopt)
  1807. .mode(torch::kNearest)''',
  1808. input_size=(1, 2, 4, 4),
  1809. fullname='interpolate_nearest_2d',
  1810. pickle=False,
  1811. default_dtype=torch.double,
  1812. ),
  1813. dict(
  1814. constructor=wrap_functional(F.interpolate, size=(12, 16), scale_factor=None, mode='nearest'),
  1815. cpp_options_args='''F::InterpolateFuncOptions()
  1816. .size(std::vector<int64_t>({12, 16}))
  1817. .scale_factor(c10::nullopt)
  1818. .mode(torch::kNearest)''',
  1819. input_size=(1, 2, 3, 4),
  1820. fullname='interpolate_nearest_tuple_2d',
  1821. pickle=False,
  1822. default_dtype=torch.double,
  1823. ),
  1824. dict(
  1825. constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
  1826. cpp_options_args='''F::InterpolateFuncOptions()
  1827. .size(c10::nullopt)
  1828. .scale_factor(std::vector<double>({4., 4.}))
  1829. .mode(torch::kNearest)''',
  1830. input_size=(1, 2, 4, 4),
  1831. fullname='interpolate_nearest_scale_2d',
  1832. pickle=False,
  1833. default_dtype=torch.double,
  1834. ),
  1835. dict(
  1836. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
  1837. cpp_options_args='''F::InterpolateFuncOptions()
  1838. .size(std::vector<int64_t>({12, 12}))
  1839. .scale_factor(c10::nullopt)
  1840. .mode(torch::kNearest)''',
  1841. input_size=(0, 2, 4, 4),
  1842. fullname='interpolate_nearest_2d_zero_dim',
  1843. pickle=False,
  1844. ),
  1845. dict(
  1846. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
  1847. cpp_options_args='''F::InterpolateFuncOptions()
  1848. .size(std::vector<int64_t>({12, 12}))
  1849. .scale_factor(c10::nullopt)
  1850. .mode(torch::kBilinear)
  1851. .align_corners(false)''',
  1852. input_size=(1, 2, 4, 4),
  1853. fullname='interpolate_bilinear_2d',
  1854. pickle=False,
  1855. default_dtype=torch.double,
  1856. ),
  1857. dict(
  1858. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
  1859. cpp_options_args='''F::InterpolateFuncOptions()
  1860. .size(std::vector<int64_t>({12, 12}))
  1861. .scale_factor(c10::nullopt)
  1862. .mode(torch::kBilinear)
  1863. .align_corners(false)''',
  1864. input_size=(0, 2, 4, 4),
  1865. fullname='interpolate_bilinear_2d_zero_dim',
  1866. pickle=False,
  1867. ),
  1868. dict(
  1869. constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
  1870. mode='bilinear', align_corners=False),
  1871. cpp_options_args='''F::InterpolateFuncOptions()
  1872. .size(std::vector<int64_t>({4, 6}))
  1873. .scale_factor(c10::nullopt)
  1874. .mode(torch::kBilinear)
  1875. .align_corners(false)''',
  1876. input_size=(1, 2, 2, 3),
  1877. fullname='interpolate_bilinear_tuple_2d',
  1878. pickle=False,
  1879. default_dtype=torch.double,
  1880. ),
  1881. dict(
  1882. constructor=wrap_functional(F.interpolate, size=None, scale_factor=4.,
  1883. mode='bilinear', align_corners=False),
  1884. cpp_options_args='''F::InterpolateFuncOptions()
  1885. .size(c10::nullopt)
  1886. .scale_factor(std::vector<double>({4., 4.}))
  1887. .mode(torch::kBilinear)
  1888. .align_corners(false)''',
  1889. input_size=(1, 2, 4, 4),
  1890. fullname='interpolate_bilinear_scale_2d',
  1891. pickle=False,
  1892. default_dtype=torch.double,
  1893. ),
  1894. dict(
  1895. constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
  1896. mode='bilinear', align_corners=False),
  1897. cpp_options_args='''F::InterpolateFuncOptions()
  1898. .size(c10::nullopt)
  1899. .scale_factor(std::vector<double>({2., 2.}))
  1900. .mode(torch::kBilinear)
  1901. .align_corners(false)''',
  1902. input_size=(1, 2, 4, 4),
  1903. fullname='interpolate_bilinear_scale_tuple_shared_2d',
  1904. pickle=False,
  1905. default_dtype=torch.double,
  1906. ),
  1907. dict(
  1908. constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
  1909. mode='bilinear', align_corners=False),
  1910. cpp_options_args='''F::InterpolateFuncOptions()
  1911. .size(c10::nullopt)
  1912. .scale_factor(std::vector<double>({2., 1.}))
  1913. .mode(torch::kBilinear)
  1914. .align_corners(false)''',
  1915. input_size=(1, 2, 4, 4),
  1916. fullname='interpolate_bilinear_scale_tuple_skewed_2d',
  1917. pickle=False,
  1918. default_dtype=torch.double,
  1919. ),
  1920. dict(
  1921. constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bilinear', align_corners=True),
  1922. cpp_options_args='''F::InterpolateFuncOptions()
  1923. .size(std::vector<int64_t>({4, 6}))
  1924. .scale_factor(c10::nullopt)
  1925. .mode(torch::kBilinear)
  1926. .align_corners(true)''',
  1927. input_size=(1, 2, 4, 4),
  1928. fullname='interpolate_bilinear_tuple_2d_align_corners',
  1929. pickle=False,
  1930. default_dtype=torch.double,
  1931. ),
  1932. dict(
  1933. constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
  1934. mode='bilinear', align_corners=True),
  1935. cpp_options_args='''F::InterpolateFuncOptions()
  1936. .size(c10::nullopt)
  1937. .scale_factor(std::vector<double>({2., 1.}))
  1938. .mode(torch::kBilinear)
  1939. .align_corners(true)''',
  1940. input_size=(1, 2, 4, 4),
  1941. fullname='interpolate_bilinear_scale_tuple_skewed_2d_align_corners',
  1942. pickle=False,
  1943. default_dtype=torch.double,
  1944. ),
  1945. dict(
  1946. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
  1947. cpp_options_args='''F::InterpolateFuncOptions()
  1948. .size(std::vector<int64_t>({12, 12}))
  1949. .scale_factor(c10::nullopt)
  1950. .mode(torch::kBicubic)
  1951. .align_corners(false)''',
  1952. input_size=(1, 2, 4, 4),
  1953. fullname='interpolate_bicubic_2d',
  1954. pickle=False,
  1955. default_dtype=torch.double,
  1956. ),
  1957. dict(
  1958. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
  1959. cpp_options_args='''F::InterpolateFuncOptions()
  1960. .size(std::vector<int64_t>({12, 12}))
  1961. .scale_factor(c10::nullopt)
  1962. .mode(torch::kBicubic)
  1963. .align_corners(false)''',
  1964. input_size=(0, 2, 4, 4),
  1965. fullname='interpolate_bicubic_2d_zero_dim',
  1966. pickle=False,
  1967. ),
  1968. dict(
  1969. constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
  1970. mode='bicubic', align_corners=False),
  1971. cpp_options_args='''F::InterpolateFuncOptions()
  1972. .size(std::vector<int64_t>({4, 6}))
  1973. .scale_factor(c10::nullopt)
  1974. .mode(torch::kBicubic)
  1975. .align_corners(false)''',
  1976. input_size=(1, 2, 2, 3),
  1977. fullname='interpolate_bicubic_tuple_2d',
  1978. pickle=False,
  1979. default_dtype=torch.double,
  1980. ),
  1981. dict(
  1982. constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='bicubic', align_corners=False),
  1983. cpp_options_args='''F::InterpolateFuncOptions()
  1984. .size(c10::nullopt)
  1985. .scale_factor(std::vector<double>({4., 4.}))
  1986. .mode(torch::kBicubic)
  1987. .align_corners(false)''',
  1988. input_size=(1, 2, 4, 4),
  1989. fullname='interpolate_bicubic_scale_2d',
  1990. pickle=False,
  1991. default_dtype=torch.double,
  1992. ),
  1993. dict(
  1994. constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
  1995. mode='bicubic', align_corners=False),
  1996. cpp_options_args='''F::InterpolateFuncOptions()
  1997. .size(c10::nullopt)
  1998. .scale_factor(std::vector<double>({2., 2.}))
  1999. .mode(torch::kBicubic)
  2000. .align_corners(false)''',
  2001. input_size=(1, 2, 4, 4),
  2002. fullname='interpolate_bicubic_scale_tuple_shared_2d',
  2003. pickle=False,
  2004. default_dtype=torch.double,
  2005. ),
  2006. dict(
  2007. constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
  2008. mode='bicubic', align_corners=False),
  2009. cpp_options_args='''F::InterpolateFuncOptions()
  2010. .size(c10::nullopt)
  2011. .scale_factor(std::vector<double>({2., 1.}))
  2012. .mode(torch::kBicubic)
  2013. .align_corners(false)''',
  2014. input_size=(1, 2, 4, 4),
  2015. fullname='interpolate_bicubic_scale_tuple_skewed_2d',
  2016. pickle=False,
  2017. default_dtype=torch.double,
  2018. ),
  2019. dict(
  2020. constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bicubic', align_corners=True),
  2021. cpp_options_args='''F::InterpolateFuncOptions()
  2022. .size(std::vector<int64_t>({4, 6}))
  2023. .scale_factor(c10::nullopt)
  2024. .mode(torch::kBicubic)
  2025. .align_corners(true)''',
  2026. input_size=(1, 2, 4, 4),
  2027. fullname='interpolate_bicubic_tuple_2d_align_corners',
  2028. pickle=False,
  2029. default_dtype=torch.double,
  2030. ),
  2031. dict(
  2032. constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
  2033. mode='bicubic', align_corners=True),
  2034. cpp_options_args='''F::InterpolateFuncOptions()
  2035. .size(c10::nullopt)
  2036. .scale_factor(std::vector<double>({2., 1.}))
  2037. .mode(torch::kBicubic)
  2038. .align_corners(true)''',
  2039. input_size=(1, 2, 4, 4),
  2040. fullname='interpolate_bicubic_scale_tuple_skewed_2d_align_corners',
  2041. pickle=False,
  2042. default_dtype=torch.double,
  2043. ),
  2044. dict(
  2045. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
  2046. cpp_options_args='''F::InterpolateFuncOptions()
  2047. .size(std::vector<int64_t>({12, 12, 12}))
  2048. .scale_factor(c10::nullopt)
  2049. .mode(torch::kNearest)''',
  2050. input_size=(1, 2, 4, 4, 4),
  2051. fullname='interpolate_nearest_3d',
  2052. pickle=False,
  2053. default_dtype=torch.double,
  2054. ),
  2055. dict(
  2056. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
  2057. cpp_options_args='''F::InterpolateFuncOptions()
  2058. .size(std::vector<int64_t>({12, 12, 12}))
  2059. .scale_factor(c10::nullopt)
  2060. .mode(torch::kNearest)''',
  2061. input_size=(0, 2, 4, 4, 4),
  2062. fullname='interpolate_nearest_3d_zero_dim',
  2063. pickle=False,
  2064. ),
  2065. dict(
  2066. constructor=wrap_functional(F.interpolate, size=(12, 16, 16), scale_factor=None, mode='nearest'),
  2067. cpp_options_args='''F::InterpolateFuncOptions()
  2068. .size(std::vector<int64_t>({12, 16, 16}))
  2069. .scale_factor(c10::nullopt)
  2070. .mode(torch::kNearest)''',
  2071. input_size=(1, 2, 3, 4, 4),
  2072. fullname='interpolate_nearest_tuple_3d',
  2073. pickle=False,
  2074. default_dtype=torch.double,
  2075. ),
  2076. dict(
  2077. constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
  2078. cpp_options_args='''F::InterpolateFuncOptions()
  2079. .size(c10::nullopt)
  2080. .scale_factor(std::vector<double>({4., 4., 4.}))
  2081. .mode(torch::kNearest)''',
  2082. input_size=(1, 2, 4, 4, 4),
  2083. fullname='interpolate_nearest_scale_3d',
  2084. pickle=False,
  2085. default_dtype=torch.double,
  2086. ),
  2087. dict(
  2088. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
  2089. cpp_options_args='''F::InterpolateFuncOptions()
  2090. .size(std::vector<int64_t>({12, 12, 12}))
  2091. .scale_factor(c10::nullopt)
  2092. .mode(torch::kTrilinear)
  2093. .align_corners(false)''',
  2094. input_size=(1, 2, 4, 4, 4),
  2095. fullname='interpolate_trilinear_3d',
  2096. pickle=False,
  2097. default_dtype=torch.double,
  2098. ),
  2099. dict(
  2100. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
  2101. cpp_options_args='''F::InterpolateFuncOptions()
  2102. .size(std::vector<int64_t>({12, 12, 12}))
  2103. .scale_factor(c10::nullopt)
  2104. .mode(torch::kTrilinear)
  2105. .align_corners(false)''',
  2106. input_size=(0, 2, 4, 4, 4),
  2107. fullname='interpolate_trilinear_3d_zero_dim',
  2108. pickle=False,
  2109. ),
  2110. dict(
  2111. constructor=wrap_functional(F.interpolate, size=(4, 6, 6),
  2112. scale_factor=None, mode='trilinear', align_corners=False),
  2113. cpp_options_args='''F::InterpolateFuncOptions()
  2114. .size(std::vector<int64_t>({4, 6, 6}))
  2115. .scale_factor(c10::nullopt)
  2116. .mode(torch::kTrilinear)
  2117. .align_corners(false)''',
  2118. input_size=(1, 2, 2, 3, 3),
  2119. fullname='interpolate_trilinear_tuple_3d',
  2120. pickle=False,
  2121. default_dtype=torch.double,
  2122. ),
  2123. dict(
  2124. constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=False),
  2125. cpp_options_args='''F::InterpolateFuncOptions()
  2126. .size(c10::nullopt)
  2127. .scale_factor(std::vector<double>({3., 3., 3.}))
  2128. .mode(torch::kTrilinear)
  2129. .align_corners(false)''',
  2130. input_size=(1, 2, 3, 4, 5),
  2131. fullname='interpolate_trilinear_scale_3d',
  2132. # See https://github.com/pytorch/pytorch/issues/5006
  2133. precision=3e-4,
  2134. pickle=False,
  2135. default_dtype=torch.double,
  2136. ),
  2137. dict(
  2138. constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=None,
  2139. mode='trilinear', align_corners=True),
  2140. cpp_options_args='''F::InterpolateFuncOptions()
  2141. .size(std::vector<int64_t>({4, 6, 6}))
  2142. .scale_factor(c10::nullopt)
  2143. .mode(torch::kTrilinear)
  2144. .align_corners(true)''',
  2145. input_size=(1, 2, 2, 3, 3),
  2146. fullname='interpolate_trilinear_tuple_3d_align_corners',
  2147. pickle=False,
  2148. default_dtype=torch.double
  2149. ),
  2150. dict(
  2151. constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=True),
  2152. cpp_options_args='''F::InterpolateFuncOptions()
  2153. .size(c10::nullopt)
  2154. .scale_factor(std::vector<double>({3., 3., 3.}))
  2155. .mode(torch::kTrilinear)
  2156. .align_corners(true)''',
  2157. input_size=(1, 2, 3, 4, 4),
  2158. fullname='interpolate_trilinear_scale_3d_align_corners',
  2159. # See https://github.com/pytorch/pytorch/issues/5006
  2160. precision=3e-4,
  2161. pickle=False,
  2162. default_dtype=torch.double,
  2163. ),
  2164. dict(
  2165. constructor=wrap_functional(F.softmax, dim=-1),
  2166. cpp_options_args='F::SoftmaxFuncOptions(-1)',
  2167. input_size=(2, 128), # trigger the last-dim algo in CUDA
  2168. fullname='softmax_lastdim',
  2169. pickle=False,
  2170. default_dtype=torch.double,
  2171. ),
  2172. dict(
  2173. constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
  2174. cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)',
  2175. input_size=(2, 128),
  2176. fullname='softmax_lastdim_dtype',
  2177. pickle=False,
  2178. test_cuda=False,
  2179. default_dtype=torch.double,
  2180. ),
  2181. dict(
  2182. constructor=wrap_functional(F.softmax, dim=1),
  2183. cpp_options_args='F::SoftmaxFuncOptions(1)',
  2184. input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo
  2185. fullname='softmax_spatial_special',
  2186. pickle=False,
  2187. default_dtype=torch.double,
  2188. ),
  2189. dict(
  2190. constructor=wrap_functional(F.softmax, dim=1),
  2191. cpp_options_args='F::SoftmaxFuncOptions(1)',
  2192. input_size=(2, 2, 4, 4), # regular spatial algorithm
  2193. fullname='softmax_spatial',
  2194. pickle=False,
  2195. default_dtype=torch.double,
  2196. ),
  2197. dict(
  2198. constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
  2199. cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)',
  2200. input_size=(2, 2, 4, 4), # regular spatial algorithm
  2201. fullname='softmax_spatial_dtype',
  2202. pickle=False,
  2203. test_cuda=False,
  2204. default_dtype=torch.double,
  2205. ),
  2206. dict(
  2207. constructor=wrap_functional(F.softmax, dim=0),
  2208. cpp_options_args='F::SoftmaxFuncOptions(0)',
  2209. input_size=(2, 3, 4, 5),
  2210. fullname='softmax_functional_dim0',
  2211. test_cuda=False,
  2212. pickle=False,
  2213. default_dtype=torch.double,
  2214. ),
  2215. dict(
  2216. constructor=wrap_functional(F.softmax, dim=3),
  2217. cpp_options_args='F::SoftmaxFuncOptions(3)',
  2218. input_size=(2, 3, 4, 5),
  2219. fullname='softmax_functional_dim3',
  2220. test_cuda=False,
  2221. pickle=False,
  2222. default_dtype=torch.double,
  2223. ),
  2224. dict(
  2225. constructor=wrap_functional(F.softmax, dim=-1),
  2226. cpp_options_args='F::SoftmaxFuncOptions(-1)',
  2227. input_size=(),
  2228. fullname='softmax_functional_scalar',
  2229. test_cuda=False,
  2230. pickle=False,
  2231. ),
  2232. dict(
  2233. constructor=wrap_functional(F.log_softmax, dim=-1),
  2234. cpp_options_args='F::LogSoftmaxFuncOptions(-1)',
  2235. input_size=(2, 128), # trigger the last-dim algo in CUDA
  2236. fullname='log_softmax_lastdim',
  2237. pickle=False,
  2238. default_dtype=torch.double,
  2239. ),
  2240. dict(
  2241. constructor=wrap_functional(F.log_softmax, dim=1),
  2242. cpp_options_args='F::LogSoftmaxFuncOptions(1)',
  2243. input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo
  2244. fullname='log_softmax_spatial_special',
  2245. pickle=False,
  2246. default_dtype=torch.double,
  2247. ),
  2248. dict(
  2249. constructor=wrap_functional(F.log_softmax, dim=1),
  2250. cpp_options_args='F::LogSoftmaxFuncOptions(1)',
  2251. input_size=(2, 2, 4, 4), # regular spatial algorithm
  2252. fullname='log_softmax_spatial',
  2253. pickle=False,
  2254. default_dtype=torch.double,
  2255. ),
  2256. dict(
  2257. constructor=wrap_functional(F.log_softmax, dim=0),
  2258. cpp_options_args='F::LogSoftmaxFuncOptions(0)',
  2259. input_size=(2, 3, 4, 5),
  2260. fullname='log_softmax_dim0',
  2261. pickle=False,
  2262. default_dtype=torch.double,
  2263. ),
  2264. dict(
  2265. constructor=wrap_functional(F.log_softmax, dim=3),
  2266. cpp_options_args='F::LogSoftmaxFuncOptions(3)',
  2267. input_size=(2, 3, 4, 5),
  2268. fullname='log_softmax_dim3',
  2269. pickle=False,
  2270. default_dtype=torch.double,
  2271. ),
  2272. dict(
  2273. constructor=wrap_functional(F.log_softmax, dim=0),
  2274. cpp_options_args='F::LogSoftmaxFuncOptions(0)',
  2275. input_size=(),
  2276. fullname='log_softmax_scalar',
  2277. pickle=False,
  2278. ),
  2279. dict(
  2280. fullname='Unfold',
  2281. constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)),
  2282. cpp_constructor_args='torch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
  2283. input_size=(2, 4, 3, 3),
  2284. check_gradgrad=False,
  2285. test_cuda=True,
  2286. default_dtype=torch.double,
  2287. ),
  2288. dict(
  2289. fullname='Fold',
  2290. constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
  2291. cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
  2292. input_size=(2, 16, 4),
  2293. check_gradgrad=False,
  2294. test_cuda=True,
  2295. default_dtype=torch.double,
  2296. ),
  2297. dict(
  2298. fullname='Fold_no_batch_dim_input',
  2299. constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
  2300. cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
  2301. input_size=(16, 4),
  2302. check_gradgrad=False,
  2303. ref=single_batch_reference_fn,
  2304. test_cuda=True,
  2305. default_dtype=torch.double,
  2306. ),
  2307. dict(
  2308. fullname='Unfold_int_input',
  2309. constructor=lambda: nn.Unfold(2, 1, 0, 1),
  2310. cpp_constructor_args='torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)',
  2311. input_size=(2, 4, 3, 3),
  2312. check_gradgrad=False,
  2313. test_cuda=True,
  2314. default_dtype=torch.double,
  2315. ),
  2316. dict(
  2317. fullname='Fold_int_input',
  2318. constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
  2319. cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)',
  2320. input_size=(2, 16, 4),
  2321. check_gradgrad=False,
  2322. test_cuda=True,
  2323. default_dtype=torch.double,
  2324. ),
  2325. dict(
  2326. fullname='Fold_no_batch_dim_int_input',
  2327. constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
  2328. cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)',
  2329. input_size=(16, 4),
  2330. ref=single_batch_reference_fn,
  2331. check_gradgrad=False,
  2332. test_cuda=True,
  2333. default_dtype=torch.double,
  2334. ),
  2335. dict(
  2336. module_name='RReLU',
  2337. constructor_args=(0.1, 0.9),
  2338. cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)',
  2339. input_size=(),
  2340. desc='with_up_down_scalar',
  2341. test_cuda=False,
  2342. default_dtype=torch.double,
  2343. ),
  2344. dict(
  2345. module_name='PairwiseDistance',
  2346. input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
  2347. default_dtype=torch.double,
  2348. ),
  2349. dict(
  2350. module_name='PairwiseDistance',
  2351. input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)),
  2352. desc='broadcast_lhs',
  2353. default_dtype=torch.double,
  2354. ),
  2355. dict(
  2356. module_name='PairwiseDistance',
  2357. input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)),
  2358. desc='broadcast_rhs',
  2359. default_dtype=torch.double,
  2360. ),
  2361. dict(
  2362. module_name='PairwiseDistance',
  2363. constructor_args=(1.5, 1e-05, True),
  2364. cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)',
  2365. input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
  2366. desc='with_non_default_args',
  2367. default_dtype=torch.double,
  2368. ),
  2369. dict(
  2370. module_name='PairwiseDistance',
  2371. input_fn=lambda: (torch.randn(8), torch.randn(8)),
  2372. reference_fn=single_batch_reference_fn,
  2373. desc='no_batch_dim',
  2374. default_dtype=torch.double,
  2375. ),
  2376. dict(
  2377. module_name='TransformerEncoderLayer',
  2378. constructor_args=(4, 2, 16, 0.0),
  2379. cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
  2380. .dim_feedforward(16)
  2381. .dropout(0.0)''',
  2382. input_size=(2, 3, 4),
  2383. desc='relu_activation',
  2384. with_tf32=True,
  2385. tf32_precision=0.1,
  2386. # TODO(#50743): figure out the error
  2387. # RuntimeError: The size of tensor a (6) must match the size of tensor b (4)
  2388. # at non-singleton dimension 2
  2389. check_batched_grad=False,
  2390. check_gradgrad=False,
  2391. default_dtype=torch.double,
  2392. ),
  2393. dict(
  2394. module_name='TransformerEncoderLayer',
  2395. constructor_args=(4, 2, 8, 0.0, F.gelu),
  2396. cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
  2397. .dim_feedforward(8)
  2398. .dropout(0.0)
  2399. .activation(torch::kGELU)''',
  2400. input_size=(2, 3, 4),
  2401. check_gradgrad=False,
  2402. desc='gelu_activation',
  2403. with_tf32=True,
  2404. tf32_precision=0.08 if SM90OrLater else 0.05,
  2405. default_dtype=torch.double,
  2406. ),
  2407. dict(
  2408. module_name='TransformerDecoderLayer',
  2409. constructor_args=(4, 2, 8, 0.0),
  2410. cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2)
  2411. .dim_feedforward(8)
  2412. .dropout(0.0)''',
  2413. input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)),
  2414. check_gradgrad=False,
  2415. desc='relu_activation',
  2416. with_tf32=True,
  2417. tf32_precision=0.05,
  2418. default_dtype=torch.double,
  2419. ),
  2420. dict(
  2421. module_name='TransformerDecoderLayer',
  2422. constructor_args=(4, 2, 8, 0.0, F.gelu),
  2423. cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2)
  2424. .dim_feedforward(8)
  2425. .dropout(0.0)
  2426. .activation(torch::kGELU)''',
  2427. input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)),
  2428. check_gradgrad=False,
  2429. desc='gelu_activation',
  2430. with_tf32=True,
  2431. tf32_precision=0.05,
  2432. default_dtype=torch.double,
  2433. ),
  2434. dict(
  2435. module_name='Transformer',
  2436. constructor_args=(4, 2, 2, 2, 8, 0.0, F.relu),
  2437. cpp_constructor_args='''torch::nn::TransformerOptions()
  2438. .d_model(4)
  2439. .nhead(2)
  2440. .num_encoder_layers(2)
  2441. .num_decoder_layers(2)
  2442. .dim_feedforward(8)
  2443. .dropout(0.0)
  2444. .activation(torch::kReLU)''',
  2445. input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)),
  2446. check_gradgrad=False,
  2447. desc='multilayer_coder',
  2448. with_tf32=True,
  2449. tf32_precision=0.05 if SM90OrLater else 0.03,
  2450. default_dtype=torch.double,
  2451. ),
  2452. dict(
  2453. module_name='Linear',
  2454. constructor_args=(3, 5),
  2455. cpp_constructor_args='torch::nn::LinearOptions(3, 5)',
  2456. input_fn=lambda: torch.rand(3),
  2457. reference_fn=lambda i, p, _: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1],
  2458. desc="no_batch_dim",
  2459. with_tf32=True,
  2460. tf32_precision=0.005,
  2461. default_dtype=torch.double,
  2462. ),
  2463. dict(
  2464. module_name='Flatten',
  2465. cpp_constructor_args='torch::nn::FlattenOptions().start_dim(-3).end_dim(-1)',
  2466. constructor_args=(-3, -1),
  2467. input_size=(3, 4, 5),
  2468. reference_fn=single_batch_reference_fn,
  2469. desc="no_batch_dim",
  2470. default_dtype=torch.double,
  2471. ),
  2472. dict(
  2473. module_name='Unflatten',
  2474. cpp_constructor_args='torch::nn::UnflattenOptions(-2, {2, 2})',
  2475. constructor_args=(-2, torch.Size([2, 2])),
  2476. input_size=(3, 4, 5),
  2477. reference_fn=single_batch_reference_fn,
  2478. desc="no_batch_dim",
  2479. default_dtype=torch.double,
  2480. ),
  2481. dict(
  2482. module_name='LayerNorm',
  2483. constructor_args=([56, 56, 56], 1e-5, False),
  2484. cpp_constructor_args='torch::nn::LayerNormOptions({56, 56, 56}).eps(1e-5).elementwise_affine(false)',
  2485. input_size=(4, 56, 56, 56),
  2486. cudnn=True,
  2487. check_eval=True,
  2488. gradcheck_fast_mode=True,
  2489. check_half=True,
  2490. desc='3d_no_affine_large_feature',
  2491. ),
  2492. ]
  2493. # add conv padding mode tests:
  2494. for padding_mode, cpp_padding_mode in zip(
  2495. ['reflect', 'circular', 'replicate', 'zeros'],
  2496. ['torch::kReflect', 'torch::kCircular', 'torch::kReplicate', 'torch::kZeros']):
  2497. # conv signature:
  2498. # in_channels, out_channels, kernel_size, stride=1,
  2499. # padding=0, dilation=1, groups=1,
  2500. # bias=True, padding_mode='zeros'
  2501. for d in (1, 2, 3):
  2502. if d == 3 and padding_mode == 'reflect':
  2503. # FIXME: remove after implementing reflection pad 3d
  2504. # https://github.com/pytorch/pytorch/issues/27655
  2505. continue
  2506. padding = tuple(range(1, d + 1))
  2507. cpp_padding = '{' + ', '.join(map(str, padding)) + '}'
  2508. input_size = (2, 2) + (4,) * d
  2509. output_size = (2, 3) + tuple(p + 1 for p in padding) # simplified from `(4 + 2 * p - 3) // 2 + 1`
  2510. new_module_tests.append(
  2511. dict(
  2512. module_name=f'Conv{d}d',
  2513. constructor_args=(2, 3, 3, 2, padding, 1, 1, True, padding_mode),
  2514. cpp_constructor_args=f'''torch::nn::Conv{d}dOptions(2, 3, 3)
  2515. .stride(2)
  2516. .padding({cpp_padding})
  2517. .dilation(1)
  2518. .groups(1)
  2519. .bias(true)
  2520. .padding_mode({cpp_padding_mode})''',
  2521. input_size=input_size,
  2522. output_size=output_size,
  2523. cudnn=True,
  2524. desc=f'{padding_mode}_stride2_pad2',
  2525. with_tf32=True,
  2526. tf32_precision=0.05,
  2527. default_dtype=torch.double,
  2528. ),
  2529. )
  2530. # Check that non linear activations work with no batch dimensions
  2531. non_linear_activations_no_batch = [
  2532. 'ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU',
  2533. 'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'GLU',
  2534. 'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh',
  2535. 'Tanhshrink', 'Threshold'
  2536. ]
  2537. non_linear_activations_extra_info: Dict[str, dict] = {
  2538. 'CELU': {'constructor_args': (2.,), 'default_dtype': torch.double},
  2539. 'Threshold': {'constructor_args': (2., 1.)},
  2540. 'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
  2541. 'Hardswish': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
  2542. # For RRelu, test that compare CPU and GPU results fail because RNG
  2543. # is different between CPU and GPU
  2544. 'RReLU': {'test_cuda': False, 'default_dtype': torch.double},
  2545. 'ELU': {'default_dtype': torch.double},
  2546. 'GELU': {'default_dtype': torch.double},
  2547. 'GLU': {'default_dtype': torch.double},
  2548. 'Hardshrink': {'default_dtype': torch.double},
  2549. 'Hardtanh': {'default_dtype': torch.double},
  2550. 'LeakyReLU': {'default_dtype': torch.double},
  2551. 'LogSigmoid': {'default_dtype': torch.double},
  2552. 'Mish': {'default_dtype': torch.double},
  2553. 'PReLU': {'default_dtype': torch.double},
  2554. 'ReLU6': {'default_dtype': torch.double},
  2555. 'ReLU': {'default_dtype': torch.double},
  2556. 'SELU': {'default_dtype': torch.double},
  2557. 'SiLU': {'default_dtype': torch.double},
  2558. 'Sigmoid': {'default_dtype': torch.double},
  2559. 'Softplus': {'default_dtype': torch.double},
  2560. 'Softshrink': {'default_dtype': torch.double},
  2561. 'Softsign': {'default_dtype': torch.double},
  2562. 'Tanh': {'default_dtype': torch.double},
  2563. 'Tanhshrink': {'default_dtype': torch.double},
  2564. }
  2565. for non_linear_activation in non_linear_activations_no_batch:
  2566. activation_test_info = dict(
  2567. module_name=non_linear_activation,
  2568. input_size=(4,),
  2569. reference_fn=single_batch_reference_fn,
  2570. desc='no_batch_dim',
  2571. test_cpp_api_parity=False,
  2572. )
  2573. extra_info = non_linear_activations_extra_info.get(non_linear_activation, {})
  2574. activation_test_info.update(extra_info)
  2575. new_module_tests.append(activation_test_info)
  2576. def kldivloss_reference(input, target, reduction='mean', log_target=False):
  2577. if log_target:
  2578. result = torch.exp(target) * (target - input)
  2579. else:
  2580. result = target * (target.log() - input)
  2581. if reduction == 'mean':
  2582. return result.mean()
  2583. elif reduction == 'sum':
  2584. return result.sum()
  2585. elif reduction == 'batchmean' and result.dim() != 0:
  2586. return result.sum() / result.size(0)
  2587. return result
  2588. def nlllossNd_reference(input, target, weight=None, ignore_index=-100,
  2589. reduction='mean'):
  2590. assert input.dim() >= 3
  2591. N = input.size(0)
  2592. C = input.size(1)
  2593. out_size = (N,) + input.size()[2:]
  2594. output = torch.zeros(out_size).type_as(input)
  2595. if weight is None:
  2596. weight = torch.ones(C).type_as(input)
  2597. total_weight = 0
  2598. for tup in product(*[range(size) for size in out_size]):
  2599. t_nx = target[tup]
  2600. norm = 0. if ignore_index == t_nx else weight[t_nx].item()
  2601. input_index = list(tup)
  2602. input_index.insert(1, t_nx)
  2603. output[tup] = -input[tuple(input_index)] * norm
  2604. total_weight += norm
  2605. if reduction == 'mean':
  2606. return output.sum() / total_weight
  2607. elif reduction == 'sum':
  2608. return output.sum()
  2609. return output
  2610. def cross_entropy_loss_prob_target_reference(input, target, weight=None, reduction='mean',
  2611. label_smoothing=0.0):
  2612. assert input.dim() >= 2
  2613. input = torch.log_softmax(input, 1)
  2614. C = input.size(1)
  2615. if weight is None:
  2616. weight = torch.ones(C).type_as(input)
  2617. weight = weight.view(1, C, *(1 for _ in input.shape[2:]))
  2618. if label_smoothing > 0.0:
  2619. assert label_smoothing <= 1.0
  2620. target = (target * (1 - label_smoothing) + label_smoothing / C)
  2621. output = -(input * target * weight).sum(dim=1)
  2622. if reduction == 'mean':
  2623. return output.mean()
  2624. elif reduction == 'sum':
  2625. return output.sum()
  2626. return output
  2627. def cross_entropy_loss_indices_target_reference(input, target, weight=None, ignore_index=-100,
  2628. reduction='mean', label_smoothing=0.0):
  2629. log_softmax_input = torch.log_softmax(input, 1)
  2630. nllloss = F.nll_loss(
  2631. log_softmax_input,
  2632. target,
  2633. weight,
  2634. ignore_index=ignore_index,
  2635. reduction=reduction)
  2636. if label_smoothing == 0.0:
  2637. return nllloss
  2638. assert 0.0 < label_smoothing <= 1.0
  2639. input = torch.log_softmax(input, 1)
  2640. C = input.size(1)
  2641. if weight is not None:
  2642. input = input * weight.view(1, C, *(1 for _ in input.shape[2:]))
  2643. smooth_loss = -torch.sum(input, 1)
  2644. ignore_mask = target == ignore_index
  2645. smooth_loss.masked_fill_(ignore_mask, 0.0)
  2646. if reduction == 'mean':
  2647. if weight is not None:
  2648. # TODO: This code can path can be removed if #61309 is resolved
  2649. # loss is normalized by the weights to be consistent with nll_loss_nd
  2650. ret = torch.sum(smooth_loss) / weight.gather(0, target.masked_select(ignore_mask.logical_not()).flatten()).sum()
  2651. else:
  2652. ret = torch.mean(smooth_loss.masked_select(ignore_mask.logical_not()))
  2653. elif reduction == 'sum':
  2654. ret = torch.sum(smooth_loss)
  2655. else:
  2656. ret = smooth_loss
  2657. return (1 - label_smoothing) * nllloss + ret * (label_smoothing / C)
  2658. def cross_entropy_loss_reference(input, target, weight=None, ignore_index=-100, reduction='mean',
  2659. label_smoothing=0.0):
  2660. if input.shape == target.shape:
  2661. return cross_entropy_loss_prob_target_reference(
  2662. input,
  2663. target,
  2664. weight=weight,
  2665. reduction=reduction,
  2666. label_smoothing=label_smoothing)
  2667. else:
  2668. return cross_entropy_loss_indices_target_reference(
  2669. input, target, weight=weight, reduction=reduction,
  2670. ignore_index=ignore_index, label_smoothing=label_smoothing
  2671. )
  2672. def nllloss_reference(input, target, weight=None, ignore_index=-100,
  2673. reduction='mean'):
  2674. def nll_loss_helper(input, target, weight, ignore_index):
  2675. if target == ignore_index:
  2676. return (0, 0)
  2677. norm = 1 if weight is None else weight[target]
  2678. result = -input[target] * norm
  2679. return (result, norm)
  2680. losses_and_weights = [nll_loss_helper(i, t, weight, ignore_index)
  2681. for i, t in zip(input, target)]
  2682. losses, weights = zip(*losses_and_weights)
  2683. losses_tensor = input.new_tensor(losses)
  2684. if reduction == 'mean':
  2685. return sum(losses_tensor) / sum(weights)
  2686. elif reduction == 'sum':
  2687. return sum(losses_tensor)
  2688. else:
  2689. return losses_tensor
  2690. def smoothl1loss_reference(input, target, reduction='mean', beta=1.0):
  2691. abs_diff = (input - target).abs()
  2692. ge_beta_mask = (abs_diff >= beta).type_as(abs_diff)
  2693. lt_beta_mask = (abs_diff < beta).type_as(abs_diff)
  2694. # when beta <= 0 we should just use l1_loss
  2695. if beta == 0:
  2696. output = abs_diff
  2697. else:
  2698. output = ge_beta_mask * (abs_diff - 0.5 * beta) + lt_beta_mask * 0.5 * (abs_diff ** 2) / beta
  2699. if reduction == 'mean':
  2700. return output.mean()
  2701. elif reduction == 'sum':
  2702. return output.sum()
  2703. return output
  2704. def huberloss_reference(input, target, reduction='mean', delta=1.0):
  2705. abs_diff = (input - target).abs()
  2706. ge_delta_mask = (abs_diff >= delta)
  2707. lt_delta_mask = (abs_diff < delta)
  2708. output = ge_delta_mask * delta * (abs_diff - 0.5 * delta) + lt_delta_mask * 0.5 * (abs_diff ** 2)
  2709. if reduction == 'mean':
  2710. return output.mean()
  2711. elif reduction == 'sum':
  2712. return output.sum()
  2713. return output
  2714. def _multilabelmarginloss_reference(input, target):
  2715. targets = []
  2716. for target_index in target:
  2717. if target_index < 0:
  2718. break
  2719. targets.append(target_index)
  2720. sum = 0
  2721. for target_index in targets:
  2722. for i in range(0, len(input)):
  2723. if i not in targets:
  2724. sum += max(0, 1 - input[target_index] + input[i])
  2725. return sum
  2726. def multilabelmarginloss_reference(input, target, reduction='mean'):
  2727. # make everything 2-dimensional
  2728. input_dim = input.dim()
  2729. if input.dim() < 2:
  2730. assert target.dim() < 2
  2731. input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
  2732. target = target.unsqueeze(0) if target.dim() == 1 else target.unsqueeze(0).unsqueeze(0)
  2733. n = input.size(0)
  2734. dim = input.size(1)
  2735. output = input.new(n).zero_()
  2736. for i in range(0, n):
  2737. output[i] = _multilabelmarginloss_reference(input[i], target[i])
  2738. if reduction == 'mean':
  2739. return output.mean() / dim
  2740. elif reduction == 'sum':
  2741. return output.sum() / dim
  2742. elif input_dim < 2:
  2743. # we know we have (1, C) X (1, C) -> (1,), so squeeze will get us
  2744. # back to correct dimensionality
  2745. return output.squeeze() / dim
  2746. else:
  2747. return output / dim
  2748. def hingeembeddingloss_reference(input, target, margin=1.0, reduction='mean'):
  2749. margin_clamp = (margin - input).clamp(min=0).type_as(input)
  2750. output = torch.where(target == 1, input, margin_clamp)
  2751. if reduction == 'mean':
  2752. return output.mean()
  2753. elif reduction == 'sum':
  2754. return output.sum()
  2755. return output
  2756. def softmarginloss_reference(input, target, reduction='mean'):
  2757. output = (1 + (-input * target).exp()).log()
  2758. if reduction == 'mean':
  2759. return output.mean()
  2760. elif reduction == 'sum':
  2761. return output.sum()
  2762. return output
  2763. def _multimarginloss_reference(input, target_idx, p, margin, weight):
  2764. if weight is None:
  2765. weight = input.new(len(input)).fill_(1)
  2766. output = 0
  2767. for i in range(0, len(input)):
  2768. if i != target_idx:
  2769. output += weight[target_idx] * (max(0, (margin - input[target_idx] + input[i])) ** p)
  2770. return output
  2771. def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='mean'):
  2772. if input.dim() < 2:
  2773. input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
  2774. target_dim = target.dim()
  2775. if target.dim() == 0:
  2776. target = target.unsqueeze(0)
  2777. n = input.size(0)
  2778. dim = input.size(1)
  2779. output = input.new(n)
  2780. for x in range(0, n):
  2781. output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight)
  2782. if reduction == 'mean':
  2783. return output.mean() / dim
  2784. elif reduction == 'sum':
  2785. return output.sum() / dim
  2786. elif target_dim == 0:
  2787. return output.squeeze(0) / dim
  2788. return output / dim
  2789. def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='mean'):
  2790. def _cos(a, b):
  2791. cos = a.new(a.size(0))
  2792. for i in range(0, a.size(0)):
  2793. cos[i] = (a[i] * b[i]).sum() / ((((a[i] * a[i]).sum() + 1e-12) * ((b[i] * b[i]).sum() + 1e-12)) ** 0.5)
  2794. return cos
  2795. output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0))
  2796. if reduction == 'mean':
  2797. return output.mean()
  2798. elif reduction == 'sum':
  2799. return output.sum()
  2800. return output
  2801. def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False,
  2802. reduction='mean'):
  2803. d_p = torch.pairwise_distance(anchor, positive, p, eps)
  2804. d_n = torch.pairwise_distance(anchor, negative, p, eps)
  2805. if swap:
  2806. d_s = torch.pairwise_distance(positive, negative, p, eps)
  2807. d_n = torch.min(d_n, d_s)
  2808. output = torch.clamp(margin + d_p - d_n, min=0.0)
  2809. if reduction == 'mean':
  2810. return output.mean()
  2811. elif reduction == 'sum':
  2812. return output.sum()
  2813. return output
  2814. def marginrankingloss_reference(input1, input2, target, margin=0, reduction='mean'):
  2815. output = (-target * (input1 - input2) + margin).clamp(min=0)
  2816. if reduction == 'mean':
  2817. return output.mean()
  2818. elif reduction == 'sum':
  2819. return output.sum()
  2820. return output
  2821. # this directly follows Graves et al.'s paper, in contrast to the production implementation, it does not use log-space
  2822. def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean'):
  2823. input_lengths = torch.as_tensor(input_lengths, dtype=torch.long)
  2824. target_lengths = torch.as_tensor(target_lengths, dtype=torch.long)
  2825. dt = log_probs.dtype
  2826. log_probs = log_probs.double() # we need the accuracy as we are not in logspace
  2827. targets = targets.long()
  2828. cum_target_lengths = target_lengths.cumsum(0)
  2829. losses = []
  2830. for i in range(log_probs.size(1)):
  2831. input_length = input_lengths[i].item()
  2832. target_length = target_lengths[i].item()
  2833. cum_target_length = cum_target_lengths[i].item()
  2834. targets_prime = targets.new_full((2 * target_length + 1,), blank)
  2835. if targets.dim() == 2:
  2836. targets_prime[1::2] = targets[i, :target_length]
  2837. else:
  2838. targets_prime[1::2] = targets[cum_target_length - target_length:cum_target_length]
  2839. probs = log_probs[:input_length, i].exp()
  2840. alpha = log_probs.new_zeros((target_length * 2 + 1,))
  2841. alpha[0] = probs[0, blank]
  2842. alpha[1] = probs[0, targets_prime[1]]
  2843. mask_third = (targets_prime[:-2] != targets_prime[2:])
  2844. for t in range(1, input_length):
  2845. alpha_next = alpha.clone()
  2846. alpha_next[1:] += alpha[:-1]
  2847. alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1))
  2848. alpha = probs[t, targets_prime] * alpha_next
  2849. losses.append(-alpha[-2:].sum().log()[None])
  2850. output = torch.cat(losses, 0)
  2851. if reduction == 'mean':
  2852. output = (output / target_lengths.to(dtype=output.dtype, device=output.device)).mean()
  2853. elif reduction == 'sum':
  2854. output = output.sum()
  2855. output = output.to(dt)
  2856. return output
  2857. loss_reference_fns: Dict['str', Callable] = {
  2858. 'KLDivLoss': kldivloss_reference,
  2859. 'KLDivLoss_log_target': partial(kldivloss_reference, log_target=True),
  2860. 'NLLLoss': nllloss_reference,
  2861. 'NLLLossNd': nlllossNd_reference,
  2862. 'SmoothL1Loss': smoothl1loss_reference,
  2863. 'HuberLoss': huberloss_reference,
  2864. 'MultiLabelMarginLoss': multilabelmarginloss_reference,
  2865. 'HingeEmbeddingLoss': hingeembeddingloss_reference,
  2866. 'SoftMarginLoss': softmarginloss_reference,
  2867. 'MultiMarginLoss': multimarginloss_reference,
  2868. 'CosineEmbeddingLoss': cosineembeddingloss_reference,
  2869. 'TripletMarginLoss': tripletmarginloss_reference,
  2870. 'MarginRankingLoss': marginrankingloss_reference,
  2871. 'CTCLoss': ctcloss_reference,
  2872. 'CrossEntropyLoss': cross_entropy_loss_reference
  2873. }
  2874. criterion_tests = []
  2875. def single_batch_reference_criterion_fn(*args):
  2876. """Reference function for criterion supporting no batch dimensions.
  2877. The criterion is passed the input and target in batched form with a single item.
  2878. The output is squeezed to compare with the no-batch input.
  2879. """
  2880. criterion = args[-1]
  2881. def unsqueeze_inp(inp):
  2882. if isinstance(inp, (list, tuple)):
  2883. return [t.unsqueeze(0) for t in inp]
  2884. return inp.unsqueeze(0)
  2885. def flatten(xs):
  2886. result = []
  2887. if isinstance(xs, (list, tuple)):
  2888. for x in xs:
  2889. result.extend(flatten(x))
  2890. else:
  2891. result.append(xs)
  2892. return result
  2893. single_batch_input_args = flatten([unsqueeze_inp(input) for input in args[:-1]])
  2894. output = criterion(*single_batch_input_args)
  2895. reduction = get_reduction(criterion)
  2896. if reduction == 'none':
  2897. return output.squeeze(0)
  2898. # reduction is 'sum' or 'mean' which results in a scalar
  2899. return output
  2900. # Check that regression criterion work with no batch dimensions
  2901. regression_criterion_no_batch = [
  2902. 'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'HuberLoss', 'SmoothL1Loss'
  2903. ]
  2904. reductions = ['none', 'mean', 'sum']
  2905. for name, reduction in product(regression_criterion_no_batch, reductions):
  2906. regression_test_info = dict(
  2907. fullname=f"{name}_no_batch_dim_{reduction}",
  2908. constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction),
  2909. input_size=(3, ),
  2910. target_size=(3, ),
  2911. reference_fn=single_batch_reference_criterion_fn,
  2912. test_cpp_api_parity=False,
  2913. default_dtype=torch.double,
  2914. )
  2915. criterion_tests.append(regression_test_info)
  2916. for reduction in reductions:
  2917. regression_test_info = dict(
  2918. fullname=f"KLDivLoss_no_batch_dim_{reduction}",
  2919. constructor=lambda: nn.KLDivLoss(reduction=reduction),
  2920. input_fn=lambda: torch.rand((3,)).log(),
  2921. target_fn=lambda: torch.rand((3,)),
  2922. reference_fn=single_batch_reference_criterion_fn,
  2923. test_cpp_api_parity=False,
  2924. default_dtype=torch.double,
  2925. )
  2926. criterion_tests.append(regression_test_info)
  2927. # Check that classification criterion work with no batch dimensions
  2928. # List of tuples of (name, input_fn, target_fn)
  2929. classification_criterion_no_batch = [
  2930. (
  2931. 'BCELoss',
  2932. lambda: torch.sigmoid(torch.randn(9, dtype=torch.double)),
  2933. lambda: torch.randn(9, dtype=torch.double).gt(0).to(torch.double)
  2934. ),
  2935. ('BCEWithLogitsLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9, dtype=torch.double)),
  2936. ('HingeEmbeddingLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)),
  2937. ('MultiLabelMarginLoss', lambda: torch.randn(4, dtype=torch.double), lambda: torch.tensor([3, 0, -1, 1])),
  2938. ('SoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)),
  2939. ('NLLLoss', lambda: F.log_softmax(torch.randn(3, dtype=torch.double), dim=0), lambda: torch.tensor(1)),
  2940. (
  2941. 'CosineEmbeddingLoss',
  2942. lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)),
  2943. lambda: torch.tensor(1, dtype=torch.double)
  2944. ),
  2945. # For MarginRankingLoss, input_fn : (x1, x2) and target_fn : target
  2946. ('MarginRankingLoss', lambda: (torch.randn(()), torch.randn(())), lambda: torch.randn(()).sign()),
  2947. # For TripletMarginLoss, input_fn : (anchor, positive) and target_fn : negative
  2948. (
  2949. 'TripletMarginLoss',
  2950. lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)),
  2951. lambda: torch.randn(9, dtype=torch.double)
  2952. ),
  2953. ('MultiLabelSoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9)),
  2954. ]
  2955. classification_criterion_no_batch_extra_info: Dict[str, dict] = {
  2956. 'MultiLabelMarginLoss': {'check_gradgrad': False},
  2957. }
  2958. # TODO : Fix these discrepancies
  2959. classification_cpp_parity = {
  2960. 'BCELoss': False,
  2961. 'BCEWithLogitsLoss': False,
  2962. 'HingeEmbeddingLoss': False,
  2963. 'NLLLoss': False,
  2964. 'SoftMarginLoss': False,
  2965. }
  2966. reductions = ['none', 'mean', 'sum']
  2967. for (name, input_fn, target_fn), reduction in product(classification_criterion_no_batch,
  2968. reductions):
  2969. classification_test_info = dict(
  2970. fullname=f"{name}_no_batch_dim_{reduction}",
  2971. constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction),
  2972. input_fn=lambda f=input_fn: f(),
  2973. target_fn=lambda f=target_fn: f(),
  2974. reference_fn=single_batch_reference_criterion_fn,
  2975. test_cpp_api_parity=True,
  2976. has_parity=classification_cpp_parity.get(name, True)
  2977. )
  2978. extra_info = classification_criterion_no_batch_extra_info.get(name, {})
  2979. classification_test_info.update(extra_info)
  2980. criterion_tests.append(classification_test_info)
  2981. class NNTestCase(TestCase):
  2982. # _forward is defined in classes inheriting from NNTestCase
  2983. @abstractmethod
  2984. def _forward(self, *args, **kwargs):
  2985. raise NotImplementedError
  2986. @abstractmethod
  2987. def _get_parameters(self, module: nn.Module) -> Tuple[List[nn.Parameter], List[nn.Parameter]]:
  2988. raise NotImplementedError
  2989. @abstractmethod
  2990. def _zero_grad_parameters(self, module: nn.Module) -> None:
  2991. raise NotImplementedError
  2992. @abstractmethod
  2993. def _backward(self, module: nn.Module,
  2994. input: _TensorOrTensors, output: torch.Tensor,
  2995. grad_output: Union[torch.Tensor, Sequence[torch.Tensor]],
  2996. create_graph: bool = False):
  2997. raise NotImplementedError
  2998. def _jacobian(self, input, num_out):
  2999. if isinstance(input, tuple):
  3000. return tuple(self._jacobian(elem, num_out) for elem in input)
  3001. elif isinstance(input, list):
  3002. return [self._jacobian(elem, num_out) for elem in input]
  3003. else:
  3004. return torch.zeros(input.nelement(), num_out)
  3005. def _flatten_tensors(self, x):
  3006. if isinstance(x, torch.Tensor):
  3007. if x.is_sparse:
  3008. return x.to_dense().view(-1)
  3009. else:
  3010. return x.view(-1)
  3011. else:
  3012. return tuple(self._flatten_tensors(a) for a in x)
  3013. def _zero_grad_input(self, input):
  3014. if isinstance(input, torch.Tensor):
  3015. if input.requires_grad and input.grad is not None:
  3016. input.grad.zero_()
  3017. input.grad.detach_()
  3018. else:
  3019. for i in input:
  3020. self._zero_grad_input(i)
  3021. def _analytical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True):
  3022. output = self._forward(module, input)
  3023. output_size = output.nelement()
  3024. if jacobian_input:
  3025. jacobian_inp = self._jacobian(input, output_size)
  3026. flat_jacobian_input = list(_iter_tensors(jacobian_inp))
  3027. if jacobian_parameters:
  3028. num_param = sum(p.numel() for p in self._get_parameters(module)[0])
  3029. jacobian_param = torch.zeros(num_param, output_size)
  3030. for i in range(output_size):
  3031. param, d_param = self._get_parameters(module)
  3032. # make non grad zeros
  3033. d_param = [torch.zeros_like(p) if d is None else d for (p, d) in zip(param, d_param)]
  3034. d_out = torch.zeros_like(output)
  3035. flat_d_out = d_out.view(-1)
  3036. flat_d_out[i] = 1
  3037. if jacobian_parameters:
  3038. self._zero_grad_parameters(module)
  3039. # Tensors will accumulate gradient from multiple steps
  3040. if jacobian_input:
  3041. self._zero_grad_input(input)
  3042. d_input = self._backward(module, input, output, d_out)
  3043. if jacobian_input:
  3044. for jacobian_x, d_x in zip(flat_jacobian_input, _iter_tensors(d_input)):
  3045. jacobian_x[:, i] = d_x.contiguous().view(-1)
  3046. if jacobian_parameters:
  3047. jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
  3048. res: Tuple[torch.Tensor, ...] = tuple()
  3049. if jacobian_input:
  3050. res += jacobian_inp,
  3051. if jacobian_parameters:
  3052. res += jacobian_param,
  3053. return res
  3054. def _numerical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True):
  3055. def fw(*input):
  3056. return self._forward(module, input).detach()
  3057. res: Tuple[torch.Tensor, ...] = tuple()
  3058. if jacobian_input:
  3059. res += _get_numerical_jacobian(fw, input, eps=1e-6),
  3060. if jacobian_parameters:
  3061. param, _ = self._get_parameters(module)
  3062. to_cat = []
  3063. for p in param:
  3064. jacobian = _get_numerical_jacobian(fw, input, target=p, eps=1e-6)
  3065. # get_numerical_jacobian returns a list of tuples but we require a tensor
  3066. to_cat.append(jacobian[0][0])
  3067. res += (torch.cat(to_cat, 0),)
  3068. return res
  3069. def check_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True):
  3070. jacobian_parameters = bool(self._get_parameters(module)[0])
  3071. analytical = self._analytical_jacobian(module, input, jacobian_input, jacobian_parameters)
  3072. numerical = self._numerical_jacobian(module, input, jacobian_input, jacobian_parameters)
  3073. analytical_t = list(_iter_tensors(analytical))
  3074. numerical_t = list(_iter_tensors(numerical))
  3075. differences = []
  3076. for a, n in zip(analytical_t, numerical_t):
  3077. if a.numel() != 0:
  3078. differences.append(a.add(n, alpha=-1).abs().max())
  3079. # TODO: compare structure (ensure analytic jacobian has correct shape)
  3080. if len(differences) > 0:
  3081. self.assertLessEqual(max(differences), PRECISION) # type: ignore[type-var]
  3082. class TestBase:
  3083. _required_arg_names = {'constructor_args', 'input', 'extra_args'}
  3084. def __init__(self, constructor, desc='', reference_fn=None, fullname=None, **kwargs):
  3085. self.desc = desc
  3086. self.fullname = fullname
  3087. self.constructor = constructor
  3088. self.reference_fn = reference_fn
  3089. for name in self._required_arg_names:
  3090. if name not in kwargs and name + '_fn' not in kwargs and name + '_size' not in kwargs:
  3091. if name in {'constructor_args', 'extra_args'}:
  3092. kwargs[name] = tuple()
  3093. else:
  3094. raise ValueError(f"{self.get_name()}: Specify {name} by a value, a function to generate it, or it's size!")
  3095. self._extra_kwargs = kwargs
  3096. self._arg_cache = {}
  3097. def get_name(self):
  3098. if self.fullname is not None:
  3099. return 'test_' + self.fullname
  3100. test_name = 'test_' + self.constructor.__name__
  3101. if self.desc:
  3102. test_name += '_' + self.desc
  3103. return test_name
  3104. def _unpack(self, value):
  3105. if isinstance(value, torch.Tensor):
  3106. return value
  3107. elif is_iterable(value):
  3108. return type(value)(self._unpack(v) for v in value)
  3109. else:
  3110. return value
  3111. @property
  3112. def constructor_args(self):
  3113. return self._get_arg('constructor_args', True)
  3114. @property
  3115. def extra_args(self):
  3116. return self._get_arg('extra_args', True)
  3117. def _get_arg(self, name, unpack):
  3118. assert name in self._required_arg_names
  3119. if name not in self._arg_cache:
  3120. fn_name = name + '_fn'
  3121. size_name = name + '_size'
  3122. if name in self._extra_kwargs:
  3123. self._arg_cache[name] = self._extra_kwargs[name]
  3124. elif fn_name in self._extra_kwargs:
  3125. self._arg_cache[name] = self._extra_kwargs[fn_name]()
  3126. else:
  3127. assert size_name in self._extra_kwargs, \
  3128. f"Missing `{name}`, `{size_name}` or `{fn_name}` for {self.get_name()}"
  3129. def map_tensor_sizes(sizes):
  3130. if isinstance(sizes, list):
  3131. return [map_tensor_sizes(s) for s in sizes]
  3132. elif isinstance(sizes, torch.Tensor):
  3133. return sizes.double()
  3134. else:
  3135. return torch.randn(sizes)
  3136. self._arg_cache[name] = map_tensor_sizes(self._extra_kwargs[size_name])
  3137. return self._unpack(self._arg_cache[name]) if unpack else self._arg_cache[name]
  3138. def _get_input(self, unpack=True):
  3139. return self._get_arg('input', unpack)
  3140. def __call__(self, test_case):
  3141. raise NotImplementedError
  3142. class ModuleTest(TestBase):
  3143. @abstractmethod
  3144. def _do_test(self, test_case: Any, module: nn.Module, input: Any) -> Any:
  3145. raise NotImplementedError
  3146. def __init__(self, *args, **kwargs):
  3147. super().__init__(*args, **kwargs)
  3148. self.jacobian_input = kwargs.get('jacobian_input', True)
  3149. self.should_test_cuda = kwargs.get('test_cuda', True)
  3150. self.should_test_pickle = kwargs.get('pickle', True)
  3151. self.check_gradgrad = kwargs.get('check_gradgrad', True)
  3152. self.FIXME_no_cuda_gradgrad_comparison = \
  3153. kwargs.get('FIXME_no_cuda_gradgrad_comparison', False)
  3154. self.precision = kwargs.get('precision', 2e-4)
  3155. self.check_forward_only = kwargs.get('check_forward_only', False)
  3156. self.default_dtype = kwargs.get('default_dtype', None)
  3157. if self.default_dtype is None:
  3158. self.default_dtype = torch.get_default_dtype()
  3159. def __call__(self, test_case):
  3160. with set_default_dtype(self.default_dtype):
  3161. module = self.constructor(*self.constructor_args)
  3162. input = self._get_input()
  3163. if self.reference_fn is not None:
  3164. out = test_case._forward(module, input)
  3165. ref_input = deepcopy(input)
  3166. ref_module = deepcopy(module)
  3167. expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0], ref_module)
  3168. test_case.assertEqual(out, expected_out, exact_dtype=False)
  3169. if self.check_forward_only:
  3170. return
  3171. self.test_noncontig(test_case, module, input)
  3172. if self.should_test_pickle:
  3173. # TODO: do this with in-memory files as soon as torch.save will support it
  3174. with tempfile.TemporaryFile() as f:
  3175. test_case._forward(module, input)
  3176. torch.save(module, f)
  3177. f.seek(0)
  3178. module_copy = torch.load(f)
  3179. test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))
  3180. self._do_test(test_case, module, input)
  3181. def noncontiguize(self, obj):
  3182. if isinstance(obj, list):
  3183. return [self.noncontiguize(o) for o in obj]
  3184. elif isinstance(obj, tuple):
  3185. return tuple(self.noncontiguize(o) for o in obj)
  3186. tensor = obj
  3187. ndim = tensor.dim()
  3188. # Always making only the last dimension noncontiguous is easy to hide
  3189. # bugs because .view(-1) will still work. So try to find a dim with size
  3190. # > 1 and make that non-contiguous, i.e., stack + select on the
  3191. # dimension directly after that.
  3192. dim = ndim
  3193. for d in range(ndim):
  3194. if tensor.size(d) > 1:
  3195. dim = d + 1
  3196. break
  3197. noncontig = torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach()
  3198. assert noncontig.numel() == 1 or noncontig.numel() == 0 or not noncontig.is_contiguous()
  3199. noncontig.requires_grad = tensor.requires_grad
  3200. return noncontig
  3201. def test_noncontig(self, test_case, module, input):
  3202. # check no scalars, can't make non-contig
  3203. if isinstance(input, torch.Tensor) and input.dim() == 0:
  3204. return
  3205. if any(i.dim() == 0 for i in input if isinstance(i, torch.Tensor)):
  3206. return
  3207. test_case._zero_grad_parameters(module)
  3208. test_case._zero_grad_input(input)
  3209. with freeze_rng_state():
  3210. output = test_case._forward(module, input)
  3211. if getattr(module, "return_indices", False):
  3212. output = output[0]
  3213. grad_output = output.new(output.shape).normal_()
  3214. output = output.clone()
  3215. d_input = deepcopy(test_case._backward(module, input, output, grad_output))
  3216. d_param = deepcopy(test_case._get_parameters(module)[1])
  3217. nc_input = self.noncontiguize(input)
  3218. nc_grad_output = self.noncontiguize(grad_output)
  3219. for contig_i, contig_g in product((True, False), repeat=2):
  3220. i = input if contig_i else nc_input
  3221. # Some ops, e.g., nn.Flatten, return gradient that shares
  3222. # storage with the grad_output. Hence we copy here.
  3223. go = deepcopy(grad_output if contig_g else nc_grad_output)
  3224. test_case._zero_grad_parameters(module)
  3225. test_case._zero_grad_input(i)
  3226. with freeze_rng_state():
  3227. out = test_case._forward(module, i)
  3228. if getattr(module, "return_indices", False):
  3229. out = out[0]
  3230. grad = test_case._backward(module, i, out, go)
  3231. test_case.assertEqual(out, output)
  3232. test_case.assertEqual(grad, d_input, atol=1e-4, rtol=0)
  3233. test_case.assertEqual(test_case._get_parameters(module)[1], d_param)
  3234. def test_cuda(self, test_case):
  3235. if not TEST_CUDA or not self.should_test_cuda:
  3236. raise unittest.SkipTest('Excluded from CUDA tests')
  3237. with set_default_dtype(self.default_dtype):
  3238. cpu_input = self._get_input()
  3239. type_map = {torch.double: torch.float}
  3240. cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,)
  3241. is_any_input_complex = any(isinstance(t, torch.Tensor) and t.dtype.is_complex for t in cpu_input_tuple)
  3242. gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map)
  3243. cpu_module = self.constructor(*self.constructor_args)
  3244. gpu_module = self.constructor(*self.constructor_args).float().cuda()
  3245. cpu_param = test_case._get_parameters(cpu_module)
  3246. gpu_param = test_case._get_parameters(gpu_module)
  3247. for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]):
  3248. gpu_p.data.copy_(cpu_p)
  3249. test_case._zero_grad_input(cpu_input_tuple)
  3250. test_case._zero_grad_input(gpu_input_tuple)
  3251. test_case._zero_grad_parameters(cpu_module)
  3252. test_case._zero_grad_parameters(gpu_module)
  3253. cpu_output = test_case._forward(cpu_module, cpu_input_tuple)
  3254. gpu_output = test_case._forward(gpu_module, gpu_input_tuple)
  3255. if getattr(cpu_module, "return_indices", False):
  3256. cpu_output = cpu_output[0]
  3257. gpu_output = gpu_output[0]
  3258. test_case.assertEqual(cpu_output, gpu_output, atol=self.precision, rtol=0, exact_dtype=False)
  3259. # Run backwards on CPU and GPU and compare results
  3260. for _ in range(5):
  3261. cpu_gradOutput = cpu_output.clone().normal_()
  3262. gpu_gradOutput = cpu_gradOutput.type_as(gpu_output)
  3263. cpu_gradInput = test_case._backward(cpu_module, cpu_input_tuple, cpu_output, cpu_gradOutput)
  3264. gpu_gradInput = test_case._backward(gpu_module, gpu_input_tuple, gpu_output, gpu_gradOutput)
  3265. test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)
  3266. for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]):
  3267. test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0)
  3268. # Run double-backwards on CPU and GPU and compare results
  3269. if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison:
  3270. cpu_output = cpu_module(*cpu_input_tuple)
  3271. gpu_output = gpu_module(*gpu_input_tuple)
  3272. if getattr(cpu_module, "return_indices", False):
  3273. cpu_output = cpu_output[0]
  3274. gpu_output = gpu_output[0]
  3275. cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True)
  3276. gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach()
  3277. gpu_gradOutput.requires_grad = True
  3278. cpu_gradInputs = torch.autograd.grad(
  3279. cpu_output,
  3280. cpu_input_tuple + tuple(cpu_module.parameters()),
  3281. cpu_gradOutput,
  3282. create_graph=True)
  3283. gpu_gradInputs = torch.autograd.grad(
  3284. gpu_output,
  3285. gpu_input_tuple + tuple(gpu_module.parameters()),
  3286. gpu_gradOutput,
  3287. create_graph=True)
  3288. for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs):
  3289. test_case.assertEqual(cpu_d_i, gpu_d_i, atol=self.precision, rtol=0, exact_dtype=False)
  3290. # We mix output into the second backwards computation so that
  3291. # torch.autograd.grad doesn't complain that some inputs
  3292. # are unreachable (which can happen if you differentiate
  3293. # only on the gradient.
  3294. if is_any_input_complex:
  3295. outputs_cpu = cpu_output.sum().abs() + sum(x.sum().abs() for x in cpu_gradInputs)
  3296. outputs_gpu = gpu_output.sum().abs() + sum(x.sum().abs() for x in gpu_gradInputs)
  3297. else:
  3298. outputs_cpu = cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs)
  3299. outputs_gpu = gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs)
  3300. cpu_gg = torch.autograd.grad(
  3301. outputs_cpu,
  3302. cpu_input_tuple + (cpu_gradOutput,) + tuple(cpu_module.parameters()),
  3303. retain_graph=True)
  3304. gpu_gg = torch.autograd.grad(
  3305. outputs_gpu,
  3306. gpu_input_tuple + (gpu_gradOutput,) + tuple(gpu_module.parameters()),
  3307. retain_graph=True)
  3308. test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)
  3309. for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg):
  3310. test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0, exact_dtype=False)
  3311. self.test_noncontig(test_case, gpu_module, gpu_input_tuple)
  3312. class InputVariableMixin:
  3313. def _get_input(self):
  3314. input = TestBase._get_input(self, False) # type: ignore[arg-type]
  3315. def map_variables(i):
  3316. if isinstance(i, torch.Tensor):
  3317. if i.is_floating_point() or i.is_complex():
  3318. i.requires_grad = True
  3319. return i
  3320. else:
  3321. return type(i)(map_variables(elem) for elem in i)
  3322. return map_variables(input)
  3323. class NewModuleTest(InputVariableMixin, ModuleTest): # type: ignore[misc]
  3324. def __init__(self, *args, **kwargs):
  3325. super().__init__(*args, **kwargs)
  3326. self.cudnn = kwargs.get('cudnn', False)
  3327. self.check_inplace = kwargs.get('check_inplace', False)
  3328. self.check_gradgrad = kwargs.get('check_gradgrad', True)
  3329. self.skip_double = kwargs.get('skip_double', False)
  3330. self.skip_half = kwargs.get('skip_half', False)
  3331. self.with_tf32 = kwargs.get('with_tf32', False)
  3332. self.tf32_precision = kwargs.get('tf32_precision', 0.001)
  3333. self.test_cpu = kwargs.get('test_cpu', True)
  3334. self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False)
  3335. self.check_batched_grad = kwargs.get('check_batched_grad', True)
  3336. self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode', None)
  3337. self.supports_forward_ad = kwargs.get('supports_forward_ad', False)
  3338. self.supports_fwgrad_bwgrad = kwargs.get('supports_fwgrad_bwgrad', False)
  3339. def _check_gradients(self, test_case, module, input_tuple):
  3340. params = tuple(x for x in module.parameters())
  3341. num_inputs = len(input_tuple)
  3342. def fn_to_gradcheck(*inputs_and_params, **kwargs):
  3343. assert not kwargs
  3344. return test_case._forward(module, inputs_and_params[:num_inputs])
  3345. # gradcheck doesn't support operators that take in dense inputs but
  3346. # return sparse parameters. This only happens in the case of nn.Embedding
  3347. # and nn.EmbeddingBag. Instead, we call `self.check_jacobian`, which
  3348. # is a slightly different version of gradcheck that can handle this.
  3349. if self.has_sparse_gradients:
  3350. assert num_inputs == 1
  3351. test_input_jacobian = torch.is_floating_point(input_tuple[0])
  3352. test_case.check_jacobian(module, input_tuple[0], test_input_jacobian)
  3353. else:
  3354. test_case.assertTrue(gradcheck(fn_to_gradcheck, input_tuple + params,
  3355. check_batched_grad=self.check_batched_grad,
  3356. fast_mode=self.gradcheck_fast_mode,
  3357. check_forward_ad=self.supports_forward_ad))
  3358. if self.check_gradgrad:
  3359. test_case.assertTrue(gradgradcheck(fn_to_gradcheck, input_tuple + params,
  3360. check_batched_grad=self.check_batched_grad,
  3361. fast_mode=self.gradcheck_fast_mode,
  3362. check_fwd_over_rev=self.supports_fwgrad_bwgrad))
  3363. def _do_test(self, test_case, module, input):
  3364. num_threads = torch.get_num_threads()
  3365. torch.set_num_threads(1)
  3366. input_tuple = input if isinstance(input, tuple) else (input,)
  3367. self._check_gradients(test_case, module, input_tuple)
  3368. # check if module can be printed
  3369. module.__repr__()
  3370. if self.check_inplace:
  3371. # check if the inplace variant of the module gives the same result
  3372. # as the out-of-place
  3373. # check_inplace doesn't support multiple input tensors, since we don't have any modules
  3374. # that modify the inputs in-place and that accept more than one input
  3375. assert len(input_tuple) == 1
  3376. input = input_tuple[0]
  3377. module_ip = self.constructor(*self.constructor_args, inplace=True)
  3378. input_version = input._version
  3379. with freeze_rng_state():
  3380. output = module(input)
  3381. test_case.assertEqual(input._version, input_version)
  3382. input_ip = deepcopy(input)
  3383. input_ip_clone = input_ip.clone()
  3384. with freeze_rng_state():
  3385. output_ip = module_ip(input_ip_clone)
  3386. test_case.assertNotEqual(input_ip_clone._version, input_version)
  3387. test_case.assertEqual(output, output_ip)
  3388. grad = output.data.clone().normal_()
  3389. if input.grad is not None:
  3390. with torch.no_grad():
  3391. input.grad.zero_()
  3392. if input_ip.grad is not None:
  3393. with torch.no_grad():
  3394. input_ip.grad.zero_()
  3395. output.backward(grad)
  3396. output_ip.backward(grad)
  3397. test_case.assertEqual(input.grad, input_ip.grad)
  3398. def assert_module_parameters_are(tensor_type, device_id=None):
  3399. for p in module.parameters():
  3400. test_case.assertIsInstance(p, tensor_type)
  3401. if device_id is not None:
  3402. test_case.assertEqual(p.get_device(), device_id)
  3403. if all(isinstance(t, torch.LongTensor) for t in input_tuple) and TEST_CUDA:
  3404. # check that cuda() moves module parameters to correct GPU device,
  3405. # and that float() casts parameters correctly
  3406. input_tuple = tuple(t.cuda() for t in input_tuple)
  3407. module.float().cuda()
  3408. module(*input_tuple)
  3409. assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
  3410. if torch.cuda.device_count() > 1:
  3411. input_tuple = tuple(t.cuda(1) for t in input_tuple)
  3412. module.cuda(1)
  3413. with torch.cuda.device(1):
  3414. module(*input_tuple)
  3415. assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined]
  3416. else:
  3417. # check that float()/double() casters work correctly
  3418. def to_type(tensor, real, complex):
  3419. if tensor.is_complex():
  3420. return tensor.to(complex)
  3421. elif tensor.is_floating_point():
  3422. return tensor.to(real)
  3423. else:
  3424. return tensor
  3425. def to_half(x):
  3426. # TODO: torch.complex32 when properly supported
  3427. return to_type(x, torch.float16, None)
  3428. def to_single(x):
  3429. return to_type(x, torch.float32, torch.complex64)
  3430. def to_double(x):
  3431. return to_type(x, torch.float64, torch.complex128)
  3432. # to float
  3433. input_tuple = tuple(to_single(t) for t in input_tuple)
  3434. module.float()
  3435. module(*input_tuple)
  3436. assert_module_parameters_are(torch.FloatTensor)
  3437. # and back to double
  3438. input_tuple = tuple(to_double(t) for t in input_tuple)
  3439. module.double()
  3440. module(*input_tuple)
  3441. assert_module_parameters_are(torch.DoubleTensor)
  3442. if TEST_CUDA and self.should_test_cuda:
  3443. # check that cuda() moves module parameters to correct GPU device,
  3444. # and that float() casts parameters correctly
  3445. # to GPU0
  3446. input_tuple = tuple(to_single(t).cuda() for t in input_tuple)
  3447. module.float().cuda()
  3448. module(*input_tuple)
  3449. assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
  3450. # to CPU
  3451. input_tuple = tuple(t.cpu() for t in input_tuple)
  3452. module.cpu()
  3453. module(*input_tuple)
  3454. assert_module_parameters_are(torch.FloatTensor)
  3455. # back to GPU0
  3456. input_tuple = tuple(t.cuda() for t in input_tuple)
  3457. module.cuda()
  3458. module(*input_tuple)
  3459. assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
  3460. # test that forwards of module runs correctly without cuDNN
  3461. if self.cudnn:
  3462. with torch.backends.cudnn.flags(enabled=False):
  3463. module(*input_tuple)
  3464. assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
  3465. if torch.cuda.device_count() >= 2:
  3466. # test cross-GPU transfer works
  3467. # to GPU1
  3468. input_tuple = tuple(t.cuda(1) for t in input_tuple)
  3469. module.cuda(1)
  3470. with torch.cuda.device(1):
  3471. module(*input_tuple)
  3472. assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined]
  3473. if not self.skip_double:
  3474. # test double()
  3475. input_tuple = tuple(to_double(t).cuda() for t in input_tuple)
  3476. module.double().cuda()
  3477. module(*input_tuple)
  3478. assert_module_parameters_are(torch.cuda.DoubleTensor, 0) # type: ignore[attr-defined]
  3479. # test half()
  3480. if not self.skip_half:
  3481. input_tuple = tuple(to_half(t).cuda() for t in input_tuple)
  3482. module.half().cuda()
  3483. module(*input_tuple)
  3484. assert_module_parameters_are(torch.cuda.HalfTensor, 0) # type: ignore[attr-defined]
  3485. torch.set_num_threads(num_threads)
  3486. def _get_target(self):
  3487. return self._get_arg('target', False)
  3488. @property
  3489. def constructor_args(self):
  3490. return self._get_arg('constructor_args', False)
  3491. class CriterionTest(InputVariableMixin, TestBase): # type: ignore[misc]
  3492. # TODO: check that criterions don't ignore grad_output
  3493. _required_arg_names = TestBase._required_arg_names.union({'target'})
  3494. def __init__(self, *args, **kwargs):
  3495. super().__init__(*args, **kwargs)
  3496. self.should_test_cuda = kwargs.get('test_cuda', True)
  3497. self.check_forward_only = kwargs.get('check_forward_only', False)
  3498. self.check_gradgrad = kwargs.get('check_gradgrad', True)
  3499. self.check_half = kwargs.get('check_half', True)
  3500. self.check_bfloat16 = kwargs.get('check_bfloat16', False)
  3501. self.check_complex = kwargs.get('check_complex', False)
  3502. self.test_cpu = kwargs.get('test_cpu', True)
  3503. self.with_tf32 = kwargs.get('with_tf32', True)
  3504. self.tf32_precision = kwargs.get('tf32_precision', 0.001)
  3505. self.check_batched_grad = kwargs.get('check_batched_grad', True)
  3506. self.default_dtype = kwargs.get('default_dtype', None)
  3507. if self.default_dtype is None:
  3508. self.default_dtype = torch.get_default_dtype()
  3509. def __call__(self, test_case):
  3510. with set_default_dtype(self.default_dtype):
  3511. module = self.constructor(*self.constructor_args)
  3512. input = self._get_input()
  3513. # Check that these methods don't raise errors
  3514. module.__repr__()
  3515. str(module)
  3516. target = self._get_target()
  3517. if self.reference_fn is not None:
  3518. out = test_case._forward_criterion(module, input, target, extra_args=self.extra_args)
  3519. ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,)
  3520. expected_out = self.reference_fn(*ref_args)
  3521. test_case.assertEqual(out, expected_out)
  3522. if self.check_forward_only:
  3523. return
  3524. params = tuple(x for x in module.parameters())
  3525. if not isinstance(input, tuple):
  3526. inputs = (input,) + params + (target,)
  3527. def apply_fn(input, target, *params):
  3528. return module(input, target)
  3529. else:
  3530. inputs = input + params + (target,)
  3531. def apply_fn(input1, input2, target, *params): # type: ignore[misc]
  3532. return module(input1, input2, target)
  3533. gradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad)
  3534. if self.check_gradgrad:
  3535. gradgradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad)
  3536. def test_cuda(self, test_case, dtype, extra_args=None):
  3537. def convert_dtype(obj, dtype, requires_grad=False):
  3538. if isinstance(obj, torch.Tensor):
  3539. return obj.detach().to(dtype=dtype).requires_grad_(requires_grad)
  3540. elif isinstance(obj, tuple):
  3541. return tuple(convert_dtype(o, dtype, requires_grad) for o in obj)
  3542. else:
  3543. return obj
  3544. if not TEST_CUDA or not self.should_test_cuda:
  3545. raise unittest.SkipTest('Excluded from CUDA tests')
  3546. with set_default_dtype(self.default_dtype):
  3547. cpu_input = self._get_input()
  3548. cpu_target = self._get_target()
  3549. cpu_module = self.constructor(*self.constructor_args)
  3550. gpu_module = self.constructor(*self.constructor_args)
  3551. # Convert input, target and module parameters to dtype
  3552. cpu_input = convert_dtype(cpu_input, dtype, True)
  3553. if cpu_target.is_floating_point() or cpu_target.is_complex():
  3554. cpu_target = convert_dtype(cpu_target, dtype)
  3555. cpu_module.type(dtype)
  3556. gpu_module.type(dtype)
  3557. # GPU setup
  3558. gpu_input = to_gpu(cpu_input)
  3559. gpu_target = to_gpu(cpu_target)
  3560. gpu_module.cuda()
  3561. # torch.HalfTensor doesn't support most operations, converting back to default
  3562. if dtype in {torch.half, torch.bfloat16}:
  3563. cpu_input = self._get_input()
  3564. cpu_target = self._get_target()
  3565. # Loss modules with weights require consistent input/module weight types
  3566. cpu_module = self.constructor(*self.constructor_args)
  3567. cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target, extra_args=extra_args)
  3568. gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target, extra_args=extra_args)
  3569. # dtype used to be able to be None, so set precision in this way instead of a precision map
  3570. test_case.assertEqual(cpu_output, gpu_output,
  3571. atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False)
  3572. cpu_gradInput = test_case._backward_criterion(
  3573. cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args)
  3574. gpu_gradInput = test_case._backward_criterion(
  3575. gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args)
  3576. # dtype used to be able to be None, so set precision in this way instead of a precision map
  3577. test_case.assertEqual(cpu_gradInput, gpu_gradInput,
  3578. atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False)
  3579. def _get_target(self):
  3580. return self._get_arg('target', False)
  3581. @property
  3582. def constructor_args(self):
  3583. return self._get_arg('constructor_args', False)
  3584. @property
  3585. def extra_args(self):
  3586. return self._get_arg('extra_args', False)
  3587. def _test_bfloat16_ops(test_case, op, device, inp_dims=(), prec=1e-2, scale_factor=None):
  3588. # fp32 compute
  3589. input1 = torch.randn(inp_dims, dtype=torch.float32, device=device, requires_grad=True)
  3590. if scale_factor is not None:
  3591. input1 = (torch.rand(inp_dims, dtype=torch.bfloat16, device=device) * scale_factor).float().requires_grad_()
  3592. out1 = op(input1)
  3593. grad_input1 = torch.randn_like(out1, device=device)
  3594. out1.backward(grad_input1)
  3595. # bfloat16 compute
  3596. op_bfp16 = op.bfloat16()
  3597. input2 = input1.detach().bfloat16().requires_grad_()
  3598. grad_input2 = grad_input1.bfloat16()
  3599. out2 = op_bfp16(input2)
  3600. out2.backward(grad_input2)
  3601. test_case.assertEqual(out1, out2, atol=prec, rtol=prec, exact_dtype=False)
  3602. test_case.assertEqual(input1.grad.data, input2.grad.data, atol=prec, rtol=prec, exact_dtype=False)
  3603. def _test_module_empty_input(test_case, module, inp, check_size=True, inference=False):
  3604. if not inference:
  3605. inp.requires_grad_(True)
  3606. out = module(inp)
  3607. if not inference:
  3608. gO = torch.rand_like(out)
  3609. out.backward(gO)
  3610. if check_size:
  3611. test_case.assertEqual(out.size(), inp.size())
  3612. if not inference:
  3613. for p in module.parameters():
  3614. if p.requires_grad:
  3615. test_case.assertEqual(p.grad, torch.zeros_like(p.grad))
  3616. test_case.assertEqual(inp.grad, torch.zeros_like(inp))
  3617. def _create_basic_net():
  3618. class Layer(nn.Module):
  3619. def __init__(self):
  3620. super().__init__()
  3621. self.layer_dummy_param = nn.Parameter(torch.empty(3, 5))
  3622. self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7))
  3623. class Net(nn.Module):
  3624. def __init__(self):
  3625. super().__init__()
  3626. self.l1 = Layer()
  3627. self.dummy_param = nn.Parameter(torch.empty(3, 5))
  3628. self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1))
  3629. l = Layer()
  3630. n = Net()
  3631. s = nn.Sequential(n, n)
  3632. return l, n, s