debug_utils.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778
  1. # mypy: allow-untyped-defs
  2. # mypy: disable-error-code="method-assign"
  3. import copy
  4. import functools
  5. import getpass
  6. import inspect
  7. import itertools
  8. import logging
  9. import os
  10. import re
  11. import subprocess
  12. import tempfile
  13. import textwrap
  14. from collections import Counter
  15. from importlib import import_module
  16. from typing import Any, Callable, Dict, List, Optional, TypeVar
  17. import torch
  18. import torch._prims_common as utils
  19. import torch._subclasses.meta_utils
  20. from torch import Tensor
  21. from torch._dynamo.testing import rand_strided
  22. from torch._prims_common import is_float_dtype
  23. from torch.multiprocessing.reductions import StorageWeakRef
  24. from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
  25. from . import config
  26. from .utils import clone_inputs, get_debug_dir
  27. log = logging.getLogger(__name__)
  28. T = TypeVar("T")
  29. inductor_config = import_module("torch._inductor.config")
  30. use_buck = inductor_config.is_fbcode()
  31. if use_buck:
  32. import libfb.py.build_info
  33. extra_deps = []
  34. extra_imports = ""
  35. if use_buck:
  36. extra_deps = [
  37. "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu",
  38. "//caffe2/torch/fb/sparsenn:sparsenn_operators",
  39. "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu",
  40. "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops",
  41. ]
  42. cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") # type: ignore[possibly-undefined]
  43. extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps])
  44. BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"]
  45. class BuckTargetWriter:
  46. def __init__(self, filename):
  47. self.subdir, self.py_file = os.path.split(os.path.abspath(filename))
  48. self.target = self.py_file.replace(".py", "")
  49. # Get main_module path from fbcode
  50. self.path = f'{self.subdir.replace("/", ".")}.{self.target}'
  51. self.path = self.path[self.path.find("fbcode.") :]
  52. self.path = self.path[7:]
  53. # Get cmd line path
  54. tmp = self.subdir
  55. tmp = tmp[tmp.find("fbcode/") :][7:]
  56. self.cmd_line_path = f"//{tmp}:{self.target}"
  57. def build(self):
  58. extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps])
  59. return textwrap.dedent(
  60. f"""
  61. load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
  62. python_binary(
  63. name="{self.target}",
  64. srcs = ["{self.py_file}"],
  65. compile = False,
  66. deps = [
  67. "//caffe2:torch",
  68. "//caffe2/functorch:functorch",
  69. "//triton:triton",
  70. "{cur_target}",
  71. ],
  72. cpp_deps = [
  73. {extra_cpp_deps}
  74. ],
  75. main_module = "{self.path}",
  76. par_style = "xar",
  77. )
  78. """
  79. )
  80. def write(self, print_msg=True):
  81. target_file = os.path.join(self.subdir, "TARGETS")
  82. with open(target_file, "w") as fd:
  83. fd.write(self.build())
  84. # log.warning("Wrote isolation TARGETS file at %s", target_file)
  85. cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path]
  86. if print_msg:
  87. log.warning(
  88. "Found an example that reproduces the error. Run this cmd to repro - %s",
  89. " ".join(cmd_split),
  90. )
  91. return cmd_split
  92. def minifier_dir():
  93. path = os.path.join(get_debug_dir(), "minifier")
  94. if path is None:
  95. path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}"
  96. if not os.path.exists(path):
  97. os.makedirs(path, exist_ok=True)
  98. return path
  99. MAX_CONSTANT_NUMEL_INLINE = 4
  100. class NNModuleToString:
  101. safe_reprs = [
  102. torch.nn.Linear,
  103. torch.nn.Conv1d,
  104. torch.nn.Conv2d,
  105. torch.nn.Conv3d,
  106. torch.nn.BatchNorm1d,
  107. torch.nn.BatchNorm2d,
  108. torch.nn.BatchNorm3d,
  109. torch.nn.LayerNorm,
  110. torch.nn.Dropout,
  111. torch.nn.Softmax,
  112. torch.nn.ReLU,
  113. torch.nn.GELU,
  114. torch.nn.Identity,
  115. torch.nn.MaxPool2d,
  116. torch.nn.Embedding,
  117. torch.nn.Tanh,
  118. torch.nn.ConvTranspose1d,
  119. torch.nn.GLU,
  120. torch.nn.LSTM,
  121. torch.nn.Flatten,
  122. torch.nn.AdaptiveAvgPool2d,
  123. ]
  124. @staticmethod
  125. def can_convert_to_string(gm):
  126. cant_convert = set()
  127. for _, module in gm.named_children():
  128. if type(module) not in NNModuleToString.safe_reprs:
  129. cant_convert.add(module)
  130. if len(cant_convert) > 0:
  131. log.warning("We have not tested reprs of some modules - %s", cant_convert)
  132. # TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct.
  133. return True
  134. @staticmethod
  135. def convert(gm):
  136. from torch.nn.modules.module import _addindent
  137. tab = " " * 4
  138. model_str = textwrap.dedent(
  139. """
  140. from torch.nn import *
  141. class Repro(torch.nn.Module):
  142. def __init__(self):
  143. super().__init__()
  144. """
  145. )
  146. for module_name, module in gm.named_children():
  147. module_str = f"{module.__repr__()}"
  148. # module should be a core torch.nn.Module, so all parameters
  149. # should be on the same device.
  150. example_param = next(module.parameters(), None)
  151. if example_param is not None and example_param.is_cuda:
  152. module_str = f"{module_str}.cuda()"
  153. model_str += f"{tab*2}self.{module_name} = {module_str}\n"
  154. for buffer_name, buffer in gm._buffers.items():
  155. if buffer is None:
  156. continue
  157. # Serialize full data for small buffers
  158. if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE:
  159. from torch._tensor_str import PRINT_OPTS
  160. assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE
  161. tensor_str = repr(buffer)
  162. elif torch.is_floating_point(buffer):
  163. tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})"
  164. else:
  165. tensor_str = (
  166. f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})"
  167. )
  168. if buffer.is_cuda:
  169. tensor_str = f"{tensor_str}.cuda()"
  170. model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n"
  171. for param_name, param in gm._parameters.items():
  172. if param is None:
  173. continue
  174. maybe_device = ""
  175. if param.is_cuda:
  176. maybe_device = ', device="cuda"'
  177. tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))"
  178. model_str += f"{tab*2}self.{param_name} = {tensor_str}\n"
  179. # TODO - Keep this code for now. But, I don't think we will need this.
  180. # attrs = dir(gm)
  181. # for attr in attrs:
  182. # if "_tensor_constant" in attr:
  183. # val = getattr(gm, attr)
  184. # model_str += f" {attr} = {val!r}\n"
  185. model_str += f"{_addindent(gm.code, 4)}\n"
  186. return model_str
  187. @functools.lru_cache(None) # subprocess is expensive
  188. def _cuda_system_info_comment():
  189. if not torch.cuda.is_available():
  190. return "# torch.cuda.is_available()==False, no GPU info collected\n"
  191. model_str = "# CUDA Info: \n"
  192. try:
  193. cuda_version_out = subprocess.check_output(["nvcc", "--version"])
  194. cuda_version_lines = cuda_version_out.decode().split("\n")
  195. comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]])
  196. model_str += f"{comment}\n"
  197. except (FileNotFoundError, subprocess.CalledProcessError):
  198. model_str += "# nvcc not found\n"
  199. gpu_names = Counter(
  200. torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())
  201. )
  202. model_str += "# GPU Hardware Info: \n"
  203. for name, count in gpu_names.items():
  204. model_str += f"# {name} : {count} \n"
  205. model_str += "\n"
  206. return model_str
  207. def generate_config_string(*, stable_output=False):
  208. import torch._functorch.config
  209. import torch._inductor.config
  210. if stable_output:
  211. return "# config omitted due to stable_output=True"
  212. experimental_config = torch.fx.experimental._config.codegen_config() # type: ignore[attr-defined]
  213. return f"""\
  214. import torch._dynamo.config
  215. import torch._inductor.config
  216. import torch._functorch.config
  217. import torch.fx.experimental._config
  218. {torch._dynamo.config.codegen_config()}
  219. {torch._inductor.config.codegen_config()}
  220. {torch._functorch.config.codegen_config()}
  221. {experimental_config}
  222. """
  223. def get_minifier_repro_path():
  224. return os.path.join(minifier_dir(), "minifier_launcher.py")
  225. def helper_for_dump_minify(contents):
  226. minified_repro_path = get_minifier_repro_path()
  227. log.warning("Writing minified repro to:\n%s", minified_repro_path)
  228. if use_buck:
  229. BuckTargetWriter(minified_repro_path).write()
  230. try:
  231. with open(minified_repro_path, "w") as fd:
  232. fd.write(contents)
  233. except OSError as e:
  234. log.exception("")
  235. raise NotImplementedError("Could not write to {minified_repro_path}") from e
  236. class AccuracyError(Exception):
  237. pass
  238. def clone_inputs_retaining_gradness(example_inputs):
  239. """
  240. This clone inputs is different from utils clone_input. In case of minifier,
  241. all the tensors are leaf tensors while creating a new graph. So, we set the
  242. requires_grad field w/o checking the leafness of the tensor.
  243. """
  244. cloned_inputs = clone_inputs(example_inputs)
  245. for idx in range(len(example_inputs)):
  246. if isinstance(cloned_inputs[idx], torch.Tensor):
  247. cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad)
  248. return cloned_inputs
  249. def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False):
  250. """
  251. Runs a forward and possibly backward iteration for a given mod and args.
  252. When disable_clone is True, we will use args as-is without cloning.
  253. This is higher fidelity but we may destroy the args in the process.
  254. """
  255. from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass
  256. gm = copy.deepcopy(gm)
  257. if not disable_clone:
  258. args = clone_inputs_retaining_gradness(args)
  259. if hasattr(gm, "zero_grad"):
  260. gm.zero_grad(True)
  261. # TorchInductor returned callable expects lists. So, may need a boxed calling convention.
  262. out = gm(args) if hasattr(gm, "_boxed_call") else gm(*args)
  263. if only_fwd:
  264. return out
  265. if requires_bwd_pass(out):
  266. loss = reduce_to_scalar_loss(out)
  267. loss.backward()
  268. return collect_results(gm, out, None, args)
  269. def same_two_models(
  270. gm,
  271. opt_gm,
  272. example_inputs,
  273. only_fwd=False,
  274. *,
  275. require_fp64=False,
  276. ignore_non_fp=False,
  277. ):
  278. """
  279. Check two models have same accuracy.
  280. require_fp64: if True, raise an error if we unable to calculate the fp64 reference
  281. ignore_non_fp: if True, do not compare outputs which are not floating point. This
  282. is mostly useful for the minifier (which wants to avoid quantizing floating point
  283. error into integer/boolean error)
  284. """
  285. from .utils import same
  286. ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)
  287. fp64_ref = None
  288. if config.same_two_models_use_fp64:
  289. try:
  290. fp64_model, fp64_examples = cast_to_fp64(
  291. copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
  292. )
  293. fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd)
  294. except Exception:
  295. if require_fp64:
  296. raise RuntimeError("Could not generate fp64 outputs") # noqa: B904
  297. log.warning("Could not generate fp64 outputs")
  298. try:
  299. res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd)
  300. except Exception as e:
  301. # This means that the minified graph is bad/exposes a different problem.
  302. # As we are checking accuracy here, lets log the exception and return True.
  303. log.exception(
  304. "While minifying the program in accuracy minification mode, "
  305. "ran into a runtime exception which is likely an unrelated issue."
  306. " Skipping this graph."
  307. )
  308. return True
  309. passing = same(
  310. ref,
  311. res,
  312. fp64_ref,
  313. tol=config.repro_tolerance,
  314. equal_nan=True,
  315. ignore_non_fp=ignore_non_fp,
  316. )
  317. return passing
  318. def cast_dtype_args_to_fp64(model):
  319. for node in model.graph.nodes:
  320. if (
  321. node.op == "call_function"
  322. and node.target == torch.ops.prims.convert_element_type.default
  323. ):
  324. assert len(node.args) == 2
  325. if is_float_dtype(node.args[1]) and node.args[1] != torch.float64:
  326. node.args = (node.args[0], torch.float64)
  327. if node.op == "call_function":
  328. dtype = node.kwargs.get("dtype")
  329. if dtype is not None and is_float_dtype(dtype):
  330. new_kwargs = dict(node.kwargs)
  331. new_kwargs["dtype"] = torch.float64
  332. node.kwargs = new_kwargs
  333. model.graph.lint()
  334. model.recompile()
  335. return model
  336. def cast_to(dtype, model, inputs):
  337. from torch.utils._pytree import tree_map
  338. model = model.to(dtype)
  339. if dtype == torch.float64:
  340. # If casting to fp64 for accuracy comparison, we need to
  341. # replace dtype arguments embedded in the graph with fp64
  342. model = cast_dtype_args_to_fp64(model)
  343. inputs = tree_map(
  344. lambda x: x.to(dtype)
  345. if isinstance(x, torch.Tensor) and x.is_floating_point()
  346. else x,
  347. inputs,
  348. )
  349. return model, inputs
  350. def cast_to_fp64(model, inputs):
  351. return cast_to(torch.float64, model, inputs)
  352. def backend_accuracy_fails(
  353. gm,
  354. example_inputs,
  355. compiler_fn,
  356. only_fwd=False,
  357. *,
  358. require_fp64=False,
  359. ignore_non_fp=False,
  360. ):
  361. try:
  362. compiled_gm = compiler_fn(
  363. copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
  364. )
  365. return not same_two_models(
  366. gm,
  367. compiled_gm,
  368. example_inputs,
  369. only_fwd,
  370. require_fp64=require_fp64,
  371. ignore_non_fp=ignore_non_fp,
  372. )
  373. except Exception as e:
  374. # This means that the minified graph is bad/exposes a different problem.
  375. # As we are checking accuracy here, lets log the exception and return False.
  376. log.exception(
  377. "While minifying the program in accuracy minification mode, "
  378. "ran into a runtime exception which is likely an unrelated issue."
  379. " Skipping this graph"
  380. )
  381. return False
  382. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  383. # REPRO SUPPORT CODE
  384. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  385. # Helper functions for computing what the default values of tensor
  386. # values should be. These all coincide with factory functions, e.g., torch.empty
  387. def _stride_or_default(
  388. stride: Optional["torch._prims_common.StrideType"],
  389. *,
  390. shape: "torch._prims_common.ShapeType",
  391. ) -> "torch._prims_common.StrideType":
  392. return stride if stride is not None else utils.make_contiguous_strides_for(shape)
  393. def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]:
  394. return lambda x: x if x is not None else d
  395. _dtype_or_default = _mk_defaulter(torch.float32)
  396. _device_or_default = _mk_defaulter(torch.device("cpu"))
  397. _storage_offset_or_default = _mk_defaulter(0)
  398. _requires_grad_or_default = _mk_defaulter(False)
  399. _is_leaf_or_default = _mk_defaulter(False)
  400. class NopInputReader:
  401. def __init__(self):
  402. self.total = 0
  403. def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):
  404. self.total += 1
  405. def tensor(self, *args, **kwargs):
  406. pass
  407. def symint(self, *args, **kwargs):
  408. pass
  409. # TODO: Support bundling the entire repro into a zip file for ease of
  410. # transferring around
  411. class InputReader:
  412. def __init__(self, save_dir=None, *, pbar=None):
  413. # If None, we will generate random data instead. It's important
  414. # to natively support this use case as it will allow people to
  415. # share repros without including the real data, if the problem
  416. # reproduces even on random data.
  417. if save_dir is None:
  418. log.warning("no save_dir specified, will generate random data")
  419. self.store = ContentStoreReader(save_dir) if save_dir is not None else None
  420. self.args = []
  421. self.pbar = pbar
  422. def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):
  423. if self.pbar is not None:
  424. self.pbar.update(1)
  425. device = _device_or_default(device)
  426. dtype_hint = _dtype_or_default(dtype_hint)
  427. if self.store is not None and storage_hash is not None:
  428. try:
  429. storage = self.store.read_storage(storage_hash)
  430. except FileNotFoundError:
  431. pass
  432. else:
  433. if device != storage.device:
  434. log.warning("device mismatch: %s != %s", device, storage.device)
  435. # TODO: transfer it to the right device? But failing this
  436. # way would be very mysterious! Would have been better
  437. # not to store device in the serialized format...
  438. return storage
  439. log.warning("could not load %s, generating random data instead", storage_hash)
  440. shape = (nbytes // dtype_hint.itemsize,)
  441. stride = _stride_or_default(None, shape=shape)
  442. return rand_strided(shape, stride, dtype_hint, device).untyped_storage()
  443. def tensor(
  444. self,
  445. storage,
  446. shape,
  447. stride=None,
  448. *,
  449. storage_offset=None,
  450. dtype=None,
  451. requires_grad=None,
  452. is_leaf=None,
  453. **metadata,
  454. ):
  455. stride = _stride_or_default(stride, shape=shape)
  456. storage_offset = _storage_offset_or_default(storage_offset)
  457. dtype = _dtype_or_default(dtype)
  458. is_leaf = _is_leaf_or_default(is_leaf)
  459. requires_grad = _requires_grad_or_default(requires_grad)
  460. t = torch.tensor(
  461. [], dtype=dtype, device=storage.device, requires_grad=requires_grad
  462. )
  463. with torch.no_grad():
  464. t.set_(storage, storage_offset, shape, stride)
  465. if not is_leaf:
  466. # Fake up some autograd history in a very naughty way
  467. with torch.enable_grad():
  468. t = t.clone(memory_format=torch.preserve_format)
  469. with torch.no_grad():
  470. t.set_(storage, storage_offset, shape, stride)
  471. assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf
  472. torch._utils.set_tensor_metadata(t, metadata)
  473. self.args.append(t)
  474. return t # for BC
  475. def symint(self, val):
  476. self.args.append(val)
  477. return val # for BC
  478. # Here is our writer strategy:
  479. # 1. We will stream all of the inputs to disk
  480. # 2. You can now deterministically randomize the inputs, or reload
  481. # the inputs from disk
  482. # 3. You can YOLO run the script without the inputs, in which case
  483. # we'll fill the inputs with random data and pray. This is the
  484. # legacy behavior, but it's also useful if you want to find out
  485. # if we're so broken even random inputs trigger it
  486. # 4. We could offer an in process "check if the randomized thing
  487. # works too" but this is delicate so we don't do it
  488. class InputWriter:
  489. def __init__(self, save_dir, *, stable_hash=False):
  490. self._lines = []
  491. # TODO: consider ensuring tensor and storage counters line up?
  492. self.storage_counter = itertools.count()
  493. self.save_dir = save_dir
  494. self.store = (
  495. ContentStoreWriter(save_dir, stable_hash=stable_hash)
  496. if save_dir is not None
  497. else None
  498. )
  499. self.seen_storages = {}
  500. def lines(self):
  501. r = [
  502. "def load_args(reader):",
  503. ]
  504. r.extend(f" {l}" for l in self._lines)
  505. # In case we need to change the internal format of load_args
  506. # in an FC-breaking way
  507. r.append("load_args._version = 0")
  508. return r
  509. # Storages are untyped, but we need to initialize them with data if
  510. # we don't have the real data, so we give a hint saying what kind
  511. # of initialization may be appropriate
  512. #
  513. # If we had a FakeTensor, device_hint tells us what device should be
  514. def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str:
  515. ws = StorageWeakRef(untyped_storage)
  516. v = self.seen_storages.get(ws)
  517. if v is not None:
  518. return v
  519. v = f"buf{next(self.storage_counter)}"
  520. maybe_dtype_hint = ""
  521. if _dtype_or_default(None) != _dtype_or_default(dtype_hint):
  522. maybe_dtype_hint = f", dtype_hint={dtype_hint!r}"
  523. # TODO: being optional on device is kind of pointless as the default
  524. # is CPU but most repros we care about are CUDA
  525. maybe_device = ""
  526. device = untyped_storage.device
  527. if device.type == "meta":
  528. assert device_hint is not None
  529. device = device_hint
  530. if _device_or_default(None) != device:
  531. maybe_device = f", device={device!r}"
  532. nbytes = untyped_storage.nbytes()
  533. storage_hash = None
  534. if self.store is not None and untyped_storage.device.type != "meta":
  535. storage_hash = self.store.write_storage(untyped_storage)
  536. self._lines.append(
  537. f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})"
  538. )
  539. self.seen_storages[ws] = v
  540. return v
  541. def tensor(self, name, t) -> None:
  542. storage = self.storage(
  543. t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device
  544. )
  545. args = []
  546. # NB: this is positional, must come first
  547. if _stride_or_default(None, shape=t.shape) != t.stride():
  548. args.append(str(tuple(t.stride())))
  549. if _dtype_or_default(None) != t.dtype:
  550. args.append(f"dtype={t.dtype!r}")
  551. if _storage_offset_or_default(None) != t.storage_offset():
  552. args.append(f"storage_offset={t.storage_offset()!r}")
  553. tensor_metadata = torch._utils.get_tensor_metadata(t)
  554. if tensor_metadata:
  555. args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items())
  556. if _requires_grad_or_default(None) != t.requires_grad:
  557. args.append(f"requires_grad={t.requires_grad!r}")
  558. is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t)
  559. if _is_leaf_or_default(None) != is_leaf:
  560. args.append(f"is_leaf={is_leaf!r}")
  561. self._lines.append(
  562. "reader.tensor("
  563. + ", ".join([storage, str(tuple(t.shape)), *args])
  564. + f") # {name}"
  565. )
  566. # TODO: this doesn't actually symint atm
  567. def symint(self, name, val) -> None:
  568. if isinstance(val, torch.SymInt):
  569. val = val.node.hint
  570. self._lines.append(f"reader.symint({val!r}) # {name}")
  571. def aot_graph_input_parser(
  572. func: Callable[[List[Tensor]], List[Tensor]],
  573. device: str = "cuda",
  574. sym_shapes: Optional[Dict[str, int]] = None,
  575. default_sym_shape: Optional[int] = None,
  576. ) -> Dict[str, Any]:
  577. """
  578. Takes in a function which has been printed with print_readable() and constructs kwargs to run it.
  579. Handles Tensor inputs, Symints, and a graph module which might have tensor constants.
  580. Consider a function `forward` defined as follows:
  581. def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "Sym(s0)",):
  582. _tensor_constant0: "i64[4190]" = self._tensor_constant0
  583. # Further implementation
  584. kwargs = aot_graph_input_parser(forward)
  585. forward(**kwargs)
  586. """
  587. from torch.fx.graph import dtype_abbrs
  588. dtype_map = {value: key for key, value in dtype_abbrs.items()}
  589. dtype_pattern = "|".join(dtype_abbrs.values())
  590. # Extracting the source code from the function
  591. source = inspect.getsource(func)
  592. # Regular expressions
  593. tensor_assignment_regex = rf"(_tensor_constant\d+): \"({dtype_pattern})\[\s*(.*?)\s*\]\" = self\.(_tensor_constant\d+)"
  594. tensor_regex = rf"({dtype_pattern})\[\s*(.*?)\s*\]"
  595. sym_shape_regex = r"Sym\((s\d+)\)"
  596. class TensorContainer:
  597. "Container for tensors as attributes"
  598. pass
  599. # Dictionary for tensors from annotations
  600. kwargs: Dict[str, Any] = {}
  601. sym_shapes = sym_shapes or {}
  602. def get_sym_int(symint):
  603. torch._check(
  604. symint in sym_shapes or default_sym_shape is not None,
  605. lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in",
  606. )
  607. return sym_shapes.get(symint, default_sym_shape)
  608. def gen_tensor(shape, dtype) -> Tensor:
  609. # Resolve symbolic shapes to concrete values
  610. resolved_shape = []
  611. dynamic_dims = []
  612. for i, dim in enumerate(shape):
  613. dim = dim.strip()
  614. if "s" in dim:
  615. s = get_sym_int(dim)
  616. resolved_shape.append(s)
  617. dynamic_dims.append(i)
  618. else:
  619. resolved_shape.append(int(dim))
  620. constructor = torch.randn if dtype.is_floating_point else torch.zeros
  621. out = constructor(resolved_shape, dtype=dtype, device=device) # type: ignore[call-arg]
  622. for d in dynamic_dims:
  623. torch._dynamo.mark_dynamic(out, d)
  624. return out
  625. # Parse function annotations for tensor generation
  626. annotations = func.__annotations__
  627. for param, annotation in annotations.items():
  628. # Skip 'return' annotation
  629. if param == "return":
  630. continue
  631. match = re.search(tensor_regex, annotation)
  632. if match:
  633. data_type, shape_str = match.groups()
  634. shape = tuple(shape_str.split(","))
  635. dtype = dtype_map[data_type]
  636. kwargs[param] = gen_tensor(shape, dtype)
  637. match = re.search(sym_shape_regex, annotation)
  638. if match:
  639. kwargs[param] = get_sym_int(match.group(1))
  640. if "self" in inspect.signature(func).parameters:
  641. container = TensorContainer()
  642. kwargs["self"] = container
  643. for match in re.finditer(tensor_assignment_regex, source):
  644. attr_name, data_type, shape_str, _ = match.groups()
  645. shape = tuple(shape_str.split(","))
  646. dtype = dtype_map[data_type]
  647. setattr(container, attr_name, gen_tensor(shape, dtype))
  648. return kwargs