common_quantization.py 105 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935
  1. # mypy: ignore-errors
  2. r"""Importing this file includes common utility methods and base clases for
  3. checking quantization api and properties of resulting modules.
  4. """
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
  9. import torch.ao.nn.quantized as nnq
  10. import torch.ao.nn.quantized.dynamic as nnqd
  11. from torch.ao.nn.intrinsic import _FusedModule
  12. import torch.distributed as dist
  13. from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM
  14. from torch._export import capture_pre_autograd_graph
  15. from torch.ao.quantization import (
  16. QuantType,
  17. default_dynamic_qat_qconfig,
  18. default_embedding_qat_qconfig,
  19. default_symmetric_qnnpack_qat_qconfig,
  20. )
  21. from torch.ao.quantization.quantize_pt2e import (
  22. _convert_to_reference_decomposed_fx,
  23. convert_pt2e,
  24. prepare_pt2e,
  25. prepare_qat_pt2e,
  26. )
  27. from torch.ao.quantization.backend_config import (
  28. get_executorch_backend_config,
  29. )
  30. from torch.ao.quantization.quantizer.xnnpack_quantizer import (
  31. XNNPACKQuantizer,
  32. get_symmetric_quantization_config,
  33. )
  34. from torch.ao.quantization import QuantWrapper, QuantStub, DeQuantStub, \
  35. default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
  36. propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_weight_only_qconfig, \
  37. get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, quantize, \
  38. QConfigMapping, get_default_qconfig_mapping, get_default_qat_qconfig_mapping
  39. from torch.ao.quantization.quantization_mappings import (
  40. get_default_dynamic_quant_module_mappings,
  41. get_default_qconfig_propagation_list,
  42. get_default_qat_module_mappings,
  43. )
  44. from torch.testing._internal.common_quantized import (
  45. override_quantized_engine,
  46. )
  47. from torch.jit.mobile import _load_for_lite_interpreter
  48. try:
  49. # graph mode quantization based on fx
  50. from torch.ao.quantization.quantize_fx import (
  51. prepare_fx,
  52. prepare_qat_fx,
  53. convert_fx,
  54. convert_to_reference_fx,
  55. )
  56. from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph
  57. from torch.fx.graph import Node
  58. from torch.fx import GraphModule
  59. HAS_FX = True
  60. except ImportError:
  61. HAS_FX = False
  62. import copy
  63. import io
  64. import functools
  65. import time
  66. import os
  67. import unittest
  68. import numpy as np
  69. from torch.testing import FileCheck
  70. from typing import Callable, Tuple, Dict, Any, Union, Type, Optional
  71. import torch._dynamo as torchdynamo
  72. class NodeSpec:
  73. ''' Used for checking GraphModule Node
  74. '''
  75. def __init__(self, op, target):
  76. '''
  77. op: call_function | call_module
  78. target:
  79. for call_function, target would be a function
  80. for call_module, target would be the type of PyTorch module
  81. '''
  82. self.op = op
  83. self.target = target
  84. @classmethod
  85. def call_function(cls, target):
  86. return NodeSpec('call_function', target)
  87. @classmethod
  88. def call_method(cls, target):
  89. return NodeSpec('call_method', target)
  90. @classmethod
  91. def call_module(cls, target):
  92. return NodeSpec('call_module', target)
  93. def __hash__(self):
  94. return hash((self.op, self.target))
  95. def __eq__(self, other):
  96. if not isinstance(other, NodeSpec):
  97. return NotImplemented
  98. return self.op == other.op and self.target == other.target
  99. def __repr__(self):
  100. return repr(self.op) + " " + repr(self.target)
  101. def get_supported_device_types():
  102. return ['cpu', 'cuda'] if torch.cuda.is_available() and not TEST_WITH_ROCM else ['cpu']
  103. def test_only_eval_fn(model, calib_data):
  104. r"""
  105. Default evaluation function takes a torch.utils.data.Dataset or a list of
  106. input Tensors and run the model on the dataset
  107. """
  108. for inp in calib_data:
  109. output = model(*inp)
  110. _default_loss_fn = torch.nn.CrossEntropyLoss()
  111. def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn):
  112. r"""
  113. Default train function takes a torch.utils.data.Dataset and train the model
  114. on the dataset
  115. """
  116. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  117. train_loss, correct, total = 0, 0, 0
  118. for i in range(10):
  119. model.train()
  120. for data, target in train_data:
  121. optimizer.zero_grad()
  122. output = model(data)
  123. loss = loss_fn(output, target)
  124. loss.backward()
  125. optimizer.step()
  126. train_loss += loss.item()
  127. _, predicted = torch.max(output, 1)
  128. total += target.size(0)
  129. correct += (predicted == target).sum().item()
  130. return train_loss, correct, total
  131. class AverageMeter:
  132. """Computes and stores the average and current value"""
  133. def __init__(self, name, fmt=':f'):
  134. self.name = name
  135. self.fmt = fmt
  136. self.reset()
  137. def reset(self):
  138. self.val = 0
  139. self.avg = 0
  140. self.sum = 0
  141. self.count = 0
  142. def update(self, val, n=1):
  143. self.val = val
  144. self.sum += val * n
  145. self.count += n
  146. self.avg = self.sum / self.count
  147. def __str__(self):
  148. fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
  149. return fmtstr.format(**self.__dict__)
  150. def accuracy(output, target, topk=(1,)):
  151. """Computes the accuracy over the k top predictions for the specified values of k"""
  152. with torch.no_grad():
  153. maxk = max(topk)
  154. batch_size = target.size(0)
  155. _, pred = output.topk(maxk, 1, True, True)
  156. pred = pred.t()
  157. correct = pred.eq(target.view(1, -1).expand_as(pred))
  158. res = []
  159. for k in topk:
  160. correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
  161. res.append(correct_k.mul_(100.0 / batch_size))
  162. return res
  163. def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches):
  164. model.train()
  165. cnt = 0
  166. for image, target in data_loader:
  167. start_time = time.time()
  168. print('.', end='')
  169. cnt += 1
  170. image, target = image.to(device), target.to(device)
  171. output = model(image)
  172. loss = criterion(output, target)
  173. optimizer.zero_grad()
  174. loss.backward()
  175. optimizer.step()
  176. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  177. if cnt >= ntrain_batches:
  178. return
  179. return
  180. def ddp_setup(rank, world_size):
  181. os.environ['MASTER_ADDR'] = 'localhost'
  182. os.environ['MASTER_PORT'] = '12355'
  183. # initialize the process group
  184. dist.init_process_group("gloo", rank=rank, world_size=world_size)
  185. def ddp_cleanup():
  186. dist.destroy_process_group()
  187. def run_ddp(rank, world_size, prepared):
  188. ddp_setup(rank, world_size)
  189. prepared.cuda()
  190. prepared = torch.nn.parallel.DistributedDataParallel(prepared, device_ids=[rank])
  191. prepared.to(rank)
  192. model_with_ddp = prepared
  193. optimizer = torch.optim.SGD(model_with_ddp.parameters(), lr=0.0001)
  194. train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1) # noqa: F821
  195. ddp_cleanup()
  196. def convert_dynamic(module):
  197. convert(module, get_default_dynamic_quant_module_mappings(), inplace=True)
  198. def prepare_dynamic(model, qconfig_dict=None):
  199. propagate_qconfig_(model, qconfig_dict)
  200. def _make_conv_test_input(
  201. batch_size, in_channels_per_group, input_feature_map_size,
  202. out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, W_scale,
  203. W_zero_point, use_bias, use_channelwise,
  204. ):
  205. in_channels = in_channels_per_group * groups
  206. out_channels = out_channels_per_group * groups
  207. (X_value_min, X_value_max) = (0, 4)
  208. X_init = torch.randint(
  209. X_value_min, X_value_max,
  210. (batch_size, in_channels,) + input_feature_map_size)
  211. X = X_scale * (X_init - X_zero_point).float()
  212. X_q = torch.quantize_per_tensor(
  213. X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8)
  214. W_scale = W_scale * out_channels
  215. W_zero_point = W_zero_point * out_channels
  216. # Resize W_scale and W_zero_points arrays equal to out_channels
  217. W_scale = W_scale[:out_channels]
  218. W_zero_point = W_zero_point[:out_channels]
  219. # For testing, we use small values for weights and for activations so that
  220. # no overflow occurs in vpmaddubsw instruction. If the overflow occurs in
  221. # qconv implementation and if there is no overflow.
  222. # In reference we can't exactly match the results with reference.
  223. # Please see the comment in qconv implementation file
  224. # aten/src/ATen/native/quantized/cpu/qconv.cpp for more details.
  225. (W_value_min, W_value_max) = (-5, 5)
  226. # The operator expects them in the format
  227. # (out_channels, in_channels/groups,) + kernel_size
  228. W_init = torch.randint(
  229. W_value_min, W_value_max,
  230. (out_channels, in_channels_per_group,) + kernel_size)
  231. b_init = torch.randint(0, 10, (out_channels,))
  232. if use_channelwise:
  233. W_shape = (-1, 1) + (1,) * len(kernel_size)
  234. W_scales_tensor = torch.tensor(W_scale, dtype=torch.float)
  235. W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float)
  236. W = W_scales_tensor.reshape(*W_shape) * (
  237. W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float()
  238. b = X_scale * W_scales_tensor * b_init.float()
  239. W_q = torch.quantize_per_channel(
  240. W, W_scales_tensor.double(), W_zero_points_tensor.long(), 0,
  241. dtype=torch.qint8)
  242. else:
  243. W = W_scale[0] * (W_init - W_zero_point[0]).float()
  244. b = X_scale * W_scale[0] * b_init.float()
  245. W_q = torch.quantize_per_tensor(
  246. W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8)
  247. return (X, X_q, W, W_q, b if use_bias else None)
  248. def _make_conv_add_extra_input_tensor(scale, zero_point, sizes):
  249. (X_value_min, X_value_max) = (0, 4)
  250. X_init = torch.randint(
  251. X_value_min,
  252. X_value_max,
  253. sizes # Infer the size of tensor to do the add
  254. )
  255. X = scale * (X_init - zero_point).float()
  256. X_q = torch.quantize_per_tensor(
  257. X, scale=scale, zero_point=zero_point, dtype=torch.quint8)
  258. return X, X_q
  259. def skipIfNoFBGEMM(fn):
  260. reason = 'Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer.'
  261. if isinstance(fn, type):
  262. if 'fbgemm' not in torch.backends.quantized.supported_engines:
  263. fn.__unittest_skip__ = True
  264. fn.__unittest_skip_why__ = reason
  265. return fn
  266. @functools.wraps(fn)
  267. def wrapper(*args, **kwargs):
  268. if 'fbgemm' not in torch.backends.quantized.supported_engines:
  269. raise unittest.SkipTest(reason)
  270. else:
  271. fn(*args, **kwargs)
  272. return wrapper
  273. def skipIfNoQNNPACK(fn):
  274. reason = 'Quantized operations require QNNPACK.'
  275. if isinstance(fn, type):
  276. if 'qnnpack' not in torch.backends.quantized.supported_engines:
  277. fn.__unittest_skip__ = True
  278. fn.__unittest_skip_why__ = reason
  279. return fn
  280. @functools.wraps(fn)
  281. def wrapper(*args, **kwargs):
  282. if 'qnnpack' not in torch.backends.quantized.supported_engines:
  283. raise unittest.SkipTest(reason)
  284. else:
  285. fn(*args, **kwargs)
  286. return wrapper
  287. @functools.wraps(fn)
  288. def wrapper(*args, **kwargs):
  289. if not torch.onnx._CAFFE2_ATEN_FALLBACK:
  290. raise unittest.SkipTest(reason)
  291. else:
  292. fn(*args, **kwargs)
  293. return wrapper
  294. def withQNNPACKBackend(fn):
  295. # TODO(future PR): consider combining with skipIfNoQNNPACK,
  296. # will require testing of existing callsites
  297. reason = 'Quantized operations require QNNPACK.'
  298. if isinstance(fn, type):
  299. if 'qnnpack' not in torch.backends.quantized.supported_engines:
  300. fn.__unittest_skip__ = True
  301. fn.__unittest_skip_why__ = reason
  302. return fn
  303. @functools.wraps(fn)
  304. def wrapper(*args, **kwargs):
  305. if 'qnnpack' not in torch.backends.quantized.supported_engines:
  306. raise unittest.SkipTest(reason)
  307. with override_quantized_engine('qnnpack'):
  308. fn(*args, **kwargs)
  309. return wrapper
  310. def skipIfNoONEDNN(fn):
  311. reason = 'Quantized operations require ONEDNN.'
  312. if isinstance(fn, type):
  313. if 'onednn' not in torch.backends.quantized.supported_engines:
  314. fn.__unittest_skip__ = True
  315. fn.__unittest_skip_why__ = reason
  316. return fn
  317. @functools.wraps(fn)
  318. def wrapper(*args, **kwargs):
  319. if 'onednn' not in torch.backends.quantized.supported_engines:
  320. raise unittest.SkipTest(reason)
  321. else:
  322. fn(*args, **kwargs)
  323. return wrapper
  324. def skipIfNoONEDNNBF16(fn):
  325. reason = 'Quantized operations require BF16 support.'
  326. if isinstance(fn, type):
  327. if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
  328. fn.__unittest_skip__ = True
  329. fn.__unittest_skip_why__ = reason
  330. return fn
  331. @functools.wraps(fn)
  332. def wrapper(*args, **kwargs):
  333. if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
  334. raise unittest.SkipTest(reason)
  335. else:
  336. fn(*args, **kwargs)
  337. return wrapper
  338. def skipIfNoX86(fn):
  339. reason = 'Quantized operations require X86.'
  340. if isinstance(fn, type):
  341. if 'x86' not in torch.backends.quantized.supported_engines:
  342. fn.__unittest_skip__ = True
  343. fn.__unittest_skip_why__ = reason
  344. return fn
  345. @functools.wraps(fn)
  346. def wrapper(*args, **kwargs):
  347. if 'x86' not in torch.backends.quantized.supported_engines:
  348. raise unittest.SkipTest(reason)
  349. else:
  350. fn(*args, **kwargs)
  351. return wrapper
  352. def skipIfNoDynamoSupport(fn):
  353. reason = "dynamo doesn't support."
  354. if isinstance(fn, type):
  355. if not torchdynamo.is_dynamo_supported():
  356. fn.__unittest_skip__ = True
  357. fn.__unittest_skip_why__ = reason
  358. return fn
  359. @functools.wraps(fn)
  360. def wrapper(*args, **kwargs):
  361. if not torchdynamo.is_dynamo_supported():
  362. raise unittest.SkipTest(reason)
  363. else:
  364. fn(*args, **kwargs)
  365. return wrapper
  366. def skipIfNoInductorSupport(fn):
  367. reason = "inductor doesn't support."
  368. if isinstance(fn, type):
  369. if not torchdynamo.is_inductor_supported():
  370. fn.__unittest_skip__ = True
  371. fn.__unittest_skip_why__ = reason
  372. return fn
  373. @functools.wraps(fn)
  374. def wrapper(*args, **kwargs):
  375. if not torchdynamo.is_inductor_supported():
  376. raise unittest.SkipTest(reason)
  377. else:
  378. fn(*args, **kwargs)
  379. return wrapper
  380. try:
  381. import torchvision # noqa: F401
  382. HAS_TORCHVISION = True
  383. except ImportError:
  384. HAS_TORCHVISION = False
  385. skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
  386. def get_script_module(model, tracing, data):
  387. return torch.jit.trace(model, data) if tracing else torch.jit.script(model)
  388. def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True):
  389. """
  390. Convert lengths to offsets for embedding_bag
  391. """
  392. tt = np.zeros((t.shape[0] + 1,), dtype=offset_type)
  393. tt[1:] = t
  394. tt = torch.from_numpy(np.cumsum(tt, dtype=offset_type))
  395. if use_begin_offset:
  396. return tt[:-1]
  397. return tt[1:]
  398. def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
  399. assert w.dim() == 2
  400. w = w.transpose(0, 1).contiguous()
  401. assert q_group_size > 1
  402. assert w.shape[-1] % q_group_size == 0
  403. to_quant = w.reshape(-1, q_group_size)
  404. assert torch.isnan(to_quant).sum() == 0
  405. max_val = to_quant.amax(dim=1, keepdim=True)
  406. min_val = to_quant.amin(dim=1, keepdim=True)
  407. max_int = 2 ** n_bit - 1
  408. min_int = 0
  409. scales = (max_val - min_val).clamp(min=1e-6) / max_int
  410. assert torch.isnan(scales).sum() == 0
  411. zeros = min_val + scales * (2 ** (n_bit - 1))
  412. assert torch.isnan(zeros).sum() == 0
  413. out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
  414. assert torch.isnan(out).sum() == 0
  415. out = out.to(dtype=torch.int32).reshape(w.shape)
  416. # Scales and zeros for the same q-group should be contiguous, so we can
  417. # load as a 32-bit word
  418. scales = scales.view(w.shape[0], -1)
  419. zeros = zeros.view(w.shape[0], -1)
  420. scales_and_zeros = (
  421. torch.cat(
  422. [
  423. scales.reshape(scales.size(0), scales.size(1), 1),
  424. zeros.reshape(zeros.size(0), zeros.size(1), 1),
  425. ],
  426. 2,
  427. ).transpose(0, 1).contiguous()
  428. )
  429. return out, scales_and_zeros
  430. def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
  431. # source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py
  432. # default setup for affine quantization of activations
  433. x_dtype = x.dtype
  434. x = x.float()
  435. eps = torch.finfo(torch.float32).eps
  436. # get min and max
  437. min_val, max_val = torch.aminmax(x, dim=1)
  438. # calculate scales and zero_points based on min and max
  439. # reference: https://fburl.com/code/srbiybme
  440. min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
  441. max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
  442. device = min_val_neg.device
  443. # reference: https://fburl.com/code/4wll53rk
  444. max_val_pos = torch.max(-min_val_neg, max_val_pos)
  445. scales = max_val_pos / (float(quant_max - quant_min) / 2)
  446. # ensure scales is the same dtype as the original tensor
  447. scales = torch.clamp(scales, min=eps).to(x.dtype)
  448. zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
  449. # quantize based on qmin/qmax/scales/zp
  450. x_div = x / scales.unsqueeze(-1)
  451. x_round = torch.round(x_div)
  452. x_zp = x_round + zero_points.unsqueeze(-1)
  453. quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
  454. return quant, scales.to(x_dtype), zero_points
  455. # QuantizationTestCase used as a base class for testing quantization on modules
  456. class QuantizationTestCase(TestCase):
  457. def setUp(self):
  458. super().setUp()
  459. self.calib_data = [[torch.rand(2, 5, dtype=torch.float)] for _ in range(2)]
  460. self.train_data = [[torch.rand(2, 5, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)] for _ in range(2)]
  461. self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)]
  462. for _ in range(2)]
  463. self.img_data_2d = [[torch.rand(1, 3, 10, 10, dtype=torch.float)]
  464. for _ in range(2)]
  465. self.img_data_3d = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float)]
  466. for _ in range(2)]
  467. self.img_data_1d_train = [[torch.rand(2, 3, 10, dtype=torch.float),
  468. torch.randint(0, 1, (1,), dtype=torch.long)]
  469. for _ in range(2)]
  470. self.img_data_2d_train = [[torch.rand(1, 3, 10, 10, dtype=torch.float),
  471. torch.randint(0, 1, (1,), dtype=torch.long)]
  472. for _ in range(2)]
  473. self.img_data_3d_train = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float),
  474. torch.randint(0, 1, (1,), dtype=torch.long)]
  475. for _ in range(2)]
  476. self.img_data_dict = {1 : self.img_data_1d,
  477. 2 : self.img_data_2d,
  478. 3 : self.img_data_3d}
  479. # Quant types that produce statically quantized ops
  480. self.static_quant_types = [QuantType.STATIC, QuantType.QAT]
  481. # All quant types for (fx based) graph mode quantization
  482. self.all_quant_types = [QuantType.DYNAMIC, QuantType.STATIC, QuantType.QAT]
  483. def checkNoPrepModules(self, module):
  484. r"""Checks the module does not contain child
  485. modules for quantization preparation, e.g.
  486. quant, dequant and observer
  487. """
  488. self.assertFalse(hasattr(module, 'quant'))
  489. self.assertFalse(hasattr(module, 'dequant'))
  490. def checkNoQconfig(self, module):
  491. r"""Checks the module does not contain qconfig
  492. """
  493. self.assertFalse(hasattr(module, 'qconfig'))
  494. for child in module.children():
  495. self.checkNoQconfig(child)
  496. def checkHasPrepModules(self, module):
  497. r"""Checks the module contains child
  498. modules for quantization preparation, e.g.
  499. quant, dequant and observer
  500. """
  501. self.assertTrue(hasattr(module, 'module'))
  502. self.assertTrue(hasattr(module, 'quant'))
  503. self.assertTrue(hasattr(module, 'dequant'))
  504. def checkObservers(self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None):
  505. r"""Checks the module or module's leaf descendants
  506. have observers in preparation for quantization
  507. """
  508. if propagate_qconfig_list is None:
  509. propagate_qconfig_list = get_default_qconfig_propagation_list()
  510. if prepare_custom_config_dict is None:
  511. prepare_custom_config_dict = {}
  512. float_to_observed_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
  513. # check if a module is a leaf module, ignoring activation_post_process attribute
  514. def is_leaf_module(module):
  515. submodule_name_count = 0
  516. for name, _ in module.named_children():
  517. if name != 'activation_post_process':
  518. submodule_name_count += 1
  519. return submodule_name_count == 0
  520. if hasattr(module, 'qconfig') and module.qconfig is not None and \
  521. ((is_leaf_module(module) and not isinstance(module, torch.nn.Sequential)
  522. and type(module) in propagate_qconfig_list) or
  523. type(module) in float_to_observed_module_class_mapping.keys()) and \
  524. not isinstance(module, torch.ao.quantization.DeQuantStub):
  525. self.assertTrue(hasattr(module, 'activation_post_process'),
  526. 'module: ' + str(type(module)) + ' do not have observer')
  527. # we don't need to check observers for child modules of the
  528. # qat modules
  529. if type(module) not in get_default_qat_module_mappings().values() and \
  530. type(module) not in float_to_observed_module_class_mapping.values() and \
  531. not isinstance(module, _FusedModule):
  532. for child in module.children():
  533. if type(child) in [nn.Dropout]:
  534. continue
  535. self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict)
  536. def checkQuantDequant(self, mod):
  537. r"""Checks that mod has nn.Quantize and
  538. nn.DeQuantize submodules inserted
  539. """
  540. self.assertEqual(type(mod.quant), nnq.Quantize)
  541. self.assertEqual(type(mod.dequant), nnq.DeQuantize)
  542. def checkWrappedQuantizedLinear(self, mod):
  543. r"""Checks that mod has been swapped for an nnq.Linear
  544. module, the bias is qint32, and that the module
  545. has Quantize and DeQuantize submodules
  546. """
  547. self.assertEqual(type(mod.module), nnq.Linear)
  548. self.checkQuantDequant(mod)
  549. def checkQuantizedLinear(self, mod):
  550. self.assertEqual(type(mod), nnq.Linear)
  551. def checkDynamicQuantizedLinear(self, mod, dtype):
  552. r"""Checks that mod has been swapped for an nnqd.Linear
  553. module, the bias is float.
  554. """
  555. self.assertEqual(type(mod), nnqd.Linear)
  556. self.assertEqual(mod._packed_params.dtype, dtype)
  557. def checkDynamicQuantizedLinearRelu(self, mod, dtype):
  558. r"""Checks that mod has been swapped for an nnqd.Linear
  559. module, the bias is float.
  560. """
  561. self.assertEqual(type(mod), nniqd.LinearReLU)
  562. self.assertEqual(mod._packed_params.dtype, dtype)
  563. def check_eager_serialization(self, ref_model, loaded_model, x):
  564. # Check state dict serialization and torch.save APIs
  565. model_dict = ref_model.state_dict()
  566. b = io.BytesIO()
  567. torch.save(model_dict, b)
  568. b.seek(0)
  569. loaded_dict = torch.load(b)
  570. loaded_model.load_state_dict(loaded_dict)
  571. ref_out = ref_model(*x)
  572. load_out = loaded_model(*x)
  573. def check_outputs(ref_out, load_out):
  574. self.assertEqual(ref_out[0], load_out[0])
  575. if isinstance(ref_out[1], tuple):
  576. self.assertEqual(ref_out[1][0], load_out[1][0])
  577. self.assertEqual(ref_out[1][1], load_out[1][1])
  578. else:
  579. self.assertEqual(ref_out[1], load_out[1])
  580. check_outputs(ref_out, load_out)
  581. b = io.BytesIO()
  582. torch.save(ref_model, b)
  583. b.seek(0)
  584. loaded = torch.load(b)
  585. load_out = loaded(*x)
  586. check_outputs(ref_out, load_out)
  587. def check_weight_bias_api(self, ref_model, weight_keys, bias_keys):
  588. weight = ref_model.get_weight()
  589. bias = ref_model.get_bias()
  590. self.assertEqual(weight_keys ^ weight.keys(), set())
  591. self.assertEqual(bias_keys ^ bias.keys(), set())
  592. def checkDynamicQuantizedLSTM(self, mod, reference_module_type, dtype):
  593. r"""Checks that mod has been swapped for an nnqd.LSTM type
  594. module, the bias is float.
  595. """
  596. wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'}
  597. self.assertEqual(type(mod), reference_module_type)
  598. for packed_params in mod._all_weight_values:
  599. self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype])
  600. def checkLinear(self, mod):
  601. self.assertEqual(type(mod), torch.nn.Linear)
  602. def checkDynamicQuantizedModule(self, mod, reference_module_type, dtype):
  603. r"""Checks that mod has been swapped for an nnqd.Linear
  604. module, the bias is float.
  605. """
  606. wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'}
  607. self.assertEqual(type(mod), reference_module_type)
  608. if hasattr(mod, '_all_weight_values'):
  609. for packed_params in mod._all_weight_values:
  610. self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype])
  611. def checkScriptable(self, orig_mod, calib_data, check_save_load=False):
  612. scripted = torch.jit.script(orig_mod)
  613. self._checkScriptable(orig_mod, scripted, calib_data, check_save_load)
  614. # Use first calib_data entry as trace input
  615. traced = torch.jit.trace(orig_mod, calib_data[0])
  616. self._checkScriptable(orig_mod, traced, calib_data, check_save_load)
  617. # Call this twice: once for a scripted module and once for a traced module
  618. def _checkScriptable(self, orig_mod, script_mod, calib_data, check_save_load):
  619. self._checkModuleCorrectnessAgainstOrig(orig_mod, script_mod, calib_data)
  620. # Test save/load
  621. buffer = io.BytesIO()
  622. torch.jit.save(script_mod, buffer)
  623. buffer.seek(0)
  624. loaded_mod = torch.jit.load(buffer)
  625. # Pending __get_state_ and __set_state__ support
  626. # See tracking task https://github.com/pytorch/pytorch/issues/23984
  627. if check_save_load:
  628. self._checkModuleCorrectnessAgainstOrig(orig_mod, loaded_mod, calib_data)
  629. def _checkModuleCorrectnessAgainstOrig(self, orig_mod, test_mod, calib_data):
  630. for inp in calib_data:
  631. ref_output = orig_mod(*inp)
  632. scripted_output = test_mod(*inp)
  633. self.assertEqual(scripted_output, ref_output)
  634. def checkGraphModeOp(self, module, inputs, quantized_op, tracing=False, debug=False,
  635. check=True, eval_mode=True, dynamic=False, qconfig=None):
  636. if debug:
  637. print('Testing:', str(module))
  638. qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
  639. if eval_mode:
  640. module = module.eval()
  641. if dynamic:
  642. qconfig_dict = {'': default_dynamic_qconfig if qconfig is None else qconfig}
  643. model = get_script_module(module, tracing, inputs[0]).eval()
  644. if debug:
  645. print('input graph:', model.graph)
  646. models = {}
  647. outputs = {}
  648. for debug in [True, False]:
  649. if dynamic:
  650. models[debug] = quantize_dynamic_jit(model, qconfig_dict, debug=debug)
  651. # make sure it runs
  652. outputs[debug] = models[debug](inputs)
  653. else:
  654. # module under test can contain in-place ops, and we depend on
  655. # input data staying constant for comparisons
  656. inputs_copy = copy.deepcopy(inputs)
  657. models[debug] = quantize_jit(
  658. model, qconfig_dict, test_only_eval_fn, [inputs_copy], inplace=False,
  659. debug=debug)
  660. # make sure it runs
  661. outputs[debug] = models[debug](*inputs[0])
  662. if debug:
  663. print('debug graph:', models[True].graph)
  664. print('non debug graph:', models[False].graph)
  665. if check:
  666. # debug and non-debug option should have the same numerics
  667. self.assertEqual(outputs[True], outputs[False])
  668. # non debug graph should produce quantized op
  669. FileCheck().check(quantized_op) \
  670. .run(models[False].graph)
  671. return models[False]
  672. def checkGraphModuleNodes(
  673. self, graph_module,
  674. expected_node=None,
  675. expected_node_occurrence=None,
  676. expected_node_list=None):
  677. """ Check if GraphModule contains the target node
  678. Args:
  679. graph_module: the GraphModule instance we want to check
  680. expected_node, expected_node_occurrence, expected_node_list:
  681. see docs for checkGraphModeFxOp
  682. """
  683. nodes_in_graph = {}
  684. node_list = []
  685. modules = dict(graph_module.named_modules(remove_duplicate=False))
  686. for node in graph_module.graph.nodes:
  687. n = None
  688. if node.op == 'call_function' or node.op == 'call_method':
  689. n = NodeSpec(node.op, node.target)
  690. elif node.op == 'call_module':
  691. n = NodeSpec(node.op, type(modules[node.target]))
  692. if n is not None:
  693. node_list.append(n)
  694. if n in nodes_in_graph:
  695. nodes_in_graph[n] += 1
  696. else:
  697. nodes_in_graph[n] = 1
  698. if expected_node is not None:
  699. self.assertTrue(expected_node in nodes_in_graph, 'node:' + str(expected_node) +
  700. ' not found in the graph module')
  701. if expected_node_occurrence is not None:
  702. for expected_node, occurrence in expected_node_occurrence.items():
  703. if occurrence != 0:
  704. self.assertTrue(
  705. expected_node in nodes_in_graph,
  706. 'Check failed for node:' + str(expected_node) +
  707. ' not found')
  708. self.assertTrue(
  709. nodes_in_graph[expected_node] == occurrence,
  710. 'Check failed for node:' + str(expected_node) +
  711. ' Expected occurrence:' + str(occurrence) +
  712. ' Found occurrence:' + str(nodes_in_graph[expected_node]))
  713. else:
  714. self.assertTrue(
  715. expected_node not in nodes_in_graph,
  716. 'Check failed for node:' + str(expected_node) +
  717. ' expected no occurrence but found')
  718. if expected_node_list is not None:
  719. cur_index = 0
  720. for n in node_list:
  721. if cur_index == len(expected_node_list):
  722. return
  723. if n == expected_node_list[cur_index]:
  724. cur_index += 1
  725. self.assertTrue(
  726. cur_index == len(expected_node_list),
  727. "Check failed for graph:" +
  728. self.printGraphModule(graph_module, print_str=False) +
  729. "Expected ordered list:" +
  730. str(expected_node_list))
  731. def printGraphModule(self, graph_module, print_str=True):
  732. modules = dict(graph_module.named_modules(remove_duplicate=False))
  733. node_infos = []
  734. for n in graph_module.graph.nodes:
  735. node_info = ' '.join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs]))
  736. if n.op == 'call_module':
  737. node_info += ' module type: ' + repr(type(modules[n.target]))
  738. node_infos.append(node_info)
  739. str_to_print = '\n'.join(node_infos)
  740. if print_str:
  741. print(str_to_print)
  742. return str_to_print
  743. if HAS_FX:
  744. def assert_types_for_matched_subgraph_pairs(
  745. self,
  746. matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
  747. expected_types: Dict[str, Tuple[Tuple[Callable, Callable], Tuple[Callable, Callable]]],
  748. gm_a: GraphModule,
  749. gm_b: GraphModule,
  750. ) -> None:
  751. """
  752. Verifies that the types specified in expected_types match
  753. the underlying objects pointed to by the nodes in matched_subgraph_pairs.
  754. An example successful test case:
  755. matched_subgraph_pairs = {'x0': (graph_a_conv_0_node, graph_b_conv_0_node)}
  756. expected_types = {'x0': (nn.Conv2d, nnq.Conv2d)}
  757. The function tests for key equivalence, and verifies types with
  758. instance checks.
  759. """
  760. def _get_underlying_op_type(
  761. node: Node, gm: GraphModule
  762. ) -> Union[Callable, str]:
  763. if node.op == 'call_module':
  764. mod = getattr(gm, node.target)
  765. return type(mod)
  766. else:
  767. assert node.op in ('call_function', 'call_method')
  768. return node.target
  769. self.assertTrue(
  770. len(matched_subgraph_pairs) == len(expected_types),
  771. f'Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}'
  772. )
  773. for k, v in expected_types.items():
  774. expected_types_a, expected_types_b = v
  775. exp_type_start_a, exp_type_end_a = expected_types_a
  776. exp_type_start_b, exp_type_end_b = expected_types_b
  777. subgraph_a, subgraph_b = matched_subgraph_pairs[k]
  778. act_type_start_a = _get_underlying_op_type(subgraph_a.start_node, gm_a)
  779. act_type_start_b = _get_underlying_op_type(subgraph_b.start_node, gm_b)
  780. act_type_end_a = _get_underlying_op_type(subgraph_a.end_node, gm_a)
  781. act_type_end_b = _get_underlying_op_type(subgraph_b.end_node, gm_b)
  782. types_match = (exp_type_start_a is act_type_start_a) and \
  783. (exp_type_end_a is act_type_end_a) and \
  784. (exp_type_start_b is act_type_start_b) and \
  785. (exp_type_end_b is act_type_end_b)
  786. self.assertTrue(
  787. types_match,
  788. f'Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, '
  789. f'got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}'
  790. )
  791. def assert_ns_compare_dict_valid(
  792. self,
  793. act_compare_dict: Dict[str, Dict[str, Dict[str, Any]]],
  794. ) -> None:
  795. """
  796. Verifies that the act_compare_dict (output of Numeric Suite APIs) is valid:
  797. 1. for each layer, results are recorded for two models
  798. 2. number of seen tensors match
  799. 3. shapes of each pair of seen tensors match
  800. """
  801. for layer_name, result_type_to_data in act_compare_dict.items():
  802. for result_type, layer_data in result_type_to_data.items():
  803. self.assertTrue(
  804. len(layer_data) == 2,
  805. f"Layer {layer_name} does not have exactly two model results.")
  806. model_name_0, model_name_1 = layer_data.keys()
  807. for res_idx in range(len(layer_data[model_name_0])):
  808. layer_data_0 = layer_data[model_name_0][res_idx]
  809. layer_data_1 = layer_data[model_name_1][res_idx]
  810. self.assertTrue(
  811. layer_data_0['type'] == layer_data_0['type'],
  812. f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.")
  813. self.assertTrue(
  814. len(layer_data_0['values']) ==
  815. len(layer_data_1['values']),
  816. f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.")
  817. # F.conv1d weight has rank 3, and toq.conv1d unpacked weight
  818. # has rank 4. For now, skip the length check for conv1d only.
  819. is_weight_functional_conv1d = (
  820. result_type == NSSingleResultValuesType.WEIGHT.value and
  821. (
  822. 'conv1d' in layer_data_0['prev_node_target_type'] or
  823. 'conv1d' in layer_data_1['prev_node_target_type']
  824. )
  825. )
  826. if not is_weight_functional_conv1d:
  827. for idx in range(len(layer_data_0['values'])):
  828. values_0 = layer_data_0['values'][idx]
  829. values_1 = layer_data_1['values'][idx]
  830. if isinstance(values_0, torch.Tensor):
  831. self.assertTrue(
  832. values_0.shape == values_1.shape,
  833. f"Layer {layer_name}, {model_name_0} and {model_name_1} " +
  834. f"have a shape mismatch at idx {idx}.")
  835. elif isinstance(values_0, list):
  836. values_0 = values_0[0]
  837. values_1 = values_1[0]
  838. self.assertTrue(
  839. values_0.shape == values_1.shape,
  840. f"Layer {layer_name}, {model_name_0} and {model_name_1} " +
  841. f"have a shape mismatch at idx {idx}.")
  842. else:
  843. assert isinstance(values_0, tuple), \
  844. f"unhandled type {type(values_0)}"
  845. assert len(values_0) == 2
  846. assert len(values_0[1]) == 2
  847. assert values_0[0].shape == values_1[0].shape
  848. assert values_0[1][0].shape == values_1[1][0].shape
  849. assert values_0[1][1].shape == values_1[1][1].shape
  850. # verify that ref_node_name is valid
  851. ref_node_name_0 = layer_data_0['ref_node_name']
  852. ref_node_name_1 = layer_data_1['ref_node_name']
  853. prev_node_name_0 = layer_data_0['prev_node_name']
  854. prev_node_name_1 = layer_data_1['prev_node_name']
  855. if layer_data_0['type'] == NSSingleResultValuesType.NODE_OUTPUT.value:
  856. self.assertTrue(ref_node_name_0 == prev_node_name_0)
  857. self.assertTrue(ref_node_name_1 == prev_node_name_1)
  858. elif layer_data_0['type'] == NSSingleResultValuesType.NODE_INPUT.value:
  859. self.assertTrue(ref_node_name_0 != prev_node_name_0)
  860. self.assertTrue(ref_node_name_1 != prev_node_name_1)
  861. def checkGraphModeFxOp(
  862. self,
  863. model,
  864. inputs,
  865. quant_type,
  866. expected_node=None,
  867. expected_node_occurrence=None,
  868. expected_node_list=None,
  869. is_reference=False,
  870. print_debug_info=False,
  871. custom_qconfig_dict=None,
  872. prepare_expected_node=None,
  873. prepare_expected_node_occurrence=None,
  874. prepare_expected_node_list=None,
  875. prepare_custom_config=None,
  876. backend_config=None):
  877. """ Quantizes model with graph mode quantization on fx and check if the
  878. quantized model contains the quantized_node
  879. Args:
  880. model: floating point torch.nn.Module
  881. inputs: one positional sample input arguments for model
  882. expected_node: NodeSpec
  883. e.g. NodeSpec.call_function(torch.quantize_per_tensor)
  884. expected_node_occurrence: a dict from NodeSpec to
  885. expected number of occurrences (int)
  886. e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1,
  887. NodeSpec.call_method('dequantize'): 1}
  888. expected_node_list: a list of NodeSpec, used to check the order
  889. of the occurrence of Node
  890. e.g. [NodeSpec.call_function(torch.quantize_per_tensor),
  891. NodeSpec.call_module(nnq.Conv2d),
  892. NodeSpec.call_function(F.hardtanh_),
  893. NodeSpec.call_method('dequantize')]
  894. is_reference: if True, enables reference mode
  895. print_debug_info: if True, prints debug info
  896. custom_qconfig_dict: overrides default qconfig_dict
  897. prepare_expected_node: same as expected_node, but for prepare
  898. prepare_expected_node_occurrence: same as
  899. expected_node_occurrence, but for prepare
  900. prepare_expected_node_list: same as expected_node_list, but
  901. for prepare
  902. Returns:
  903. A dictionary with the following structure:
  904. {
  905. "prepared": ..., # the prepared model
  906. "quantized": ..., # the quantized non-reference model
  907. "quantized_reference": ..., # the quantized reference model
  908. "result": ..., # the result for either quantized or
  909. # quantized_reference model depending on the
  910. # is_reference argument
  911. }
  912. """
  913. # TODO: make img_data a single example instead of a list
  914. if type(inputs) == list:
  915. inputs = inputs[0]
  916. if quant_type == QuantType.QAT:
  917. qconfig_mapping = get_default_qat_qconfig_mapping(torch.backends.quantized.engine)
  918. model.train()
  919. elif quant_type == QuantType.STATIC:
  920. qconfig_mapping = get_default_qconfig_mapping(torch.backends.quantized.engine)
  921. model.eval()
  922. else:
  923. qconfig = default_dynamic_qconfig
  924. qconfig_mapping = QConfigMapping().set_global(qconfig)
  925. model.eval()
  926. if quant_type == QuantType.QAT:
  927. prepare = prepare_qat_fx
  928. else:
  929. prepare = prepare_fx
  930. # overwrite qconfig_dict with custom_qconfig_dict
  931. if custom_qconfig_dict is not None:
  932. assert type(custom_qconfig_dict) in (QConfigMapping, dict), \
  933. 'custom_qconfig_dict should be a QConfigMapping or a dict'
  934. if isinstance(custom_qconfig_dict, QConfigMapping):
  935. qconfig_mapping = custom_qconfig_dict
  936. else:
  937. qconfig_mapping = QConfigMapping.from_dict(custom_qconfig_dict)
  938. prepared = prepare(
  939. model, qconfig_mapping,
  940. example_inputs=inputs,
  941. prepare_custom_config=prepare_custom_config,
  942. backend_config=backend_config)
  943. if not quant_type == QuantType.DYNAMIC:
  944. prepared(*inputs)
  945. if print_debug_info:
  946. print()
  947. print('quant type:\n', quant_type)
  948. print('original model:\n', model)
  949. print()
  950. print('prepared model:\n', prepared)
  951. self.checkGraphModuleNodes(
  952. prepared, prepare_expected_node,
  953. prepare_expected_node_occurrence, prepare_expected_node_list)
  954. prepared_copy = copy.deepcopy(prepared)
  955. qgraph = convert_fx(copy.deepcopy(prepared))
  956. qgraph_reference = convert_to_reference_fx(copy.deepcopy(prepared))
  957. result = qgraph(*inputs)
  958. result_reference = qgraph_reference(*inputs)
  959. qgraph_copy = copy.deepcopy(qgraph)
  960. qgraph_reference_copy = copy.deepcopy(qgraph_reference)
  961. qgraph_to_check = qgraph_reference if is_reference else qgraph
  962. if print_debug_info:
  963. print()
  964. print('quantized model:\n', qgraph_to_check)
  965. self.printGraphModule(qgraph_to_check)
  966. print()
  967. self.checkGraphModuleNodes(
  968. qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list)
  969. return {"prepared": prepared_copy,
  970. "quantized": qgraph_copy,
  971. "quantized_reference": qgraph_reference_copy,
  972. "quantized_output": result,
  973. "quantized_reference_output": result_reference}
  974. def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets,
  975. set_qconfig, is_emb_bag, dtype=torch.quint8):
  976. # Test serialization of dynamic EmbeddingBag module using state_dict
  977. if is_emb_bag:
  978. inputs = [indices, offsets]
  979. else:
  980. inputs = [indices]
  981. emb_dict = qemb.state_dict()
  982. b = io.BytesIO()
  983. torch.save(emb_dict, b)
  984. b.seek(0)
  985. loaded_dict = torch.load(b)
  986. embedding_unpack = torch.ops.quantized.embedding_bag_unpack
  987. # Check unpacked weight values explicitly
  988. for key in emb_dict:
  989. if isinstance(emb_dict[key], torch._C.ScriptObject):
  990. assert isinstance(loaded_dict[key], torch._C.ScriptObject)
  991. emb_weight = embedding_unpack(emb_dict[key])
  992. loaded_weight = embedding_unpack(loaded_dict[key])
  993. self.assertEqual(emb_weight, loaded_weight)
  994. # Check state dict serialization and torch.save APIs
  995. if is_emb_bag:
  996. loaded_qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
  997. include_last_offset=True, mode='sum', dtype=dtype)
  998. else:
  999. loaded_qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype)
  1000. self.check_eager_serialization(qemb, loaded_qemb, inputs)
  1001. loaded_qemb.load_state_dict(loaded_dict)
  1002. self.assertEqual(embedding_unpack(qemb._packed_params._packed_weight),
  1003. embedding_unpack(loaded_qemb._packed_params._packed_weight))
  1004. # Test JIT serialization
  1005. self.checkScriptable(qemb, [inputs], check_save_load=True)
  1006. # Test from_float call
  1007. if is_emb_bag:
  1008. float_embedding = torch.nn.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
  1009. include_last_offset=True, scale_grad_by_freq=False, mode='sum')
  1010. else:
  1011. float_embedding = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
  1012. if set_qconfig:
  1013. float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype,
  1014. qscheme=torch.per_channel_affine_float_qparams,
  1015. ch_axis=0)
  1016. float_embedding.qconfig = QConfig(activation=default_dynamic_quant_observer,
  1017. weight=float_qparams_observer)
  1018. prepare_dynamic(float_embedding)
  1019. float_embedding(*inputs)
  1020. if is_emb_bag:
  1021. q_embeddingbag = nnq.EmbeddingBag.from_float(float_embedding)
  1022. expected_name = "QuantizedEmbeddingBag"
  1023. else:
  1024. q_embeddingbag = nnq.Embedding.from_float(float_embedding)
  1025. expected_name = "QuantizedEmbedding"
  1026. q_embeddingbag(*inputs)
  1027. self.assertTrue(expected_name in str(q_embeddingbag))
  1028. class QuantizationLiteTestCase(QuantizationTestCase):
  1029. def _create_quantized_model(self, model_class: Type[torch.nn.Module], **kwargs):
  1030. # Creates quantized model for testing mobile script modules
  1031. qengine = "qnnpack"
  1032. with override_quantized_engine(qengine):
  1033. qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  1034. model = model_class(**kwargs)
  1035. model = quantize(model, test_only_eval_fn, [self.calib_data])
  1036. return model
  1037. def _compare_script_and_mobile(self,
  1038. model: torch.nn.Module,
  1039. input: torch.Tensor):
  1040. # Compares the numerical outputs for script and lite modules
  1041. qengine = "qnnpack"
  1042. with override_quantized_engine(qengine):
  1043. script_module = torch.jit.script(model)
  1044. script_module_result = script_module(input)
  1045. max_retry = 5
  1046. for retry in range(1, max_retry + 1):
  1047. # retries `max_retry` times; breaks iff succeeds else throws exception
  1048. try:
  1049. buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
  1050. buffer.seek(0)
  1051. mobile_module = _load_for_lite_interpreter(buffer)
  1052. mobile_module_result = mobile_module(input)
  1053. torch.testing.assert_close(script_module_result, mobile_module_result)
  1054. mobile_module_forward_result = mobile_module.forward(input)
  1055. torch.testing.assert_close(script_module_result, mobile_module_forward_result)
  1056. mobile_module_run_method_result = mobile_module.run_method("forward", input)
  1057. torch.testing.assert_close(script_module_result, mobile_module_run_method_result)
  1058. except AssertionError as e:
  1059. if retry == max_retry:
  1060. raise e
  1061. else:
  1062. continue
  1063. break
  1064. class PT2EQuantizationTestCase(QuantizationTestCase):
  1065. """
  1066. Base QuantizationTestCase for PT2 with some helper methods.
  1067. """
  1068. _MAP_TO_FX_TRACED_OPS = {
  1069. torch.ops.quantized_decomposed.quantize_per_tensor: torch.ops.quantized_decomposed.quantize_per_tensor.default,
  1070. torch.ops.quantized_decomposed.dequantize_per_tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.default,
  1071. torch.ops.quantized_decomposed.quantize_per_channel: torch.ops.quantized_decomposed.quantize_per_channel.default,
  1072. torch.ops.quantized_decomposed.dequantize_per_channel: torch.ops.quantized_decomposed.dequantize_per_channel.default,
  1073. torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
  1074. torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
  1075. }
  1076. def _test_quantizer(
  1077. self,
  1078. model,
  1079. example_inputs,
  1080. quantizer,
  1081. expected_node_occurrence,
  1082. expected_node_list=None,
  1083. check_against_fx_quant=False,
  1084. fx_qconfig_mapping=None,
  1085. export_with_dynamic_shape=False,
  1086. is_qat=False,
  1087. is_debug_mode=False,
  1088. ):
  1089. # resetting dynamo cache
  1090. torch._dynamo.reset()
  1091. m_eager = model.eval()
  1092. # program capture
  1093. m = copy.deepcopy(m_eager)
  1094. dynamic_shapes = tuple(
  1095. {0: torch.export.Dim("dim")} if i == 0 else None
  1096. for i in range(len(example_inputs))
  1097. )
  1098. m = capture_pre_autograd_graph(
  1099. m,
  1100. example_inputs,
  1101. dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
  1102. )
  1103. if is_qat:
  1104. m = prepare_qat_pt2e(m, quantizer)
  1105. else:
  1106. m = prepare_pt2e(m, quantizer)
  1107. # Calibrate
  1108. m(*example_inputs)
  1109. m = convert_pt2e(m)
  1110. if is_debug_mode:
  1111. print("quantized model", m)
  1112. pt2_quant_output = m(*example_inputs)
  1113. ns = NodeSpec
  1114. node_occurrence = {
  1115. ns.call_function(k): v for k, v in expected_node_occurrence.items()
  1116. }
  1117. if expected_node_list is None:
  1118. expected_node_list = []
  1119. node_list = [ns.call_function(n) for n in expected_node_list]
  1120. self.checkGraphModuleNodes(
  1121. m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
  1122. )
  1123. if check_against_fx_quant:
  1124. qconfig_mapping = fx_qconfig_mapping
  1125. backend_config = get_executorch_backend_config()
  1126. m_copy = copy.deepcopy(m_eager)
  1127. m_fx = prepare_fx(
  1128. m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
  1129. )
  1130. m_fx(*example_inputs)
  1131. m_fx = _convert_to_reference_decomposed_fx(
  1132. m_fx, backend_config=backend_config
  1133. )
  1134. m_fx = capture_pre_autograd_graph(
  1135. m_fx,
  1136. example_inputs,
  1137. dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
  1138. )
  1139. node_occurrence = {}
  1140. for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items():
  1141. if k in expected_node_occurrence:
  1142. node_occurrence[ns.call_function(v)] = expected_node_occurrence[k]
  1143. self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence)
  1144. fx_quant_output = m_fx(*example_inputs)
  1145. self.assertEqual(fx_quant_output, pt2_quant_output)
  1146. def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False):
  1147. # resetting dynamo cache
  1148. torch._dynamo.reset()
  1149. m = capture_pre_autograd_graph(
  1150. m,
  1151. example_inputs,
  1152. )
  1153. if is_qat:
  1154. m = prepare_qat_pt2e(m, quantizer)
  1155. else:
  1156. m = prepare_pt2e(m, quantizer)
  1157. m(*example_inputs)
  1158. m = convert_pt2e(m)
  1159. return m
  1160. def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule:
  1161. class M(torch.nn.Module):
  1162. def __init__(self):
  1163. super().__init__()
  1164. self.linear = torch.nn.Linear(2, 2)
  1165. def forward(self, x):
  1166. return self.linear(x)
  1167. quantizer = XNNPACKQuantizer()
  1168. operator_config = get_symmetric_quantization_config(is_per_channel=is_per_channel)
  1169. quantizer.set_global(operator_config)
  1170. example_inputs = (torch.randn(2, 2),)
  1171. m = M().eval()
  1172. return self._quantize(m, quantizer, example_inputs)
  1173. # Below are a series of toy models to use in testing quantization
  1174. class SingleLayerLinearModel(torch.nn.Module):
  1175. def __init__(self):
  1176. super().__init__()
  1177. self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1178. def forward(self, x):
  1179. x = self.fc1(x)
  1180. return x
  1181. def get_example_inputs(self) -> Tuple[Any, ...]:
  1182. return (torch.rand(1, 5),)
  1183. class AnnotatedSingleLayerLinearModel(torch.nn.Module):
  1184. def __init__(self, qengine='fbgemm'):
  1185. super().__init__()
  1186. self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  1187. self.fc1 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
  1188. def forward(self, x):
  1189. x = self.fc1(x)
  1190. return x
  1191. def get_example_inputs(self) -> Tuple[Any, ...]:
  1192. return (torch.rand(1, 5),)
  1193. class SingleLayerLinearDynamicModel(torch.nn.Module):
  1194. def __init__(self, qengine='fbgemm'):
  1195. super().__init__()
  1196. self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  1197. self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1198. def forward(self, x):
  1199. x = self.fc1(x)
  1200. return x
  1201. def get_example_inputs(self) -> Tuple[Any, ...]:
  1202. return (torch.rand(1, 5),)
  1203. class LinearAddModel(nn.Module):
  1204. def __init__(self):
  1205. super().__init__()
  1206. self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
  1207. self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
  1208. def forward(self, x):
  1209. x = self.fc1(x)
  1210. x = torch.add(x, 5)
  1211. x = self.fc2(x)
  1212. return x
  1213. def get_example_inputs(self) -> Tuple[Any, ...]:
  1214. return (torch.rand(1, 5),)
  1215. class RNNDynamicModel(torch.nn.Module):
  1216. def __init__(self, mod_type):
  1217. super().__init__()
  1218. self.qconfig = default_dynamic_qconfig
  1219. if mod_type == 'GRU':
  1220. self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float)
  1221. if mod_type == 'LSTM':
  1222. self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float)
  1223. def forward(self, x):
  1224. x = self.mod(x)
  1225. return x
  1226. class RNNCellDynamicModel(torch.nn.Module):
  1227. def __init__(self, mod_type):
  1228. super().__init__()
  1229. self.qconfig = default_dynamic_qconfig
  1230. if mod_type == 'GRUCell':
  1231. self.mod = torch.nn.GRUCell(2, 2).to(dtype=torch.float)
  1232. if mod_type == 'LSTMCell':
  1233. self.mod = torch.nn.LSTMCell(2, 2).to(dtype=torch.float)
  1234. if mod_type == 'RNNReLU':
  1235. self.mod = torch.nn.RNNCell(2, 2, nonlinearity='relu').to(dtype=torch.float)
  1236. if mod_type == 'RNNTanh':
  1237. self.mod = torch.nn.RNNCell(2, 2, nonlinearity='tanh').to(dtype=torch.float)
  1238. def forward(self, x):
  1239. x = self.mod(x)
  1240. return x
  1241. class LSTMwithHiddenDynamicModel(torch.nn.Module):
  1242. def __init__(self, qengine='fbgemm'):
  1243. super().__init__()
  1244. self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  1245. self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float)
  1246. def forward(self, x, hid):
  1247. x, hid = self.lstm(x, hid)
  1248. return x, hid
  1249. class ConvModel(torch.nn.Module):
  1250. def __init__(self):
  1251. super().__init__()
  1252. self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1253. def forward(self, x):
  1254. x = self.conv(x)
  1255. return x
  1256. def get_example_inputs(self) -> Tuple[Any, ...]:
  1257. return (torch.rand(1, 3, 5, 5),)
  1258. class ConvTransposeModel(torch.nn.Module):
  1259. def __init__(self):
  1260. super().__init__()
  1261. self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1262. def forward(self, x):
  1263. x = self.conv(x)
  1264. return x
  1265. def get_example_inputs(self) -> Tuple[Any, ...]:
  1266. return (torch.rand(1, 3, 5, 5),)
  1267. class AnnotatedConvModel(torch.nn.Module):
  1268. def __init__(self, qengine):
  1269. super().__init__()
  1270. self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  1271. self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1272. self.quant = QuantStub()
  1273. self.dequant = DeQuantStub()
  1274. def forward(self, x):
  1275. x = self.quant(x)
  1276. x = self.conv(x)
  1277. x = self.dequant(x)
  1278. return x
  1279. def get_example_inputs(self) -> Tuple[Any, ...]:
  1280. return (torch.rand(1, 3, 5, 5),)
  1281. class AnnotatedConvTransposeModel(torch.nn.Module):
  1282. def __init__(self, qengine):
  1283. super().__init__()
  1284. self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  1285. self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1286. self.quant = QuantStub()
  1287. self.dequant = DeQuantStub()
  1288. def forward(self, x):
  1289. x = self.quant(x)
  1290. x = self.conv(x)
  1291. x = self.dequant(x)
  1292. return x
  1293. def get_example_inputs(self) -> Tuple[Any, ...]:
  1294. return (torch.rand(1, 3, 5, 5),)
  1295. class ConvBnModel(torch.nn.Module):
  1296. def __init__(self):
  1297. super().__init__()
  1298. self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1299. self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
  1300. def forward(self, x):
  1301. x = self.conv(x)
  1302. x = self.bn(x)
  1303. return x
  1304. def get_example_inputs(self) -> Tuple[Any, ...]:
  1305. return (torch.rand(1, 3, 5, 5),)
  1306. class AnnotatedConvBnModel(torch.nn.Module):
  1307. def __init__(self):
  1308. super().__init__()
  1309. self.qconfig = default_qconfig
  1310. self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1311. self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
  1312. self.quant = QuantStub()
  1313. self.dequant = DeQuantStub()
  1314. def forward(self, x):
  1315. x = self.quant(x)
  1316. x = self.conv(x)
  1317. x = self.bn(x)
  1318. x = self.dequant(x)
  1319. return x
  1320. def get_example_inputs(self) -> Tuple[Any, ...]:
  1321. return (torch.rand(1, 3, 5, 5),)
  1322. class ConvBnReLUModel(torch.nn.Module):
  1323. def __init__(self):
  1324. super().__init__()
  1325. self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1326. self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
  1327. self.relu = nn.ReLU(inplace=True)
  1328. def forward(self, x):
  1329. x = self.conv(x)
  1330. x = self.bn(x)
  1331. x = self.relu(x)
  1332. return x
  1333. def get_example_inputs(self) -> Tuple[Any, ...]:
  1334. return (torch.rand(1, 3, 5, 5),)
  1335. class AnnotatedConvBnReLUModel(torch.nn.Module):
  1336. def __init__(self, qengine='fbgemm'):
  1337. super().__init__()
  1338. self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  1339. self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1340. self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
  1341. self.relu = nn.ReLU(inplace=True)
  1342. self.quant = QuantStub()
  1343. self.dequant = DeQuantStub()
  1344. def forward(self, x):
  1345. x = self.quant(x)
  1346. x = self.conv(x)
  1347. x = self.bn(x)
  1348. x = self.relu(x)
  1349. x = self.dequant(x)
  1350. return x
  1351. def fuse_model(self):
  1352. # TODO: remove this check and define two fuse_modules function on this module
  1353. if self.training:
  1354. torch.ao.quantization.fuse_modules_qat(self, [['conv', 'bn', 'relu']], inplace=True)
  1355. else:
  1356. torch.ao.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True)
  1357. def get_example_inputs(self) -> Tuple[Any, ...]:
  1358. return (torch.rand(1, 3, 5, 5),)
  1359. class TwoLayerConvModel(torch.nn.Module):
  1360. def __init__(self):
  1361. super().__init__()
  1362. self.conv1 = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1363. self.conv2 = torch.nn.Conv2d(5, 5, 1, bias=False).to(dtype=torch.float)
  1364. def forward(self, x):
  1365. x = self.conv1(x)
  1366. x = self.conv2(x)
  1367. return x
  1368. def get_example_inputs(self) -> Tuple[Any, ...]:
  1369. return (torch.rand(1, 3, 5, 5),)
  1370. class TwoLayerLinearModel(torch.nn.Module):
  1371. def __init__(self):
  1372. super().__init__()
  1373. self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
  1374. self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
  1375. def forward(self, x):
  1376. x = self.fc1(x)
  1377. x = self.fc2(x)
  1378. return x
  1379. def get_example_inputs(self) -> Tuple[Any, ...]:
  1380. return (torch.rand(1, 5),)
  1381. class LinearModelWithSubmodule(nn.Module):
  1382. def __init__(self):
  1383. super().__init__()
  1384. self.subm = TwoLayerLinearModel()
  1385. self.fc = nn.Linear(5, 5)
  1386. def forward(self, x):
  1387. x = self.subm(x)
  1388. x = self.fc(x)
  1389. return x
  1390. def get_example_inputs(self) -> Tuple[Any, ...]:
  1391. return self.subm.get_example_inputs()
  1392. class AnnotatedTwoLayerLinearModel(torch.nn.Module):
  1393. def __init__(self):
  1394. super().__init__()
  1395. self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
  1396. self.fc2 = QuantWrapper(torch.nn.Linear(8, 5).to(dtype=torch.float))
  1397. self.fc2.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
  1398. def forward(self, x):
  1399. x = self.fc1(x)
  1400. x = self.fc2(x)
  1401. return x
  1402. def get_example_inputs(self) -> Tuple[Any, ...]:
  1403. return (torch.rand(1, 5),)
  1404. class ActivationsTestModel(torch.nn.Module):
  1405. def __init__(self):
  1406. super().__init__()
  1407. self.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
  1408. self.quant = torch.ao.quantization.QuantStub()
  1409. self.hardswish = torch.nn.Hardswish().to(dtype=torch.float)
  1410. self.elu = torch.nn.ELU().to(dtype=torch.float)
  1411. self.dequant = torch.ao.quantization.DeQuantStub()
  1412. def forward(self, x):
  1413. x = self.quant(x)
  1414. x = self.hardswish(x)
  1415. x = self.elu(x)
  1416. x = self.dequant(x)
  1417. return x
  1418. class LinearReluModel(torch.nn.Module):
  1419. def __init__(self):
  1420. super().__init__()
  1421. self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1422. self.relu = torch.nn.ReLU()
  1423. def forward(self, x):
  1424. x = self.relu(self.fc(x))
  1425. return x
  1426. def get_example_inputs(self) -> Tuple[Any, ...]:
  1427. return (torch.rand(1, 5),)
  1428. class LinearReluLinearModel(torch.nn.Module):
  1429. def __init__(self):
  1430. super().__init__()
  1431. self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
  1432. self.relu = torch.nn.ReLU()
  1433. self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
  1434. def forward(self, x):
  1435. x = self.fc1(x)
  1436. x = self.relu(x)
  1437. x = self.fc2(x)
  1438. return x
  1439. def get_example_inputs(self) -> Tuple[Any, ...]:
  1440. return (torch.rand(1, 5),)
  1441. class LinearReluAddModel(torch.nn.Module):
  1442. def __init__(self):
  1443. super().__init__()
  1444. self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1445. self.relu = torch.nn.ReLU()
  1446. self.fc2 = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1447. def forward(self, x):
  1448. x = self.fc1(x)
  1449. x = self.relu(x)
  1450. x = torch.add(x, 5)
  1451. x = self.fc2(x)
  1452. self.relu = torch.nn.ReLU()
  1453. return x
  1454. def get_example_inputs(self) -> Tuple[Any, ...]:
  1455. return (torch.rand(1, 5),)
  1456. class LinearBnLeakyReluModel(torch.nn.Module):
  1457. def __init__(self, with_bn=True):
  1458. super().__init__()
  1459. self.linear = nn.Linear(5, 5)
  1460. self.bn1d = nn.BatchNorm1d(5)
  1461. self.leaky_relu = nn.LeakyReLU(0.01)
  1462. self.with_bn = with_bn
  1463. def forward(self, x):
  1464. x = self.linear(x)
  1465. if self.with_bn:
  1466. x = self.bn1d(x)
  1467. x = self.leaky_relu(x)
  1468. return x
  1469. def get_example_inputs(self) -> Tuple[Any, ...]:
  1470. return (torch.rand(1, 5),)
  1471. class LinearTanhModel(torch.nn.Module):
  1472. def __init__(self):
  1473. super().__init__()
  1474. self.linear = nn.Linear(5, 5)
  1475. self.tanh = nn.Tanh()
  1476. def forward(self, x):
  1477. x = self.linear(x)
  1478. x = self.tanh(x)
  1479. return x
  1480. def get_example_inputs(self) -> Tuple[Any, ...]:
  1481. return (torch.rand(1, 5),)
  1482. class ConvBnAddReluModel(torch.nn.Module):
  1483. def __init__(self,
  1484. with_bn=True,
  1485. with_relu=True,
  1486. left_conv=True,
  1487. two_conv=True,
  1488. use_torch_add=True):
  1489. super().__init__()
  1490. self.conv = nn.Conv2d(5, 5, (2, 2))
  1491. self.conv2 = nn.Conv2d(5, 5, (2, 2))
  1492. self.bn = nn.BatchNorm2d(5)
  1493. self.relu = nn.ReLU()
  1494. self.with_bn = with_bn
  1495. self.with_relu = with_relu
  1496. self.two_conv = two_conv
  1497. self.left_conv = left_conv
  1498. self.use_torch_add = use_torch_add
  1499. def forward(self, x1, x2):
  1500. if self.two_conv:
  1501. if self.use_torch_add:
  1502. if self.with_bn:
  1503. x = torch.add(self.bn(self.conv(x1)), self.conv2(x1))
  1504. else:
  1505. x = torch.add(self.conv(x1), self.conv2(x1))
  1506. else:
  1507. if self.with_bn:
  1508. x = self.bn(self.conv(x1)) + self.conv2(x1)
  1509. else:
  1510. x = self.conv(x1) + self.conv2(x1)
  1511. else:
  1512. if self.use_torch_add:
  1513. if self.left_conv:
  1514. if self.with_bn:
  1515. x = torch.add(self.bn(self.conv(x1)), x2)
  1516. else:
  1517. x = torch.add(self.conv(x1), x2)
  1518. else:
  1519. if self.with_bn:
  1520. x = torch.add(x2, self.bn(self.conv(x1)))
  1521. else:
  1522. x = torch.add(x2, self.conv(x1))
  1523. else:
  1524. if self.left_conv:
  1525. if self.with_bn:
  1526. x = self.bn(self.conv(x1)) + x2
  1527. else:
  1528. x = self.conv(x1) + x2
  1529. else:
  1530. if self.with_bn:
  1531. x = x2 + self.bn(self.conv(x1))
  1532. else:
  1533. x = x2 + self.conv(x1)
  1534. if self.with_relu:
  1535. x = self.relu(x)
  1536. return x
  1537. def get_example_inputs(self) -> Tuple[Any, ...]:
  1538. return (torch.rand(1, 5, 3, 3), torch.rand(1, 5, 2, 2))
  1539. # TODO: self.fc should be self.conv
  1540. class ConvReluModel(torch.nn.Module):
  1541. def __init__(self):
  1542. super().__init__()
  1543. self.fc = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
  1544. self.relu = torch.nn.ReLU()
  1545. def forward(self, x):
  1546. x = self.relu(self.fc(x))
  1547. return x
  1548. def get_example_inputs(self) -> Tuple[Any, ...]:
  1549. return (torch.rand(1, 3, 5, 5),)
  1550. # TODO: self.fc should be self.conv
  1551. class ConvReluConvModel(torch.nn.Module):
  1552. def __init__(self):
  1553. super().__init__()
  1554. self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
  1555. self.relu = torch.nn.ReLU()
  1556. self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float)
  1557. def forward(self, x):
  1558. x = self.fc1(x)
  1559. x = self.relu(x)
  1560. x = self.fc2(x)
  1561. return x
  1562. def get_example_inputs(self) -> Tuple[Any, ...]:
  1563. return (torch.rand(1, 3, 5, 5),)
  1564. # TODO: self.fc should be self.conv
  1565. class ConvReluAddModel(torch.nn.Module):
  1566. def __init__(self):
  1567. super().__init__()
  1568. self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
  1569. self.relu = torch.nn.ReLU()
  1570. self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float)
  1571. def forward(self, x):
  1572. x = self.fc1(x)
  1573. x = self.relu(x)
  1574. x = torch.add(x, 5)
  1575. x = self.fc2(x)
  1576. self.relu = torch.nn.ReLU()
  1577. return x
  1578. def get_example_inputs(self) -> Tuple[Any, ...]:
  1579. return (torch.rand(1, 3, 5, 5),)
  1580. class NormalizationTestModel(torch.nn.Module):
  1581. def __init__(self):
  1582. super().__init__()
  1583. self.quant = torch.ao.quantization.QuantStub()
  1584. self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
  1585. self.layer_norm = torch.nn.LayerNorm(8)
  1586. self.group_norm = torch.nn.GroupNorm(2, 8)
  1587. self.instance_norm1d = torch.nn.InstanceNorm1d(8)
  1588. self.instance_norm2d = torch.nn.InstanceNorm2d(8)
  1589. self.instance_norm3d = torch.nn.InstanceNorm3d(8)
  1590. def forward(self, x):
  1591. x = self.quant(x)
  1592. x = self.fc1(x)
  1593. x = self.layer_norm(x)
  1594. x = self.group_norm(x.unsqueeze(-1).repeat(1, 1, 3))
  1595. x = self.instance_norm1d(x)
  1596. x = self.instance_norm2d(x.unsqueeze(-1))
  1597. x = self.instance_norm3d(x.unsqueeze(-1))
  1598. return x
  1599. class NestedModel(torch.nn.Module):
  1600. def __init__(self):
  1601. super().__init__()
  1602. self.sub1 = LinearReluModel()
  1603. self.sub2 = TwoLayerLinearModel()
  1604. self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1605. def forward(self, x):
  1606. x = self.sub1(x)
  1607. x = self.sub2(x)
  1608. x = self.fc3(x)
  1609. return x
  1610. class AnnotatedNestedModel(torch.nn.Module):
  1611. def __init__(self, qengine):
  1612. super().__init__()
  1613. self.sub1 = LinearReluModel()
  1614. self.sub2 = TwoLayerLinearModel()
  1615. self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
  1616. self.fc3.qconfig = default_qconfig
  1617. self.sub2.fc1 = QuantWrapper(self.sub2.fc1)
  1618. if qengine == 'fbgemm':
  1619. self.sub2.fc1.qconfig = default_per_channel_qconfig
  1620. else:
  1621. self.sub2.fc1.qconfig = default_qconfig
  1622. def forward(self, x):
  1623. x = self.sub1(x)
  1624. x = self.sub2(x)
  1625. x = self.fc3(x)
  1626. return x
  1627. class AnnotatedSubNestedModel(torch.nn.Module):
  1628. def __init__(self):
  1629. super().__init__()
  1630. self.sub1 = LinearReluModel()
  1631. self.sub2 = QuantWrapper(TwoLayerLinearModel())
  1632. self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
  1633. self.fc3.qconfig = default_qconfig
  1634. self.sub2.qconfig = default_qconfig
  1635. def forward(self, x):
  1636. x = self.sub1(x)
  1637. x = self.sub2(x)
  1638. x = self.fc3(x)
  1639. return x
  1640. class AnnotatedCustomConfigNestedModel(torch.nn.Module):
  1641. def __init__(self):
  1642. super().__init__()
  1643. self.sub1 = LinearReluModel()
  1644. self.sub2 = TwoLayerLinearModel()
  1645. self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
  1646. self.fc3.qconfig = default_qconfig
  1647. self.sub2.qconfig = default_qconfig
  1648. custom_options = {
  1649. 'dtype': torch.quint8,
  1650. 'qscheme': torch.per_tensor_affine
  1651. }
  1652. custom_qconfig = QConfig(activation=default_observer.with_args(**custom_options),
  1653. weight=default_weight_observer)
  1654. self.sub2.fc1.qconfig = custom_qconfig
  1655. self.sub2.fc1 = QuantWrapper(self.sub2.fc1)
  1656. self.sub2.fc2 = QuantWrapper(self.sub2.fc2)
  1657. def forward(self, x):
  1658. x = self.sub1(x)
  1659. x = self.sub2(x)
  1660. x = self.fc3(x)
  1661. return x
  1662. class QuantSubModel(torch.nn.Module):
  1663. def __init__(self):
  1664. super().__init__()
  1665. self.sub1 = LinearReluModel()
  1666. self.sub2 = QuantWrapper(TwoLayerLinearModel())
  1667. self.sub2.qconfig = default_qconfig
  1668. self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1669. self.fc3.qconfig = default_qconfig
  1670. def forward(self, x):
  1671. x = self.sub1(x)
  1672. x = self.sub2(x)
  1673. x = self.fc3(x)
  1674. return x
  1675. class InnerModule(torch.nn.Module):
  1676. def __init__(self):
  1677. super().__init__()
  1678. self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
  1679. self.relu1 = torch.nn.ReLU()
  1680. self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
  1681. self.relu2 = torch.nn.ReLU()
  1682. def forward(self, x):
  1683. return self.relu2(self.fc2(self.relu1(self.fc1(x))))
  1684. def fuse_modules(self):
  1685. fusable_layers = []
  1686. named_children = list(self.named_children())
  1687. for idx, (current_name, layer) in enumerate(named_children):
  1688. if isinstance(layer, torch.nn.Linear):
  1689. if idx >= len(named_children) - 1:
  1690. break
  1691. if isinstance(named_children[idx + 1][1], torch.nn.ReLU):
  1692. fusable_layers.append([current_name,
  1693. named_children[idx + 1][0]])
  1694. # TODO: remove this check and define two fuse_modules function on this module
  1695. if self.training:
  1696. torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True)
  1697. else:
  1698. torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True)
  1699. class FunctionalLinear(torch.nn.Module):
  1700. def __init__(self):
  1701. super().__init__()
  1702. self.weight = torch.rand((5, 5))
  1703. self.bias = torch.zeros(5)
  1704. def forward(self, x):
  1705. return F.linear(x, self.weight, self.bias)
  1706. def get_example_inputs(self) -> Tuple[Any, ...]:
  1707. return (torch.rand(1, 5),)
  1708. class SingleLayerFunctionalLinearModel(torch.nn.Module):
  1709. def __init__(self):
  1710. super().__init__()
  1711. self.linear1 = FunctionalLinear()
  1712. def forward(self, x):
  1713. x = self.linear1(x)
  1714. return x
  1715. def get_example_inputs(self) -> Tuple[Any, ...]:
  1716. return self.linear1.get_example_inputs()
  1717. class TwoLayerFunctionalLinearModel(torch.nn.Module):
  1718. def __init__(self):
  1719. super().__init__()
  1720. self.linear1 = FunctionalLinear()
  1721. self.linear2 = FunctionalLinear()
  1722. def forward(self, x):
  1723. x = self.linear1(x)
  1724. x = self.linear2(x)
  1725. return x
  1726. def get_example_inputs(self) -> Tuple[Any, ...]:
  1727. return self.linear1.get_example_inputs()
  1728. class FunctionalLinearAddModel(torch.nn.Module):
  1729. def __init__(self):
  1730. super().__init__()
  1731. self.linear1 = FunctionalLinear()
  1732. self.linear2 = FunctionalLinear()
  1733. def forward(self, x):
  1734. x = self.linear1(x)
  1735. x = torch.add(x, 5)
  1736. x = self.linear2(x)
  1737. return x
  1738. def get_example_inputs(self) -> Tuple[Any, ...]:
  1739. return self.linear1.get_example_inputs()
  1740. class FunctionalLinearReluModel(nn.Module):
  1741. def __init__(self):
  1742. super().__init__()
  1743. self.linear = FunctionalLinear()
  1744. def forward(self, x):
  1745. x = self.linear(x)
  1746. x = F.relu(x)
  1747. return x
  1748. def get_example_inputs(self) -> Tuple[Any, ...]:
  1749. return self.linear.get_example_inputs()
  1750. class FunctionalLinearReluLinearModel(nn.Module):
  1751. def __init__(self):
  1752. super().__init__()
  1753. self.linear1 = FunctionalLinear()
  1754. self.relu = nn.ReLU()
  1755. self.linear2 = FunctionalLinear()
  1756. def forward(self, x):
  1757. x = self.linear1(x)
  1758. x = self.relu(x)
  1759. x = self.linear2(x)
  1760. return x
  1761. def get_example_inputs(self) -> Tuple[Any, ...]:
  1762. return self.linear1.get_example_inputs()
  1763. class FunctionalConv2d(torch.nn.Module):
  1764. def __init__(self):
  1765. super().__init__()
  1766. self.weight = torch.rand(3, 3, 3, 3)
  1767. self.bias = torch.rand(3)
  1768. self.stride = (1, 1)
  1769. self.padding = (0, 0)
  1770. self.dilation = (1, 1)
  1771. self.groups = 1
  1772. def forward(self, x):
  1773. return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
  1774. def get_example_inputs(self) -> Tuple[Any, ...]:
  1775. return (torch.rand(1, 3, 5, 5),)
  1776. class SingleLayerFunctionalConvModel(torch.nn.Module):
  1777. def __init__(self):
  1778. super().__init__()
  1779. self.conv1 = FunctionalConv2d()
  1780. def forward(self, x):
  1781. x = self.conv1(x)
  1782. return x
  1783. def get_example_inputs(self) -> Tuple[Any, ...]:
  1784. return self.conv1.get_example_inputs()
  1785. class TwoLayerFunctionalConvModel(torch.nn.Module):
  1786. def __init__(self):
  1787. super().__init__()
  1788. self.conv1 = FunctionalConv2d()
  1789. self.conv2 = FunctionalConv2d()
  1790. def forward(self, x):
  1791. x = self.conv1(x)
  1792. x = self.conv2(x)
  1793. return x
  1794. def get_example_inputs(self) -> Tuple[Any, ...]:
  1795. return self.conv1.get_example_inputs()
  1796. class FunctionalConvReluModel(nn.Module):
  1797. def __init__(self):
  1798. super().__init__()
  1799. self.conv = FunctionalConv2d()
  1800. def forward(self, x):
  1801. x = self.conv(x)
  1802. x = F.relu(x)
  1803. return x
  1804. def get_example_inputs(self) -> Tuple[Any, ...]:
  1805. return self.conv.get_example_inputs()
  1806. class FunctionalConvReluConvModel(nn.Module):
  1807. def __init__(self):
  1808. super().__init__()
  1809. self.conv1 = FunctionalConv2d()
  1810. self.relu = nn.ReLU()
  1811. self.conv2 = FunctionalConv2d()
  1812. def forward(self, x):
  1813. x = self.conv1(x)
  1814. x = self.relu(x)
  1815. x = self.conv2(x)
  1816. return x
  1817. def get_example_inputs(self) -> Tuple[Any, ...]:
  1818. return self.conv1.get_example_inputs()
  1819. class SkipQuantModel(torch.nn.Module):
  1820. r"""We can skip quantization by explicitly
  1821. setting qconfig of a submodule to None
  1822. """
  1823. def __init__(self):
  1824. super().__init__()
  1825. self.sub = InnerModule()
  1826. self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1827. def forward(self, x):
  1828. return self.fc(self.sub(x))
  1829. def fuse_modules(self):
  1830. self.sub.fuse_modules()
  1831. class AnnotatedSkipQuantModel(torch.nn.Module):
  1832. r"""We can skip quantization by explicitly
  1833. setting qconfig of a submodule to None
  1834. """
  1835. def __init__(self, qengine):
  1836. super().__init__()
  1837. self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  1838. self.sub = QuantWrapper(InnerModule())
  1839. self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1840. # don't quantize this fc
  1841. self.fc.qconfig = None
  1842. def forward(self, x):
  1843. return self.fc(self.sub(x))
  1844. def fuse_modules(self):
  1845. self.sub.module.fuse_modules()
  1846. class QuantStubModel(torch.nn.Module):
  1847. r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
  1848. """
  1849. def __init__(self):
  1850. super().__init__()
  1851. self.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
  1852. self.quant = QuantStub()
  1853. self.dequant = DeQuantStub()
  1854. self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1855. def forward(self, x):
  1856. x = self.quant(x)
  1857. x = self.fc(x)
  1858. return self.dequant(x)
  1859. class ManualLinearQATModel(torch.nn.Module):
  1860. r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
  1861. """
  1862. def __init__(self, qengine):
  1863. super().__init__()
  1864. self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
  1865. self.quant = QuantStub()
  1866. self.dequant = DeQuantStub()
  1867. self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
  1868. self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float)
  1869. def forward(self, x):
  1870. x = self.quant(x)
  1871. x = self.fc1(x)
  1872. x = self.fc2(x)
  1873. return self.dequant(x)
  1874. class ManualDropoutQATModel(torch.nn.Module):
  1875. r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
  1876. """
  1877. def __init__(self, qengine):
  1878. super().__init__()
  1879. self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
  1880. self.quant = QuantStub()
  1881. self.dequant = DeQuantStub()
  1882. self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
  1883. self.dropout = torch.nn.Dropout(0.5)
  1884. def forward(self, x):
  1885. x = self.quant(x)
  1886. x = self.fc1(x)
  1887. x = self.dropout(x)
  1888. return self.dequant(x)
  1889. class ManualLinearDynamicQATModel(torch.nn.Module):
  1890. r"""A Module that uses a dynamic QAT by default.
  1891. """
  1892. def __init__(self, qconfig=None):
  1893. super().__init__()
  1894. self.qconfig = qconfig or default_dynamic_qat_qconfig
  1895. self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
  1896. self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float)
  1897. def forward(self, x):
  1898. x = self.fc1(x)
  1899. x = self.fc2(x)
  1900. return x
  1901. class ManualConvLinearQATModel(torch.nn.Module):
  1902. r"""A module with manually inserted `QuantStub` and `DeQuantStub`
  1903. and contains both linear and conv modules
  1904. """
  1905. def __init__(self, qconfig=None):
  1906. super().__init__()
  1907. self.qconfig = qconfig if qconfig else torch.ao.quantization.get_default_qat_qconfig("qnnpack")
  1908. self.quant = QuantStub()
  1909. self.dequant = DeQuantStub()
  1910. self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float)
  1911. self.fc1 = torch.nn.Linear(64, 10).to(dtype=torch.float)
  1912. self.fc2 = torch.nn.Linear(10, 10).to(dtype=torch.float)
  1913. def forward(self, x):
  1914. x = self.quant(x)
  1915. x = self.conv(x)
  1916. x = x.view(-1, 64).contiguous()
  1917. x = self.fc1(x)
  1918. x = self.fc2(x)
  1919. return self.dequant(x)
  1920. class ManualConvLinearSymmQATModel(ManualConvLinearQATModel):
  1921. r"""Same as ManualConvLinearQATModule but with Symmetric Quantization.
  1922. Supported only with qnnpack.
  1923. """
  1924. def __init__(self):
  1925. super().__init__(default_symmetric_qnnpack_qat_qconfig)
  1926. class ManualEmbeddingBagLinear(nn.Module):
  1927. def __init__(self):
  1928. super().__init__()
  1929. self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode='sum')
  1930. self.emb.qconfig = default_embedding_qat_qconfig
  1931. self.quant = QuantStub()
  1932. self.dequant = DeQuantStub()
  1933. self.linear = nn.Linear(12, 1).to(dtype=torch.float)
  1934. self.qconfig = get_default_qat_qconfig("qnnpack")
  1935. def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None,
  1936. per_sample_weights: Optional[torch.Tensor] = None):
  1937. x = self.emb(input, offsets, per_sample_weights)
  1938. x = self.quant(x)
  1939. x = self.linear(x)
  1940. return self.dequant(x)
  1941. class DeFusedEmbeddingBagLinear(nn.Module):
  1942. r"""A module to simulate QAT embedding bag with a linear layer,
  1943. this module uses a separate embedding and bagging op, similar
  1944. to that which is described in the EmbeddingBag documentation.
  1945. https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html
  1946. """
  1947. def __init__(self) -> None:
  1948. super().__init__()
  1949. self.emb = nn.Embedding(num_embeddings=10, embedding_dim=12)
  1950. self.emb.qconfig = default_embedding_qat_qconfig
  1951. self.bagging_op = torch.sum
  1952. self.quant = QuantStub()
  1953. self.dequant = DeQuantStub()
  1954. self.linear = nn.Linear(12, 1).to(dtype=torch.float)
  1955. self.qconfig = get_default_qat_qconfig("qnnpack")
  1956. def forward(self, input: torch.Tensor) -> torch.Tensor:
  1957. x = self.bagging_op(self.emb(input), dim=1)
  1958. x = self.quant(x)
  1959. x = self.linear(x)
  1960. return self.dequant(x)
  1961. class SubModelForFusion(nn.Module):
  1962. def __init__(self):
  1963. super().__init__()
  1964. self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float)
  1965. self.bn = nn.BatchNorm2d(2).to(dtype=torch.float)
  1966. def forward(self, x):
  1967. x = self.conv(x)
  1968. x = self.bn(x)
  1969. return x
  1970. class SubModelWithoutFusion(nn.Module):
  1971. def __init__(self):
  1972. super().__init__()
  1973. self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float)
  1974. self.relu = nn.ReLU(inplace=False).to(dtype=torch.float)
  1975. def forward(self, x):
  1976. return self.relu(self.conv(x))
  1977. class ModelForFusion(nn.Module):
  1978. def __init__(self, qconfig):
  1979. super().__init__()
  1980. self.conv1 = nn.Conv2d(3, 2, 1, bias=None).to(dtype=torch.float)
  1981. self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float)
  1982. self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float)
  1983. self.sub1 = SubModelForFusion()
  1984. self.sub2 = SubModelWithoutFusion()
  1985. self.fc = nn.Linear(36, 10).to(dtype=torch.float)
  1986. self.quant = QuantStub()
  1987. self.dequant = DeQuantStub()
  1988. self.qconfig = qconfig
  1989. self.conv2 = nn.Conv3d(3, 2, (1, 1, 1), bias=None).to(dtype=torch.float)
  1990. self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float)
  1991. self.bn2 = nn.BatchNorm3d(2).to(dtype=torch.float)
  1992. self.relu3 = nn.ReLU(inplace=True).to(dtype=torch.float)
  1993. self.conv3 = nn.Conv1d(3, 3, 2).to(dtype=torch.float)
  1994. self.bn3 = nn.BatchNorm1d(3).to(dtype=torch.float)
  1995. self.relu4 = nn.ReLU(inplace=True).to(dtype=torch.float)
  1996. # don't quantize sub2
  1997. self.sub2.qconfig = None
  1998. self.fc.qconfig = None
  1999. def forward(self, x):
  2000. x = x.squeeze(2)
  2001. x = self.quant(x)
  2002. x = self.conv3(x)
  2003. x = self.bn3(x)
  2004. x = self.relu4(x)
  2005. x = x.unsqueeze(2)
  2006. y = x.unsqueeze(2)
  2007. x = self.conv1(x)
  2008. x = self.bn1(x)
  2009. x = self.relu1(x)
  2010. x = self.sub1(x)
  2011. x = self.dequant(x)
  2012. x = self.sub2(x)
  2013. x = x.reshape(-1, 36).contiguous()
  2014. x = self.fc(x)
  2015. y = self.conv2(y)
  2016. y = self.relu2(y)
  2017. y = self.bn2(y)
  2018. y = self.relu3(y)
  2019. y = self.dequant(y)
  2020. return x
  2021. class ConvBNReLU(nn.Sequential):
  2022. def __init__(self):
  2023. super().__init__(
  2024. nn.Conv2d(3, 3, 1, 1, bias=False),
  2025. nn.BatchNorm2d(3),
  2026. nn.ReLU(inplace=False)
  2027. )
  2028. class ModelWithSequentialFusion(nn.Module):
  2029. def __init__(self):
  2030. super().__init__()
  2031. self.conv1 = nn.Conv2d(3, 3, 1)
  2032. self.relu1 = nn.ReLU(inplace=False)
  2033. layers = []
  2034. for i in range(3):
  2035. layers.append(ConvBNReLU())
  2036. self.features = nn.Sequential(*layers)
  2037. head = [nn.Linear(300, 10), nn.ReLU(inplace=False)]
  2038. self.classifier = nn.Sequential(*head)
  2039. self.seq = nn.Sequential()
  2040. self.quant = QuantStub()
  2041. self.dequant = DeQuantStub()
  2042. def forward(self, x):
  2043. x = self.quant(x)
  2044. x = self.conv1(x)
  2045. x = self.relu1(x)
  2046. x = self.features(x)
  2047. x = torch.reshape(x, (-1, 3 * 10 * 10))
  2048. x = self.classifier(x)
  2049. x = self.seq(x)
  2050. x = self.dequant(x)
  2051. return x
  2052. class ModelForFusionWithBias(nn.Module):
  2053. def __init__(self):
  2054. super().__init__()
  2055. self.conv1 = nn.Conv2d(3, 2, 5, bias=True).to(dtype=torch.float)
  2056. self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float)
  2057. self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float)
  2058. self.conv2 = nn.Conv2d(2, 2, 1, bias=True).to(dtype=torch.float)
  2059. self.bn2 = nn.BatchNorm2d(2).to(dtype=torch.float)
  2060. self.quant = QuantStub()
  2061. self.dequant = DeQuantStub()
  2062. def forward(self, x):
  2063. x = self.quant(x)
  2064. x = self.conv1(x)
  2065. x = self.bn1(x)
  2066. x = self.relu1(x)
  2067. x = self.conv2(x)
  2068. x = self.bn2(x)
  2069. x = self.dequant(x)
  2070. return x
  2071. class ModelForLinearBNFusion(nn.Module):
  2072. def __init__(self):
  2073. super().__init__()
  2074. self.fc = nn.Linear(20, 10)
  2075. self.bn = nn.BatchNorm1d(10)
  2076. nn.init.uniform_(self.bn.weight)
  2077. nn.init.uniform_(self.bn.bias)
  2078. def forward(self, x):
  2079. return self.bn(self.fc(x))
  2080. class DummyObserver(torch.nn.Module):
  2081. def calculate_qparams(self):
  2082. return 1.0, 0
  2083. def forward(self, x):
  2084. return x
  2085. class ModelForConvTransposeBNFusion(nn.Module):
  2086. def __init__(self):
  2087. super().__init__()
  2088. self.conv1 = nn.ConvTranspose1d(3, 3, 1)
  2089. self.bn1 = nn.BatchNorm1d(3)
  2090. self.conv2 = nn.ConvTranspose2d(3, 3, 1)
  2091. self.bn2 = nn.BatchNorm2d(3)
  2092. self.conv3 = nn.ConvTranspose3d(3, 3, 1)
  2093. self.bn3 = nn.BatchNorm3d(3)
  2094. def forward(self, x):
  2095. x = self.conv1(x)
  2096. x = self.bn1(x)
  2097. x = x.unsqueeze(2)
  2098. x = self.conv2(x)
  2099. x = self.bn2(x)
  2100. x = x.unsqueeze(2)
  2101. x = self.conv3(x)
  2102. x = self.bn3(x)
  2103. return x
  2104. class ModelWithFunctionals(torch.nn.Module):
  2105. def __init__(self):
  2106. super().__init__()
  2107. self.mycat = nnq.FloatFunctional()
  2108. self.myadd = nnq.FloatFunctional()
  2109. self.myadd_relu = nnq.FloatFunctional()
  2110. self.mymatmul = nnq.FloatFunctional()
  2111. # Tracing doesnt work yet for c10 ops with scalar inputs
  2112. # https://github.com/pytorch/pytorch/issues/27097
  2113. # self.my_scalar_add = nnq.FloatFunctional()
  2114. # self.my_scalar_mul = nnq.FloatFunctional()
  2115. def forward(self, x):
  2116. y = self.mycat.cat([x, x, x])
  2117. z = self.myadd.add(y, y)
  2118. w = self.myadd_relu.add_relu(z, z)
  2119. u = self.mymatmul.matmul(w, w.T)
  2120. # Tracing doesnt work yet for c10 ops with scalar inputs
  2121. # https://github.com/pytorch/pytorch/issues/27097
  2122. # w = self.my_scalar_add.add_scalar(w, -0.5)
  2123. # w = self.my_scalar_mul.mul_scalar(w, 0.5)
  2124. return u
  2125. class ResNetBase(torch.nn.Module):
  2126. def __init__(self):
  2127. super().__init__()
  2128. norm_layer = nn.BatchNorm2d
  2129. inplanes = 3
  2130. self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
  2131. self.bn1 = norm_layer(inplanes)
  2132. self.relu1 = nn.ReLU()
  2133. self.relu2 = nn.ReLU()
  2134. self.downsample = torch.nn.Identity()
  2135. self.myop = nn.quantized.FloatFunctional()
  2136. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  2137. self.fc = torch.nn.Linear(inplanes, 1)
  2138. def forward(self, x):
  2139. out = self.conv1(x)
  2140. out = self.bn1(out)
  2141. out = self.relu1(out)
  2142. identity = self.downsample(x)
  2143. out = self.myop.add(out, identity)
  2144. out = self.relu2(out)
  2145. out = self.avgpool(out)
  2146. out = torch.flatten(out, 1)
  2147. out = self.fc(out)
  2148. return out
  2149. def fuse_model(self):
  2150. # TODO: remove this check and define two fuse_model function on this module
  2151. if self.training:
  2152. torch.ao.quantization.fuse_modules_qat(self, [['conv1', 'bn1', 'relu1']], inplace=True)
  2153. else:
  2154. torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True)
  2155. class ModelMultipleOps(torch.nn.Module):
  2156. def __init__(self):
  2157. super().__init__()
  2158. norm_layer = nn.BatchNorm2d
  2159. inplanes = 3
  2160. self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
  2161. self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
  2162. self.bn1 = norm_layer(inplanes)
  2163. self.relu1 = nn.ReLU()
  2164. self.relu2 = nn.ReLU()
  2165. self.downsample = torch.nn.Identity()
  2166. self.skip_add = nn.quantized.FloatFunctional()
  2167. self.cat = nn.quantized.FloatFunctional()
  2168. self.avgpool = nn.AdaptiveAvgPool2d((4, 4))
  2169. self.fc = nn.Linear(12, 6)
  2170. def forward(self, x):
  2171. out = self.conv1(x)
  2172. out = self.bn1(out)
  2173. out = self.relu1(out)
  2174. identity = self.downsample(x)
  2175. out = self.skip_add.add(out, identity)
  2176. out = self.relu2(out)
  2177. out = self.avgpool(out)
  2178. out = self.conv2(out)
  2179. out = torch.nn.functional.max_pool2d(out, 2, 2)
  2180. out = self.cat.cat([out, out])
  2181. out = out.reshape(-1, 3 * 2 * 2)
  2182. out = self.fc(out)
  2183. return out
  2184. # Model to ensure consistency of fake quant with true quant
  2185. # Average pooling and mean operations are not modelled
  2186. # accurately with fake-quant so this model does not
  2187. # contain those operations
  2188. class ModelMultipleOpsNoAvgPool(torch.nn.Module):
  2189. def __init__(self):
  2190. super().__init__()
  2191. norm_layer = nn.BatchNorm2d
  2192. inplanes = 3
  2193. self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
  2194. self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
  2195. self.bn1 = norm_layer(inplanes)
  2196. self.relu1 = nn.ReLU()
  2197. self.relu2 = nn.ReLU()
  2198. self.skip_add = nn.quantized.FloatFunctional()
  2199. self.cat = nn.quantized.FloatFunctional()
  2200. self.maxpool = nn.MaxPool2d((4, 4))
  2201. self.fc = nn.Linear(12, 6)
  2202. def forward(self, x):
  2203. out = self.conv1(x)
  2204. out = self.bn1(out)
  2205. out = self.relu1(out)
  2206. skip = self.conv2(x)
  2207. out = self.skip_add.add(out, skip)
  2208. out = self.relu2(out)
  2209. out = self.maxpool(out)
  2210. out = self.conv2(out)
  2211. out = torch.nn.functional.max_pool2d(out, 2, 2)
  2212. out = self.cat.cat([out, out])
  2213. out = out.reshape(-1, 3 * 2 * 2)
  2214. out = self.fc(out)
  2215. return out
  2216. class EmbeddingBagModule(torch.nn.Module):
  2217. def __init__(self):
  2218. super().__init__()
  2219. self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
  2220. include_last_offset=True, scale_grad_by_freq=False, mode='sum')
  2221. def forward(self, indices, offsets, per_sample_weights):
  2222. return self.emb(indices, offsets, per_sample_weights)
  2223. class EmbeddingModule(torch.nn.Module):
  2224. def __init__(self):
  2225. super().__init__()
  2226. self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
  2227. def forward(self, indices):
  2228. return self.emb(indices)
  2229. class EmbeddingWithStaticLinear(torch.nn.Module):
  2230. def __init__(self):
  2231. super().__init__()
  2232. self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12)
  2233. self.fc = torch.nn.Linear(4, 2)
  2234. self.emb.qconfig = float_qparams_weight_only_qconfig
  2235. self.qconfig = default_qconfig
  2236. self.quant = QuantStub()
  2237. self.dequant = DeQuantStub()
  2238. def forward(self, indices, offsets, linear_in):
  2239. emb = self.emb(indices, offsets)
  2240. q_x = self.quant(linear_in)
  2241. fc = self.fc(q_x)
  2242. fc = self.dequant(fc)
  2243. features = torch.cat([fc] + [emb], dim=1)
  2244. return features
  2245. class DenseTopMLP(nn.Module):
  2246. def __init__(self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out) -> None:
  2247. super().__init__()
  2248. self.dense_mlp = nn.Sequential(
  2249. nn.Linear(dense_dim, dense_out),
  2250. )
  2251. self.top_mlp = nn.Sequential(
  2252. nn.Linear(dense_out + embedding_dim, top_out_in),
  2253. nn.Linear(top_out_in, top_out_out),
  2254. )
  2255. def forward(
  2256. self,
  2257. sparse_feature: torch.Tensor,
  2258. dense: torch.Tensor,
  2259. ) -> torch.Tensor:
  2260. dense_feature = self.dense_mlp(dense)
  2261. features = torch.cat([dense_feature] + [sparse_feature], dim=1)
  2262. out = self.top_mlp(features)
  2263. return out
  2264. # thin wrapper around embedding bag, because tracing inside nn.Embedding
  2265. # bag is not supported at the moment and this is top level
  2266. class EmbBagWrapper(nn.Module):
  2267. def __init__(self, num_embeddings, embedding_dim):
  2268. super().__init__()
  2269. self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode='sum')
  2270. def forward(self, indices, offsets):
  2271. return self.emb_bag(indices, offsets)
  2272. class SparseNNModel(nn.Module):
  2273. _NUM_EMBEDDINGS = 10
  2274. _EMBEDDING_DIM = 5
  2275. _DENSE_DIM = 4
  2276. _DENSE_OUTPUT = 2
  2277. _TOP_OUT_IN = 2
  2278. _TOP_OUT_OUT = 2
  2279. _TOP_MLP_DIM = 1
  2280. def __init__(self) -> None:
  2281. super().__init__()
  2282. self.model_sparse = EmbBagWrapper(self._NUM_EMBEDDINGS, self._EMBEDDING_DIM)
  2283. self.dense_top = DenseTopMLP(
  2284. self._DENSE_DIM, self._DENSE_OUTPUT, self._EMBEDDING_DIM, self._TOP_OUT_IN,
  2285. self._TOP_OUT_OUT)
  2286. def forward(
  2287. self,
  2288. sparse_indices: torch.Tensor,
  2289. sparse_offsets: torch.Tensor,
  2290. dense: torch.Tensor,
  2291. ) -> torch.Tensor:
  2292. sparse_feature = self.model_sparse(sparse_indices, sparse_offsets)
  2293. out = self.dense_top(sparse_feature, dense)
  2294. return out
  2295. class TestHelperModules:
  2296. class Conv2dPropAnnotaton(torch.nn.Module):
  2297. def __init__(self):
  2298. super().__init__()
  2299. self.conv = torch.nn.Conv2d(3, 3, 3)
  2300. self.linear = torch.nn.Linear(3, 3)
  2301. def forward(self, x):
  2302. x = self.conv(x)
  2303. x = x.view(-1, 3)
  2304. x = torch.nn.functional.hardtanh(x, -0.5, 0.5)
  2305. x = self.linear(x)
  2306. return x
  2307. class Conv2dWithObsSharingOps(torch.nn.Module):
  2308. def __init__(self):
  2309. super().__init__()
  2310. self.conv = torch.nn.Conv2d(3, 3, 3)
  2311. self.hardtanh = torch.nn.Hardtanh()
  2312. self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
  2313. def forward(self, x):
  2314. x = self.conv(x)
  2315. x = self.adaptive_avg_pool2d(x)
  2316. x = self.hardtanh(x)
  2317. x = torch.mean(x)
  2318. return x
  2319. class Conv2dWithTwoLinearPermute(torch.nn.Module):
  2320. def __init__(self):
  2321. super().__init__()
  2322. self.conv = torch.nn.Conv2d(3, 16, 3)
  2323. self.linear1 = torch.nn.Linear(16, 8, bias=False)
  2324. self.linear2 = torch.nn.Linear(8, 8)
  2325. def forward(self, x):
  2326. conv_out = self.conv(x)
  2327. permute_out = torch.permute(conv_out, (0, 2, 3, 1))
  2328. return self.linear2(self.linear1(permute_out))
  2329. class Conv2dWithTwoLinear(torch.nn.Module):
  2330. def __init__(self):
  2331. super().__init__()
  2332. self.conv = torch.nn.Conv2d(3, 16, 3)
  2333. self.linear1 = torch.nn.Linear(64, 8, bias=False)
  2334. self.linear2 = torch.nn.Linear(8, 8)
  2335. def forward(self, x):
  2336. conv_out = self.conv(x)
  2337. reshape_out = torch.reshape(conv_out, (2, 64))
  2338. return self.linear2(self.linear1(reshape_out))
  2339. class ConvLinearWPermute(torch.nn.Module):
  2340. def __init__(self):
  2341. super().__init__()
  2342. self.conv = torch.nn.Conv2d(3, 8, 3)
  2343. self.linear1 = torch.nn.Linear(8, 8)
  2344. def forward(self, x):
  2345. conv_out = self.conv(x)
  2346. permute_out = torch.permute(conv_out, (0, 2, 3, 1))
  2347. return self.linear1(permute_out)
  2348. class TwoLinearModule(torch.nn.Module):
  2349. def __init__(self):
  2350. super().__init__()
  2351. self.linear1 = torch.nn.Linear(8, 16, bias=False)
  2352. self.linear2 = torch.nn.Linear(16, 8)
  2353. def forward(self, x):
  2354. return self.linear2(self.linear1(x))
  2355. class ConvMaxPool2d(torch.nn.Module):
  2356. def __init__(self):
  2357. super().__init__()
  2358. self.conv = torch.nn.Conv2d(2, 2, 1)
  2359. self.pool = torch.nn.MaxPool2d(1, 1)
  2360. def forward(self, x):
  2361. x = self.conv(x)
  2362. x = self.pool(x)
  2363. return x
  2364. class ConvWithAdaptiveAvgPool2d(torch.nn.Module):
  2365. def __init__(self):
  2366. super().__init__()
  2367. self.conv = torch.nn.Conv2d(3, 3, 3)
  2368. self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
  2369. def forward(self, x):
  2370. x = self.conv(x)
  2371. x = self.adaptive_avg_pool2d(x)
  2372. return x
  2373. class ConvWithBNRelu(torch.nn.Module):
  2374. def __init__(self, relu, dim=2, bn=True, bias=True):
  2375. super().__init__()
  2376. convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d}
  2377. bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
  2378. self.conv = convs[dim](3, 3, 3, bias=bias)
  2379. if bn:
  2380. self.bn = bns[dim](3)
  2381. else:
  2382. self.bn = torch.nn.Identity()
  2383. if relu:
  2384. self.relu = torch.nn.ReLU()
  2385. else:
  2386. self.relu = torch.nn.Identity()
  2387. def forward(self, x):
  2388. x = self.conv(x)
  2389. x = self.bn(x)
  2390. return self.relu(x)
  2391. class ConvTWithBNRelu(torch.nn.Module):
  2392. def __init__(self, relu, dim=2, bn=True, bias=True):
  2393. super().__init__()
  2394. convts = {1: torch.nn.ConvTranspose1d, 2: torch.nn.ConvTranspose2d}
  2395. bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
  2396. self.convt = convts[dim](3, 3, 3, bias=bias)
  2397. if bn:
  2398. self.bn = bns[dim](3)
  2399. else:
  2400. self.bn = torch.nn.Identity()
  2401. if relu:
  2402. self.relu = torch.nn.ReLU()
  2403. else:
  2404. self.relu = torch.nn.Identity()
  2405. def forward(self, x):
  2406. x = self.convt(x)
  2407. x = self.bn(x)
  2408. return self.relu(x)
  2409. class Conv2dThenConv1d(torch.nn.Module):
  2410. def __init__(self):
  2411. super().__init__()
  2412. self.conv1d = torch.nn.Conv1d(3, 3, 3)
  2413. self.conv2d = torch.nn.Conv2d(3, 3, 3)
  2414. def forward(self, x):
  2415. x = self.conv2d(x)
  2416. x = x.squeeze(0)
  2417. x = self.conv1d(x)
  2418. return x
  2419. def example_inputs(self):
  2420. return (torch.randn(1, 3, 5, 5),)
  2421. class Conv2dWithCat(torch.nn.Module):
  2422. def __init__(self):
  2423. super().__init__()
  2424. self.conv1 = torch.nn.Conv2d(3, 3, 3)
  2425. self.conv2 = torch.nn.Conv2d(3, 3, 3)
  2426. def forward(self, x, y):
  2427. x = self.conv1(x)
  2428. y = self.conv2(y)
  2429. z = torch.cat([x, y], dim=1)
  2430. return z
  2431. class Conv2dWithTwoCat(torch.nn.Module):
  2432. def __init__(self):
  2433. super().__init__()
  2434. self.conv1 = torch.nn.Conv2d(3, 3, 3)
  2435. self.conv2 = torch.nn.Conv2d(3, 3, 3)
  2436. def forward(self, x1, x2, x3, x4):
  2437. x1 = self.conv1(x1)
  2438. x2 = self.conv2(x2)
  2439. y = torch.cat([x1, x2], dim=1)
  2440. z = x3 + x4
  2441. w = torch.cat([z, y])
  2442. return w
  2443. class ThreeAdd(torch.nn.Module):
  2444. def forward(self, x1, x2, x3, x4):
  2445. y = x1 + x2
  2446. z = x3 + x4
  2447. w = y + z
  2448. return w
  2449. class EmbeddingModule(torch.nn.Module):
  2450. def __init__(self):
  2451. super().__init__()
  2452. self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
  2453. def forward(self, indices):
  2454. return self.emb(indices)
  2455. class EmbeddingConvLinearModule(torch.nn.Module):
  2456. def __init__(self):
  2457. super().__init__()
  2458. self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8)
  2459. self.conv = torch.nn.Conv2d(8, 16, (1, 3))
  2460. self.linear = torch.nn.Linear(16, 8)
  2461. def forward(self, indices):
  2462. embeddings = self.emb(indices)
  2463. embeddings = torch.unsqueeze(embeddings, dim=0)
  2464. embeddings = torch.permute(embeddings, (0, 3, 1, 2))
  2465. conv_out = self.conv(embeddings)
  2466. conv_out = torch.permute(conv_out, (0, 2, 3, 1))
  2467. conv_out = torch.squeeze(conv_out, dim=0)
  2468. return self.linear(conv_out)
  2469. class AddInplaceAdd(torch.nn.Module):
  2470. def forward(self, x, y):
  2471. x = x + y
  2472. x += y
  2473. return x
  2474. class MulInplaceMul(torch.nn.Module):
  2475. def forward(self, x, y):
  2476. x = x * y
  2477. x *= y
  2478. return x
  2479. class AddMulScalar(torch.nn.Module):
  2480. def forward(self, x):
  2481. x = x + 3
  2482. x = x * 3
  2483. x += 3
  2484. x *= 3
  2485. return x
  2486. class ConvBnReLU2dAndLinearReLU(torch.nn.Module):
  2487. def __init__(self):
  2488. super().__init__()
  2489. self.conv_bn_relu = TestHelperModules.ConvWithBNRelu(relu=True)
  2490. self.linear = torch.nn.Linear(3, 8, bias=False)
  2491. self.relu = torch.nn.ReLU()
  2492. def forward(self, x):
  2493. x = self.conv_bn_relu(x)
  2494. permute_out = torch.permute(x, (0, 2, 3, 1))
  2495. linear_out = self.linear(permute_out)
  2496. return linear_out
  2497. class GroupwiseConv2d(torch.nn.Module):
  2498. def __init__(self):
  2499. super().__init__()
  2500. self.conv = torch.nn.Conv2d(4, 4, 3, groups=2)
  2501. def forward(self, x):
  2502. return self.conv(x)
  2503. def example_inputs(self):
  2504. return (torch.randn(2, 4, 10, 10),)
  2505. class LinearReluModel(torch.nn.Module):
  2506. def __init__(self):
  2507. super().__init__()
  2508. self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
  2509. self.relu = torch.nn.ReLU()
  2510. def forward(self, x):
  2511. x = self.relu(self.fc(x))
  2512. return x