fx.py 56 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512
  1. # coding=utf-8
  2. # Copyright 2021 The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import builtins
  16. import collections
  17. import contextlib
  18. import functools
  19. import inspect
  20. import math
  21. import operator
  22. import os
  23. import random
  24. import sys
  25. import warnings
  26. from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
  27. import torch
  28. import torch.utils._pytree as pytree
  29. from torch import nn
  30. from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
  31. from torch.fx._compatibility import compatibility
  32. from torch.fx._symbolic_trace import is_fx_tracing
  33. from torch.fx.proxy import ParameterProxy
  34. from .. import logging
  35. from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache
  36. from ..modeling_utils import PretrainedConfig, PreTrainedModel
  37. from ..models.auto import get_values
  38. from ..models.auto.modeling_auto import (
  39. MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
  40. MODEL_FOR_BACKBONE_MAPPING_NAMES,
  41. MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
  42. MODEL_FOR_CTC_MAPPING_NAMES,
  43. MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
  44. MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
  45. MODEL_FOR_IMAGE_MAPPING_NAMES,
  46. MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
  47. MODEL_FOR_MASKED_LM_MAPPING_NAMES,
  48. MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
  49. MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
  50. MODEL_FOR_PRETRAINING_MAPPING_NAMES,
  51. MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
  52. MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
  53. MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
  54. MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
  55. MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
  56. MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
  57. MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
  58. MODEL_MAPPING_NAMES,
  59. )
  60. from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
  61. from .import_utils import (
  62. ENV_VARS_TRUE_VALUES,
  63. TORCH_FX_REQUIRED_VERSION,
  64. get_torch_version,
  65. is_peft_available,
  66. is_torch_fx_available,
  67. )
  68. if is_peft_available():
  69. from peft import PeftModel
  70. logger = logging.get_logger(__name__)
  71. _IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES
  72. def _generate_supported_model_class_names(
  73. model_name: Type[PretrainedConfig],
  74. supported_tasks: Optional[Union[str, List[str]]] = None,
  75. ) -> List[str]:
  76. task_mapping = {
  77. "default": MODEL_MAPPING_NAMES,
  78. "pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
  79. "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
  80. "masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
  81. "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
  82. "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
  83. "speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
  84. "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
  85. "document-question-answering": MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
  86. "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
  87. "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
  88. "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
  89. "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
  90. "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
  91. "zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
  92. "ctc": MODEL_FOR_CTC_MAPPING_NAMES,
  93. "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
  94. "semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
  95. "backbone": MODEL_FOR_BACKBONE_MAPPING_NAMES,
  96. "image-feature-extraction": MODEL_FOR_IMAGE_MAPPING_NAMES,
  97. }
  98. if supported_tasks is None:
  99. supported_tasks = task_mapping.keys()
  100. if isinstance(supported_tasks, str):
  101. supported_tasks = [supported_tasks]
  102. model_class_names = []
  103. for task in supported_tasks:
  104. class_name = task_mapping[task].get(model_name, None)
  105. if class_name:
  106. model_class_names.append(class_name)
  107. return model_class_names
  108. _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
  109. "altclip",
  110. "albert",
  111. "bart",
  112. "bert",
  113. "blenderbot",
  114. "blenderbot-small",
  115. "bloom",
  116. "clip",
  117. "convnext",
  118. "deberta",
  119. "deberta-v2",
  120. "dinov2",
  121. "distilbert",
  122. "donut-swin",
  123. "electra",
  124. "gpt2",
  125. "gpt_neo",
  126. "gptj",
  127. "hiera",
  128. "hubert",
  129. "layoutlm",
  130. "llama",
  131. "cohere",
  132. "lxmert",
  133. "m2m_100",
  134. "marian",
  135. "mbart",
  136. "megatron-bert",
  137. "mistral",
  138. "mixtral",
  139. "mobilebert",
  140. "mt5",
  141. "nezha",
  142. "opt",
  143. "pegasus",
  144. "plbart",
  145. "qwen2",
  146. "qwen2_moe",
  147. "resnet",
  148. "roberta",
  149. "segformer",
  150. "speech_to_text",
  151. "speech_to_text_2",
  152. "swin",
  153. "t5",
  154. "trocr",
  155. "vit",
  156. "xglm",
  157. "wav2vec2",
  158. # "xlnet",
  159. ]
  160. _FX_SUPPORTED_MODELS_WITH_KV_CACHE = ["llama", "opt"]
  161. _REGULAR_SUPPORTED_MODELS = []
  162. for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
  163. if isinstance(item, dict):
  164. _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
  165. else:
  166. _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
  167. _SPECIAL_SUPPORTED_MODELS = [
  168. "CLIPTextModel",
  169. "CLIPTextModelWithProjection",
  170. "CLIPVisionModel",
  171. "CLIPVisionModelWithProjection",
  172. "AltCLIPTextModel",
  173. "AltCLIPVisionModel",
  174. "GitVisionModel",
  175. "GPT2DoubleHeadsModel",
  176. "Speech2Text2Decoder",
  177. "TrOCRDecoder",
  178. "PeftModelForCausalLM",
  179. "PeftModelForSeq2SeqLM",
  180. # TODO: add support for them as it should be quite easy to do so (small blocking issues).
  181. # XLNetForQuestionAnswering,
  182. ]
  183. _SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
  184. _CURRENT_TRACER = None
  185. def torch_nn_embedding(self, input):
  186. return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype)
  187. def torch_nn_functional_embedding(
  188. input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
  189. ):
  190. return torch.empty(*input.shape, weight.shape[-1], device="meta", dtype=weight.dtype)
  191. def torch_nn_layernorm(self, input):
  192. return input
  193. def torch_nn_groupnorm(self, input):
  194. return input
  195. def torch_nn_linear(self, input):
  196. return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
  197. def torch_relu(x):
  198. return x
  199. def torch_nn_relu(self, x):
  200. return x
  201. def torch_nn_functional_relu(x, inplace=False):
  202. if not inplace:
  203. raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
  204. return x
  205. def torch_where(condition, x, y):
  206. # torch.where returns the broadcasted tensor of condition, x, and y,
  207. # so hack it by using addition
  208. return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
  209. def torch_abs(input, *, out=None):
  210. if out is not None:
  211. raise ValueError("Don't support in-place abs for MetaTensor analysis")
  212. return input
  213. def torch_arange(*args, **kwargs):
  214. n = len(args)
  215. step = 1
  216. if n == 1:
  217. start = 0
  218. end = args[0]
  219. elif n == 2:
  220. start, end = args
  221. else:
  222. start, end, step = args
  223. if isinstance(start, float):
  224. start = int(start)
  225. if isinstance(end, float):
  226. start = int(end)
  227. if isinstance(step, float):
  228. step = int(step)
  229. step = kwargs.get("step", step)
  230. dtype = kwargs.get("dtype")
  231. return torch.empty((end - start) // step, dtype=dtype, device="meta")
  232. def torch_full(*args, **kwargs):
  233. args = list(args)
  234. # We set the fill value to 1 as its value is not important as long as it's not a tensor on the `meta` device.
  235. if len(args) > 1:
  236. args[1] = 1
  237. else:
  238. kwargs["fill_value"] = 1
  239. kwargs_without_device = dict(kwargs)
  240. kwargs_without_device.pop("device", None)
  241. return torch.full(*args, **kwargs_without_device, device="meta")
  242. def torch_cat(tensors, dim=None, axis=None, *, out=None):
  243. if dim is None and axis is None:
  244. dim = 0
  245. if dim is None and axis is not None:
  246. dim = axis
  247. if dim < 0:
  248. dim = tensors[0].dim() + dim
  249. shapes = [t.shape for t in tensors]
  250. shape = list(shapes[0])
  251. concatenated_dim = sum(shape[dim] for shape in shapes)
  252. final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
  253. return torch.empty(final_shape, device="meta")
  254. def torch_stack(tensors, dim=None, axis=None, *, out=None):
  255. if dim is None and axis is None:
  256. dim = 0
  257. if dim is None and axis is not None:
  258. dim = axis
  259. if dim < 0:
  260. dim = tensors[0].dim() + 1 + dim
  261. shape = list(tensors[0].shape)
  262. shape.insert(dim, len(tensors))
  263. return torch.empty(shape, device="meta")
  264. def torch_add(input, other, *, alpha=1, out=None):
  265. if not isinstance(input, torch.Tensor):
  266. return torch.empty_like(other, device="meta")
  267. if not isinstance(other, torch.Tensor):
  268. return torch.empty_like(input, device="meta")
  269. max_length = max(input.dim(), other.dim())
  270. input_shape = list(input.shape) + [1] * (max_length - input.dim())
  271. other_shape = list(other.shape) + [1] * (max_length - other.dim())
  272. shape = []
  273. for i in range(max_length):
  274. shape.append(max(input_shape[i], other_shape[i]))
  275. return torch.empty(shape, device="meta")
  276. def torch_mul(input, other, *, out=None):
  277. return torch_add(input, other, out=out)
  278. def torch_tensor_mul(self, other):
  279. return torch_mul(self, other)
  280. def torch_matmul(input, other, *, out=None):
  281. d1 = input.dim()
  282. d2 = other.dim()
  283. shape = None
  284. if d1 == 1 and d2 == 1:
  285. shape = None
  286. elif d1 == 2 and d2 == 2:
  287. shape = (input.size(0), other.size(1))
  288. elif d1 == 1 and d2 == 2:
  289. shape = (other.size(1),)
  290. elif d1 == 2 and d1 == 1:
  291. shape = (input.size(0),)
  292. else:
  293. max_length = max(input.dim(), other.dim())
  294. shape1 = list(input.shape)
  295. shape2 = list(other.shape)
  296. if d1 == 1:
  297. shape1 = [1] + shape1
  298. if d2 == 1:
  299. shape2.append(1)
  300. shape1 = [-1] * (max_length - d1) + list(input.shape)
  301. shape2 = [-1] * (max_length - d2) + list(other.shape)
  302. shape = []
  303. for i in range(max_length):
  304. shape.append(max(shape1[i], shape2[i]))
  305. shape[-2] = shape1[-2]
  306. shape[-1] = shape2[-1]
  307. if d1 == 1:
  308. shape.pop(-2)
  309. if d2 == 1:
  310. shape.pop(-1)
  311. if shape is None:
  312. return torch.tensor(0.0, device="meta")
  313. return torch.empty(*shape, device="meta")
  314. def torch_bmm(input, mat2, *, out=None):
  315. if out is not None:
  316. raise ValueError("Don't support in-place bmm for MetaTensor analysis")
  317. batch_size, n, m = input.shape
  318. _, _, p = mat2.shape
  319. return torch.empty(batch_size, n, p, device="meta")
  320. def torch_baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None):
  321. if out is not None:
  322. raise ValueError("Don't support in-place baddbmm for MetaTensor analysis")
  323. return torch_bmm(batch1, batch2)
  324. def torch_tensor_baddbmm(self, batch1, batch2, *, beta=1, alpha=1, out=None):
  325. return torch_baddbmm(self, batch1, batch2, beta=beta, alpha=alpha, out=out)
  326. def torch_einsum(equation, *operands):
  327. # TODO: infer shape without performing the computation, this might be quite hard.
  328. concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
  329. return torch.einsum(equation, *concrete_operands).to("meta")
  330. def torch_tensor_repeat(self, *sizes):
  331. shape = list(self.shape)
  332. for i, x in enumerate(sizes):
  333. shape[i] *= x
  334. return torch.empty(shape, device="meta")
  335. def torch_repeat_interleave(*args, dim=None, output_size=None):
  336. num_args = len(args)
  337. if num_args == 1:
  338. shape = [output_size if output_size is not None else args[0].sum()]
  339. else:
  340. shape = list(args[0].shape)
  341. if dim is None:
  342. if num_args > 2:
  343. dim = args[2]
  344. else:
  345. shape = [sum(shape)]
  346. dim = 0
  347. repeats = args[1]
  348. if isinstance(repeats, int) or torch.numel(repeats) == 1:
  349. shape[dim] *= int(repeats)
  350. else:
  351. shape[dim] = output_size if output_size is not None else repeats.sum()
  352. return torch.empty(*shape, device="meta")
  353. def torch_index_select(input, dim, index, *, out=None):
  354. shape = list(input.shape)
  355. shape[dim] = len(index)
  356. return torch.empty(*shape, device="meta")
  357. def torch_tensor_index_select(self, dim, index):
  358. return torch_index_select(self, dim, index)
  359. def torch_gather(input, dim, index, *, sparse_grad=False, out=None):
  360. shape = list(input.shape)
  361. shape[dim] = index.shape[dim]
  362. return torch.empty(*shape, device="meta")
  363. def torch_tensor_gather(self, dim, index):
  364. return torch_gather(self, dim, index)
  365. def torch_roll(input, shifts, dims=None):
  366. return input
  367. def torch_flip(input, dims):
  368. return input
  369. def torch_tensor_flip(self, dims):
  370. return self
  371. def torch_nn_conv1d(self, input):
  372. l_in = input.shape[-1]
  373. shape = None
  374. padding = self.padding
  375. if padding == "valid":
  376. padding = (0, 0)
  377. if padding == "same":
  378. shape = list(input.shape)
  379. if shape is None:
  380. shape = list(input.shape)
  381. l_out = math.floor(
  382. (l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
  383. )
  384. shape[-1] = l_out
  385. shape[-2] = self.out_channels
  386. return torch.empty(shape, device="meta")
  387. def torch_nn_conv2d(self, input):
  388. h_in, w_in = input.shape[-2:]
  389. shape = None
  390. padding = self.padding
  391. if padding == "valid":
  392. padding = (0, 0)
  393. if padding == "same":
  394. shape = list(input.shape)
  395. if shape is None:
  396. shape = list(input.shape)
  397. h_out = math.floor(
  398. (h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
  399. )
  400. w_out = math.floor(
  401. (w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
  402. )
  403. shape[-2:] = [h_out, w_out]
  404. shape[-3] = self.out_channels
  405. return torch.empty(shape, device="meta")
  406. def torch_squeeze(input, dim=None):
  407. shape = list(input.shape)
  408. if dim is not None:
  409. if dim < 0:
  410. dim = input.dim() + dim
  411. if shape[dim] == 1:
  412. shape.pop(dim)
  413. else:
  414. new_shape = []
  415. for dim_value in shape:
  416. if dim_value == 1:
  417. continue
  418. new_shape.append(dim_value)
  419. shape = new_shape
  420. return torch.empty(shape, device="meta")
  421. def torch_tensor_squeeze(self, dim=None):
  422. return torch_squeeze(self, dim)
  423. def torch_unsqueeze(input, dim):
  424. shape = list(input.shape)
  425. if dim < 0:
  426. dim = input.dim() + 1 + dim
  427. shape.insert(dim, 1)
  428. return torch.empty(shape, device="meta")
  429. def torch_tensor_unsqueeze(self, dim):
  430. return torch_unsqueeze(self, dim)
  431. def torch_unique_consecutive(input, **kwargs):
  432. output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
  433. if isinstance(output, torch.Tensor):
  434. return output.to("meta")
  435. else:
  436. return tuple(map(output, lambda x: x.to("meta")))
  437. def torch_nn_functional_one_hot(tensor, num_classes=-1):
  438. if num_classes < 0:
  439. raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
  440. shape = list(tensor.shape) + [num_classes]
  441. return torch.empty(shape, device="meta")
  442. def torch_nn_functional_scaled_dot_product_attention(
  443. query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
  444. ):
  445. target_length = query.shape[-2]
  446. head_dim = value.shape[-1]
  447. return torch.empty((*query.shape[:-2], target_length, head_dim), device="meta")
  448. def torch_nn_mseloss(self, input, target):
  449. if self.reduction == "none":
  450. shape = target.shape
  451. else:
  452. shape = (1,)
  453. return torch.empty(shape, device="meta")
  454. def torch_nn_crossentropyloss(self, input, target):
  455. if self.reduction == "none":
  456. shape = target.shape
  457. else:
  458. shape = (1,)
  459. return torch.empty(shape, device="meta")
  460. def torch_nn_bcewithlogitsloss(self, input, target):
  461. if self.reduction == "none":
  462. shape = target.shape
  463. else:
  464. shape = (1,)
  465. return torch.empty(shape, device="meta")
  466. def operator_getitem(a, b):
  467. def to_concrete(t):
  468. if isinstance(t, torch.Tensor):
  469. concrete = torch.ones_like(t, device="cpu")
  470. if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
  471. concrete = concrete.to(torch.int64)
  472. return concrete
  473. return t
  474. if isinstance(a, torch.Tensor):
  475. # TODO: infer shape without performing the computation.
  476. if isinstance(b, tuple):
  477. b = tuple(map(to_concrete, b))
  478. else:
  479. b = to_concrete(b)
  480. return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
  481. return operator.getitem(a, b)
  482. _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
  483. torch.nn.Embedding: torch_nn_embedding,
  484. torch.nn.functional.embedding: torch_nn_functional_embedding,
  485. torch.nn.LayerNorm: torch_nn_layernorm,
  486. torch.nn.GroupNorm: torch_nn_groupnorm,
  487. torch.nn.Linear: torch_nn_linear,
  488. torch.relu: torch_relu,
  489. torch.nn.functional.relu: torch_nn_functional_relu,
  490. torch.nn.ReLU: torch_nn_relu,
  491. torch.where: torch_where,
  492. torch.abs: torch_abs,
  493. torch.arange: torch_arange,
  494. torch.full: torch_full,
  495. torch.cat: torch_cat,
  496. torch.stack: torch_stack,
  497. torch.add: torch_add,
  498. torch.mul: torch_mul,
  499. torch.Tensor.mul: torch_tensor_mul,
  500. torch.matmul: torch_matmul,
  501. torch.bmm: torch_bmm,
  502. torch.baddbmm: torch_baddbmm,
  503. torch.Tensor.baddbmm: torch_tensor_baddbmm,
  504. torch.einsum: torch_einsum,
  505. torch.Tensor.repeat: torch_tensor_repeat,
  506. torch.repeat_interleave: torch_repeat_interleave,
  507. torch.roll: torch_roll,
  508. torch.flip: torch_flip,
  509. torch.Tensor.flip: torch_tensor_flip,
  510. torch.index_select: torch_index_select,
  511. torch.Tensor.index_select: torch_tensor_index_select,
  512. torch.gather: torch_gather,
  513. torch.Tensor.gather: torch_tensor_gather,
  514. torch.nn.Conv1d: torch_nn_conv1d,
  515. torch.nn.Conv2d: torch_nn_conv2d,
  516. torch.squeeze: torch_squeeze,
  517. torch.Tensor.squeeze: torch_tensor_squeeze,
  518. torch.unsqueeze: torch_unsqueeze,
  519. torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
  520. torch.unique_consecutive: torch_unique_consecutive,
  521. torch.nn.functional.one_hot: torch_nn_functional_one_hot,
  522. torch.nn.MSELoss: torch_nn_mseloss,
  523. torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
  524. torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
  525. operator.getitem: operator_getitem,
  526. }
  527. if is_torch_greater_or_equal_than_2_0:
  528. _MANUAL_META_OVERRIDES[torch.nn.functional.scaled_dot_product_attention] = (
  529. torch_nn_functional_scaled_dot_product_attention
  530. )
  531. class HFProxy(Proxy):
  532. """
  533. Proxy that uses metadata to handle data-dependent control-flow.
  534. """
  535. def install_metadata(self, metadata):
  536. self._metadata = metadata
  537. @property
  538. def shape(self):
  539. return self.tracer.create_proxy("call_method", "size", (self,), {})
  540. @property
  541. def device(self):
  542. # Hack so we can track when devices are used. During meta-tensor propagation,
  543. # replace these values with a constant 'meta'
  544. return MetaDeviceAttribute(self, "device")
  545. def __len__(self):
  546. if hasattr(self, "_metadata") and self._metadata is not None:
  547. return len(self._metadata)
  548. return super().__len__()
  549. def __bool__(self):
  550. if hasattr(self, "_metadata") and self._metadata is not None:
  551. return self._metadata
  552. return super().__bool__()
  553. def __getattr__(self, k):
  554. if k == "_metadata":
  555. return self.__getattribute__(k)
  556. # note: not added to the graph yet, if this is a method call
  557. # we peephole optimize to the method invocation
  558. return HFAttribute(self, k)
  559. def __setitem__(self, indices, values):
  560. return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
  561. def __contains__(self, key):
  562. if hasattr(self, "_metadata") and self._metadata is not None:
  563. return key in self._metadata
  564. return super().__contains__(key)
  565. class HFAttribute(HFProxy):
  566. def __init__(self, root, attr: str):
  567. self.root = root
  568. self.attr = attr
  569. self.tracer = root.tracer
  570. self._node = None
  571. if hasattr(self.root, "_metadata"):
  572. self.install_metadata(getattr(self.root._metadata, attr))
  573. @property
  574. def node(self):
  575. # the node for attributes is added lazily, since most will just be method calls
  576. # which do not rely on the getitem call
  577. if self._node is None:
  578. self._node = self.tracer.create_proxy("call_function", builtins.getattr, (self.root, self.attr), {}).node
  579. return self._node
  580. def __call__(self, *args, **kwargs):
  581. return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
  582. class MetaDeviceAttribute(HFAttribute):
  583. pass
  584. class HFCacheProxy(HFProxy):
  585. """
  586. Proxy that represents an instance of `transformers.cache_utils.Cache`.
  587. """
  588. def install_orig_cache_cls(self, orig_cache_cls: Type[Cache]):
  589. self._orig_cache_cls = orig_cache_cls
  590. @property
  591. def __class__(self):
  592. if not hasattr(self, "_orig_cache_cls"):
  593. raise RuntimeError("The original Cache class must be installed to the HFCacheProxy.")
  594. return self.tracer._CLASSES_TO_PATCH[self._orig_cache_cls]
  595. def create_wrapper(
  596. function: Callable,
  597. op_type: Union[Literal["call_function"], Literal["call_method"], Literal["get_attr"]],
  598. proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
  599. ) -> Callable:
  600. @functools.wraps(function)
  601. def wrapper(*args, **kwargs):
  602. if not is_fx_tracing():
  603. return function(*args, **kwargs)
  604. found_proxies = []
  605. def check_proxy(a):
  606. if isinstance(a, Proxy):
  607. found_proxies.append(a)
  608. torch.fx.node.map_aggregate(args, check_proxy)
  609. torch.fx.node.map_aggregate(kwargs, check_proxy)
  610. if len(found_proxies) > 0:
  611. tracer = found_proxies[0].tracer
  612. if op_type == "call_function":
  613. target = function
  614. elif op_type == "call_method":
  615. target = function.__name__
  616. elif op_type == "get_attr":
  617. target = function.__name__
  618. else:
  619. raise ValueError(f"op_type {op_type} not supported.")
  620. return tracer.create_proxy(op_type, target, args, kwargs, proxy_factory_fn=proxy_factory_fn)
  621. else:
  622. return function(*args, **kwargs)
  623. return wrapper
  624. class HFProxyableClassMeta(type):
  625. """
  626. Metaclass that creates a class with its main methods wrapped to be proxyable.
  627. """
  628. def __new__(
  629. cls,
  630. name: str,
  631. bases: Tuple[Type, ...],
  632. attrs: Dict[str, Any],
  633. proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
  634. ):
  635. cls = super().__new__(cls, name, bases, attrs)
  636. for attr_name in dir(cls):
  637. attr = getattr(cls, attr_name, None)
  638. if attr is None:
  639. continue
  640. if attr_name == "__init__":
  641. op_type = "call_function"
  642. elif attr_name.startswith("__"):
  643. op_type = None
  644. elif inspect.ismethod(attr):
  645. op_type = "call_function"
  646. elif inspect.isfunction(attr):
  647. op_type = "call_method"
  648. else:
  649. op_type = None
  650. if op_type is not None:
  651. setattr(cls, attr_name, create_wrapper(attr, op_type, proxy_factory_fn=proxy_factory_fn))
  652. return cls
  653. def gen_constructor_wrapper(target: Callable) -> Tuple[Callable, Callable]:
  654. """
  655. Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on.
  656. """
  657. wrapper = create_wrapper(target, "call_function")
  658. return wrapper, target
  659. def _proxies_to_metas(v):
  660. """Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
  661. if isinstance(v, MetaDeviceAttribute):
  662. return "meta"
  663. if isinstance(v, torch.fx.Proxy):
  664. if not (isinstance(v, HFProxy) and hasattr(v, "_metadata")):
  665. raise RuntimeError(f"No metadata was found for {v}")
  666. return v._metadata
  667. return v
  668. def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
  669. def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
  670. global _CURRENT_TRACER
  671. if not isinstance(_CURRENT_TRACER, HFTracer):
  672. raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
  673. cache_proxy = HFCacheProxy(n, _CURRENT_TRACER)
  674. cache_proxy.install_orig_cache_cls(orig_cache_cls)
  675. return cache_proxy
  676. return cache_proxy_factory_fn
  677. # Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
  678. ProxyableCache = HFProxyableClassMeta(
  679. "ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache)
  680. )
  681. ProxyableDynamicCache = HFProxyableClassMeta(
  682. "ProxyableDynamicCache",
  683. (DynamicCache,),
  684. {},
  685. proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
  686. )
  687. ProxyableSinkCache = HFProxyableClassMeta(
  688. "ProxyableSinkCache",
  689. (SinkCache,),
  690. {},
  691. proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
  692. )
  693. ProxyableStaticCache = HFProxyableClassMeta(
  694. "ProxyableStaticCache",
  695. (StaticCache,),
  696. {},
  697. proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
  698. )
  699. def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
  700. if forbidden_values is None:
  701. forbidden_values = []
  702. value = random.randint(low, high)
  703. while value in forbidden_values:
  704. value = random.randint(low, high)
  705. return value
  706. class HFTracer(Tracer):
  707. """
  708. Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
  709. regular PyTorch torch.fx.Proxy.
  710. """
  711. # Feature flag for proxying accesses to buffer values
  712. proxy_buffer_attributes: bool = True
  713. allow_insert_stateless_mods: bool = True
  714. _TORCH_METHODS_TO_PATCH = [
  715. "arange",
  716. "zeros",
  717. "ones",
  718. "full",
  719. "full_like",
  720. "eye",
  721. "empty",
  722. "tensor",
  723. "clamp",
  724. "finfo",
  725. "tril",
  726. ]
  727. _CLASSES_TO_PATCH = {
  728. Cache: ProxyableCache,
  729. DynamicCache: ProxyableDynamicCache,
  730. SinkCache: ProxyableSinkCache,
  731. StaticCache: ProxyableStaticCache,
  732. }
  733. supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
  734. def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
  735. super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
  736. if not is_torch_fx_available():
  737. raise ImportError(
  738. f"Found an incompatible version of torch. Found version {get_torch_version()}, but only version "
  739. f"{TORCH_FX_REQUIRED_VERSION} is supported."
  740. )
  741. def _generate_dummy_input(
  742. self, model: "PreTrainedModel", input_name: str, shape: List[int], input_names: List[str]
  743. ) -> Dict[str, torch.Tensor]:
  744. """Generates dummy input for model inference recording."""
  745. # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
  746. # from pickle, or from the "__class__" attribute in the general case.
  747. model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
  748. device = model.device
  749. inputs_dict = {}
  750. # when tracing a model with KV cache, we simply need to unsure that the KV cache length is larger than one to
  751. # rightfully pass certain controlflows (Example: https://github.com/huggingface/transformers/blob/5c8d941d66734811d2ef6f57f15b44f7fb7a98c4/src/transformers/modeling_attn_mask_utils.py#L162).
  752. # After tracing, the model can then still be used with arbitrary lengths different than the one used during tracing.
  753. kv_cache_length = 5
  754. if input_name in ["labels", "start_positions", "end_positions"]:
  755. batch_size = shape[0]
  756. if model_class_name in [
  757. *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
  758. *get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
  759. *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
  760. *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
  761. *get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
  762. ]:
  763. inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
  764. elif model_class_name in [
  765. *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
  766. *get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
  767. "XLNetForQuestionAnswering",
  768. ]:
  769. inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
  770. inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
  771. elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
  772. if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
  773. raise ValueError(
  774. "Could not retrieve the problem type for the sequence classification task, please set "
  775. 'model.config.problem_type to one of the following values: "regression", '
  776. '"single_label_classification", or "multi_label_classification".'
  777. )
  778. if model.config.problem_type == "regression":
  779. labels_shape = (batch_size, model.config.num_labels)
  780. labels_dtype = torch.float32
  781. elif model.config.problem_type == "single_label_classification":
  782. labels_shape = (batch_size,)
  783. labels_dtype = torch.long
  784. elif model.config.problem_type == "multi_label_classification":
  785. labels_shape = (batch_size, model.config.num_labels)
  786. labels_dtype = torch.float32
  787. else:
  788. raise ValueError(
  789. 'Expected model.config.problem_type to be either: "regression", "single_label_classification"'
  790. f', or "multi_label_classification", but "{model.config.problem_type}" was provided.'
  791. )
  792. inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
  793. elif model_class_name in [
  794. *get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
  795. *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
  796. *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
  797. *get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
  798. *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
  799. *get_values(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES),
  800. "GPT2DoubleHeadsModel",
  801. "PeftModelForCausalLM",
  802. "PeftModelForSeq2SeqLM",
  803. ]:
  804. inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
  805. elif model_class_name in [*get_values(MODEL_FOR_CTC_MAPPING_NAMES)]:
  806. inputs_dict["labels"] = torch.zeros(shape, dtype=torch.float32, device=device)
  807. else:
  808. raise NotImplementedError(
  809. f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
  810. )
  811. elif "pixel_values" in input_name:
  812. batch_size = shape[0]
  813. image_size = getattr(model.config, "image_size", None)
  814. if image_size is None:
  815. if hasattr(model.config, "vision_config"):
  816. image_size = model.config.vision_config.image_size
  817. elif hasattr(model.config, "encoder"):
  818. image_size = model.config.encoder.image_size
  819. else:
  820. image_size = (_generate_random_int(), _generate_random_int())
  821. # If no num_channels is in the config, use some arbitrary value.
  822. num_channels = getattr(model.config, "num_channels", 3)
  823. if not isinstance(image_size, collections.abc.Iterable):
  824. image_size = (image_size, image_size)
  825. height, width = image_size
  826. inputs_dict[input_name] = torch.zeros(
  827. batch_size, num_channels, height, width, dtype=torch.float32, device=device
  828. )
  829. elif "bbox" in input_name:
  830. inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
  831. elif "input_features" in input_name:
  832. inputs_dict[input_name] = torch.zeros(
  833. *shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
  834. )
  835. elif "inputs_embeds" in input_name:
  836. batch_size = shape[0]
  837. if (
  838. getattr(model.config, "embedding_size", None) is not None
  839. and model.config.model_type != "megatron-bert"
  840. ):
  841. embedding_size = model.config.embedding_size
  842. else:
  843. embedding_size = model.config.hidden_size
  844. if len(shape) == 3:
  845. # (batch_size, num_choices, sequence_length, embedding_size)
  846. embedding_shape = (batch_size, shape[1], shape[2], embedding_size)
  847. else:
  848. # (batch_size, sequence_length, embedding_size)
  849. embedding_shape = (batch_size, shape[1], embedding_size)
  850. inputs_dict[input_name] = torch.zeros(embedding_shape, dtype=torch.float, device=device)
  851. elif "visual_feats" in input_name:
  852. inputs_dict[input_name] = torch.zeros(
  853. shape
  854. + [
  855. model.config.visual_feat_dim,
  856. ],
  857. dtype=torch.float,
  858. device=device,
  859. )
  860. elif "visual_pos" in input_name:
  861. inputs_dict[input_name] = torch.zeros(
  862. shape
  863. + [
  864. model.config.visual_pos_dim,
  865. ],
  866. dtype=torch.float,
  867. device=device,
  868. )
  869. elif "inputs" in input_name:
  870. inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
  871. elif "input_values" in input_name:
  872. batch_size, _ = shape
  873. # Generating big sequence length for audio inputs.
  874. seq_length = _generate_random_int(low=10000, high=20000)
  875. inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
  876. elif "mask" in input_name:
  877. if "past_key_values" in input_names:
  878. mask_shape = [shape[0], shape[1] + kv_cache_length]
  879. else:
  880. mask_shape = shape
  881. inputs_dict[input_name] = torch.zeros(mask_shape, dtype=torch.long, device=device)
  882. elif "ids" in input_name:
  883. inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
  884. elif "past_key_values" in input_name:
  885. if model.config.model_type not in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
  886. raise NotImplementedError(
  887. f"Symbolic trace with past_key_values input is not supported yet for the model {model.config.model_type}. Please open an issue or a PR in Transformers repository if you would like to see the support added."
  888. )
  889. num_heads = model.config.num_attention_heads
  890. head_dim = model.config.hidden_size // model.config.num_attention_heads
  891. cache_shape = (shape[0], num_heads, kv_cache_length, head_dim)
  892. pkv = tuple(
  893. (
  894. torch.rand(cache_shape, dtype=torch.float, device=device),
  895. torch.rand(cache_shape, dtype=torch.float, device=device),
  896. )
  897. for i in range(model.config.num_hidden_layers)
  898. )
  899. inputs_dict[input_name] = pkv
  900. else:
  901. shape_with_hidden_size = shape + [model.config.hidden_size]
  902. inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device)
  903. return inputs_dict
  904. def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
  905. rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
  906. if kind == "placeholder" and target in self.meta_args:
  907. rv.install_metadata(self.meta_args[target])
  908. return rv
  909. if target in self.orig_fns:
  910. # NOTE: tensor constructors in PyTorch define the `device` argument as
  911. # *kwargs-only*. That is why this works. If you add methods to
  912. # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
  913. # this will break and you will likely see issues where we cannot infer
  914. # the size of the output.
  915. if "device" in kwargs:
  916. kwargs["device"] = "meta"
  917. try:
  918. args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
  919. kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)
  920. should_install_metadata = True
  921. self._disable_module_getattr = True
  922. self._disable_call_module = True
  923. if kind == "call_function":
  924. meta_target = _MANUAL_META_OVERRIDES.get(target, target)
  925. meta_out = meta_target(*args_metas, **kwargs_metas)
  926. if isinstance(meta_out, torch.Tensor):
  927. meta_out = meta_out.to(device="meta")
  928. elif kind == "call_method":
  929. method = getattr(args_metas[0].__class__, target)
  930. meta_target = _MANUAL_META_OVERRIDES.get(method, method)
  931. meta_out = meta_target(*args_metas, **kwargs_metas)
  932. elif kind == "call_module":
  933. if not hasattr(self, "orig_forward"):
  934. raise AttributeError(f"{self} does not have an attribute called orig_forward")
  935. mod = self.root.get_submodule(target)
  936. mod_type = type(mod)
  937. if mod_type in _MANUAL_META_OVERRIDES:
  938. meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
  939. else:
  940. meta_out = self.orig_forward(*args_metas, **kwargs_metas)
  941. elif kind == "get_attr":
  942. attr_itr = self.root
  943. atoms = target.split(".")
  944. for atom in atoms:
  945. attr_itr = getattr(attr_itr, atom)
  946. if isinstance(attr_itr, torch.Tensor):
  947. meta_out = attr_itr.to(device="meta")
  948. else:
  949. meta_out = attr_itr
  950. else:
  951. should_install_metadata = False
  952. if should_install_metadata:
  953. if not isinstance(rv, Proxy):
  954. raise ValueError("Don't support composite output yet")
  955. rv.install_metadata(meta_out)
  956. except Exception as e:
  957. if _IS_IN_DEBUG_MODE:
  958. warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
  959. self._disable_module_getattr = False
  960. self._disable_call_module = False
  961. return rv
  962. # Replaced by .getattr from PyTorch 1.13
  963. def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
  964. if getattr(self, "_disable_module_getattr", False):
  965. return attr_val
  966. else:
  967. def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
  968. for n, p in collection_to_search:
  969. if attr_val is p:
  970. if n not in parameter_proxy_cache:
  971. kwargs = {}
  972. if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
  973. kwargs["proxy_factory_fn"] = (
  974. None
  975. if not self.param_shapes_constant
  976. else lambda node: ParameterProxy(self, node, n, attr_val)
  977. )
  978. val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
  979. parameter_proxy_cache[n] = val_proxy
  980. return parameter_proxy_cache[n]
  981. return None
  982. if isinstance(attr_val, torch.nn.Parameter):
  983. maybe_parameter_proxy = maybe_get_proxy_for_attr(
  984. attr_val, self.root.named_parameters(), parameter_proxy_cache
  985. )
  986. if maybe_parameter_proxy is not None:
  987. return maybe_parameter_proxy
  988. if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
  989. maybe_buffer_proxy = maybe_get_proxy_for_attr(
  990. attr_val, self.root.named_buffers(), parameter_proxy_cache
  991. )
  992. if maybe_buffer_proxy is not None:
  993. return maybe_buffer_proxy
  994. return attr_val
  995. # Needed for PyTorch 1.13+
  996. def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
  997. return self._module_getattr(attr, attr_val, parameter_proxy_cache)
  998. def call_module(self, m, forward, args, kwargs):
  999. if getattr(self, "_disable_call_module", False):
  1000. return forward(*args, **kwargs)
  1001. self.orig_forward = forward
  1002. return super().call_module(m, forward, args, kwargs)
  1003. def proxy(self, node):
  1004. return HFProxy(node, self)
  1005. @contextlib.contextmanager
  1006. def patch_for_tracing(self, root: Union[torch.nn.Module, Callable[..., Any]]):
  1007. # Patching torch functions
  1008. self.patched_torch_methods = {
  1009. target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
  1010. }
  1011. self.orig_fns = set()
  1012. for name, (wrapper, orig) in self.patched_torch_methods.items():
  1013. setattr(torch, name, wrapper)
  1014. self.orig_fns.add(orig)
  1015. # Patching classes
  1016. patched = []
  1017. module_of_model = inspect.getmodule(root)
  1018. for name, mod in sys.modules.items():
  1019. if module_of_model is not None and mod is not module_of_model:
  1020. continue
  1021. if not name.startswith("transformers"):
  1022. continue
  1023. for orig_cls, patched_cls in self._CLASSES_TO_PATCH.items():
  1024. for attr_name, attr in mod.__dict__.items():
  1025. if attr is orig_cls:
  1026. patched.append((mod, attr_name, orig_cls))
  1027. setattr(mod, attr_name, patched_cls)
  1028. yield
  1029. # Restoring patched functions and classes.
  1030. for name, (_, orig) in self.patched_torch_methods.items():
  1031. setattr(torch, name, orig)
  1032. self.patched_torch_methods = {}
  1033. self.orig_fns = set()
  1034. for mod, attr_name, orig_cls in patched:
  1035. setattr(mod, attr_name, orig_cls)
  1036. def trace(
  1037. self,
  1038. root: Union[torch.nn.Module, Callable[..., Any]],
  1039. concrete_args: Optional[Dict[str, Any]] = None,
  1040. dummy_inputs: Optional[Dict[str, Any]] = None,
  1041. complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
  1042. ) -> Graph:
  1043. """
  1044. Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
  1045. `torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
  1046. the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
  1047. `torch.nn.Module` instance to use as the root and add embedded constants to.
  1048. Args:
  1049. root (`torch.nn.Module` or `Callable`):
  1050. Either a `torch.nn.Module`` or a function to be traced through. If root is not a
  1051. [`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
  1052. concrete_args (`Dict[str, Any], *optional*):
  1053. Concrete arguments that should not be treated as Proxies
  1054. dummy_inputs (`Dict[str, Any]`, *optional*):
  1055. The dummy inputs needed to handle data-dependent control-flow if `root` is not a
  1056. [`~transformers.PreTrainedModel`]. It can also be used when `root` is a
  1057. [`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
  1058. complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
  1059. If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
  1060. `dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.
  1061. Returns:
  1062. `torch.fx.Graph`:
  1063. A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.
  1064. """
  1065. sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root)
  1066. if concrete_args is None:
  1067. concrete_args = {}
  1068. if dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs:
  1069. for param in sig.parameters.values():
  1070. if param.name in dummy_inputs:
  1071. continue
  1072. if param.default is inspect.Parameter.empty:
  1073. raise ValueError(f"You need to specify a default value for the parameter {param.name}.")
  1074. concrete_args.update(
  1075. {
  1076. p.name: p.default
  1077. for p in sig.parameters.values()
  1078. if (p.name not in dummy_inputs and p.name not in concrete_args)
  1079. }
  1080. )
  1081. input_names = sig.parameters.keys() - concrete_args.keys()
  1082. # Creating a random input shape to generate dummy inputs.
  1083. batch_size = _generate_random_int()
  1084. sequence_length = _generate_random_int()
  1085. shape = [batch_size, sequence_length]
  1086. if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
  1087. num_choices = _generate_random_int(low=2, high=5)
  1088. shape.insert(1, num_choices)
  1089. inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
  1090. for input_name in input_names:
  1091. if input_name in inputs:
  1092. continue
  1093. # We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
  1094. # be able to use HFTracer._generate_dummy_input.
  1095. if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith(
  1096. ("_deserialize_graph_module", "_CodeOnlyModule")
  1097. ):
  1098. inputs.update(self._generate_dummy_input(root, input_name, shape, input_names=input_names))
  1099. else:
  1100. raise RuntimeError(
  1101. f"Could not generate input named {input_name} for because root is not a"
  1102. " transformers.PreTrainedModel."
  1103. )
  1104. def to_meta(value):
  1105. if isinstance(value, torch.Tensor):
  1106. return value.to("meta")
  1107. return value
  1108. concrete_metas = pytree.tree_map(to_meta, inputs)
  1109. for param in sig.parameters.values():
  1110. if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
  1111. concrete_metas[f"**{param.name}"] = {}
  1112. self.meta_args = concrete_metas
  1113. global _CURRENT_TRACER
  1114. _CURRENT_TRACER = self
  1115. with self.patch_for_tracing(root):
  1116. try:
  1117. self.graph = super().trace(root, concrete_args=concrete_args)
  1118. finally:
  1119. _CURRENT_TRACER = None
  1120. # This is necessary because concrete args are added as input to the traced module since
  1121. # https://github.com/pytorch/pytorch/pull/55888.
  1122. for node in self.graph.nodes:
  1123. if node.op == "placeholder":
  1124. # Removing default values for inputs as the forward pass will fail with them.
  1125. if node.target in input_names:
  1126. node.args = ()
  1127. # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
  1128. # It cannot infer on the attributes and methods the input should have, and fails.
  1129. node.type = torch.Tensor
  1130. # It is a concrete arg so it is not used and should be removed.
  1131. else:
  1132. to_visit = [node]
  1133. to_delete = collections.OrderedDict()
  1134. while to_visit:
  1135. n = to_visit.pop(0)
  1136. to_delete[n] = None
  1137. to_visit += list(n.users.keys())
  1138. for user in reversed(to_delete.keys()):
  1139. self.graph.erase_node(user)
  1140. # TODO: solves GraphModule creation.
  1141. # Without this, return type annotation "Tuple" is causing code execution failure.
  1142. if node.op == "output":
  1143. node.type = None
  1144. return self.graph
  1145. def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
  1146. """
  1147. Whether the module was instantiated with Proxies. If that is the case, such module cannot be a leaf module
  1148. because its attributes are input-dependent.
  1149. """
  1150. return any(isinstance(attr, Proxy) for attr in mod.__dict__.values())
  1151. def _insert_module_as_submodule(self, mod: nn.Module) -> str:
  1152. """
  1153. Helper method which tries to insert a module that was not declared as submodule.
  1154. """
  1155. # If one of the module attributes is a Proxy, it means that its instantiation is input-dependent.
  1156. # It is not possible to insert such modules, those should be traced through.
  1157. if self._stateless_mod_instanciation_depends_on_proxies(mod):
  1158. return ""
  1159. idx = 0
  1160. mod_name = mod.__class__.__name__.lower()
  1161. path = f"{mod_name}_{idx}"
  1162. already_inserted = False
  1163. while hasattr(self.root, path):
  1164. if getattr(self.root, path) is mod:
  1165. already_inserted = True
  1166. break
  1167. path = f"{mod_name}_{idx}"
  1168. idx += 1
  1169. # No need to add multiple instances of the same module.
  1170. if not already_inserted:
  1171. self.root.add_module(path, mod)
  1172. return path
  1173. def path_of_module(self, mod: nn.Module) -> str:
  1174. """
  1175. Helper method to find the qualified name of `mod` in the Module hierarchy of `root`. For example, if `root` has
  1176. a submodule named `foo`, which has a submodule named `bar`, passing `bar` into this function will return the
  1177. string "foo.bar".
  1178. Args:
  1179. mod (str): The `Module` to retrieve the qualified name for.
  1180. """
  1181. try:
  1182. return super().path_of_module(mod)
  1183. except NameError as e:
  1184. if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
  1185. path = self._insert_module_as_submodule(mod)
  1186. return path
  1187. raise e
  1188. def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
  1189. return (not self._stateless_mod_instanciation_depends_on_proxies(m)) and super().is_leaf_module(
  1190. m, module_qualified_name
  1191. )
  1192. @compatibility(is_backward_compatible=True)
  1193. def keys(self, obj: "Proxy") -> Any:
  1194. """Called when a proxy object is has the keys() method called.
  1195. This is what happens when ** is called on a proxy. This should return an iterator if ** is supposed to work in
  1196. your custom tracer.
  1197. """
  1198. attribute = HFAttribute(obj, "keys")()
  1199. if obj.node.target.startswith("**"):
  1200. return attribute._metadata
  1201. return attribute
  1202. def get_concrete_args(model: nn.Module, input_names: List[str]):
  1203. sig = inspect.signature(model.forward)
  1204. if not (set(input_names) <= set(sig.parameters.keys())):
  1205. formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
  1206. formatted_allowed_input_names = ", ".join(sig.parameters.keys())
  1207. raise ValueError(
  1208. f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
  1209. f" {formatted_allowed_input_names}"
  1210. )
  1211. return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
  1212. def is_model_supported(model: "PreTrainedModel"):
  1213. return model.__class__.__name__ in _SUPPORTED_MODELS
  1214. def check_if_model_is_supported(model: "PreTrainedModel"):
  1215. if not is_model_supported(model):
  1216. supported_model_names = ", ".join(_SUPPORTED_MODELS)
  1217. raise NotImplementedError(
  1218. f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
  1219. )
  1220. def symbolic_trace(
  1221. model: "PreTrainedModel",
  1222. input_names: Optional[List[str]] = None,
  1223. disable_check: bool = False,
  1224. tracer_cls: Type[HFTracer] = HFTracer,
  1225. ) -> GraphModule:
  1226. """
  1227. Performs symbolic tracing on the model.
  1228. Args:
  1229. model ([`PretrainedModel`]):
  1230. The model to trace.
  1231. input_names (`List[str]`, *optional*):
  1232. The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
  1233. disable_check (`bool`, *optional*, defaults to `False`):
  1234. If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
  1235. tracer_cls (`Type[HFTracer]`, *optional*, defaults to `HFTracer`):
  1236. The tracer class to use for instantiating the tracer. If unset, `HFTracer` is used instead.
  1237. Returns:
  1238. `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
  1239. Example:
  1240. ```python
  1241. from transformers.utils.fx import symbolic_trace
  1242. traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
  1243. ```
  1244. """
  1245. if input_names is None:
  1246. input_names = model.dummy_inputs.keys()
  1247. input_names = list(input_names)
  1248. concrete_args = get_concrete_args(model, input_names)
  1249. if not disable_check:
  1250. check_if_model_is_supported(model)
  1251. if "past_key_values" in input_names and not getattr(model.config, "use_cache", False):
  1252. logger.warning(
  1253. "`past_key_values` were specified as input names, but model.config.use_cache = False, this might lead to "
  1254. "unexpected behavior."
  1255. )
  1256. if "past_key_values" not in input_names and getattr(model.config, "use_cache", False):
  1257. logger.warning(
  1258. "`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting "
  1259. "model.config.use_cache = False."
  1260. )
  1261. model.config.use_cache = False
  1262. # Tracing.
  1263. tracer = tracer_cls()
  1264. traced_graph = tracer.trace(model, concrete_args=concrete_args)
  1265. traced = torch.fx.GraphModule(model, traced_graph)
  1266. traced.config = model.config
  1267. # The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
  1268. # _generate_dummy_input, where the model class is needed.
  1269. traced.class_for_deserialization = model.__class__
  1270. traced.device = model.device
  1271. return traced