testing.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import dis
  4. import functools
  5. import logging
  6. import os.path
  7. import random
  8. import re
  9. import sys
  10. import types
  11. import unittest
  12. from typing import List, Optional, Sequence, Union
  13. from unittest.mock import patch
  14. np: Optional[types.ModuleType] = None
  15. try:
  16. import numpy as np
  17. except ModuleNotFoundError:
  18. np = None
  19. import torch
  20. from torch import fx
  21. from torch._dynamo.output_graph import OutputGraph
  22. from . import config, eval_frame, optimize_assert, reset
  23. from .bytecode_transformation import (
  24. create_instruction,
  25. debug_checks,
  26. is_generator,
  27. transform_code_object,
  28. )
  29. from .guards import CheckFunctionManager, GuardedCode
  30. from .utils import same
  31. unsupported = eval_frame.unsupported
  32. three = 3
  33. log = logging.getLogger(__name__)
  34. def clone_me(x):
  35. if x is None:
  36. return None
  37. return x.detach().clone().requires_grad_(x.requires_grad)
  38. def remove_optimized_module_prefix(name) -> str:
  39. return re.sub(r"^_orig_mod[.]", "", name)
  40. def collect_results(model, prediction, loss, example_inputs):
  41. results = []
  42. results.append(prediction)
  43. results.append(loss)
  44. # if isinstance(loss, torch.Tensor) and loss.item() > 1:
  45. # log.warning(
  46. # f"High loss value alert - {loss:.2f}. Can result in unstable gradients."
  47. # )
  48. grads = dict()
  49. params = dict()
  50. for name, param in model.named_parameters():
  51. if isinstance(model, eval_frame.OptimizedModule):
  52. name = remove_optimized_module_prefix(name)
  53. param_copy = param
  54. grad = param.grad
  55. # Treat None and zero grad as same
  56. if param.grad is None:
  57. grad = torch.zeros_like(param)
  58. grads[name + ".grad"] = grad
  59. params[name] = param_copy
  60. results.append(grads)
  61. results.append(params)
  62. buffers = dict()
  63. for name, buffer in model.named_buffers():
  64. if isinstance(model, eval_frame.OptimizedModule):
  65. name = remove_optimized_module_prefix(name)
  66. buffers[name] = buffer
  67. results.append(buffers)
  68. for example in example_inputs:
  69. if isinstance(example, (tuple, list)):
  70. for inp in example:
  71. if isinstance(inp, torch.Tensor):
  72. results.append(inp.grad)
  73. else:
  74. if isinstance(example, torch.Tensor):
  75. results.append(example.grad)
  76. return results
  77. def requires_bwd_pass(out):
  78. if isinstance(out, torch.Tensor):
  79. return out.requires_grad
  80. elif isinstance(out, (list, tuple)):
  81. return any(requires_bwd_pass(x) for x in out)
  82. elif out is None:
  83. return False
  84. elif isinstance(out, int):
  85. return False
  86. raise NotImplementedError("Don't know how to reduce", type(out))
  87. def reduce_to_scalar_loss(out):
  88. """Reduce the output of a model to get scalar loss"""
  89. if isinstance(out, torch.Tensor):
  90. # Mean does not work on integer tensors
  91. return out.sum() / out.numel()
  92. elif isinstance(out, (list, tuple)):
  93. return sum(reduce_to_scalar_loss(x) for x in out) / len(out)
  94. elif type(out).__name__ in (
  95. "MaskedLMOutput",
  96. "Seq2SeqLMOutput",
  97. "CausalLMOutputWithCrossAttentions",
  98. ):
  99. return reduce_to_scalar_loss(out.logits)
  100. elif type(out).__name__ == "SquashedNormal":
  101. return out.mean.sum()
  102. elif isinstance(out, dict):
  103. return sum(reduce_to_scalar_loss(value) for value in out.values()) / len(
  104. out.keys()
  105. )
  106. raise NotImplementedError("Don't know how to reduce", type(out))
  107. def debug_dir() -> str:
  108. path = os.path.join(os.path.dirname(__file__), "../debug")
  109. if not os.path.exists(path):
  110. os.mkdir(path)
  111. return path
  112. def debug_dump(name, code: types.CodeType, extra="") -> None:
  113. with open(os.path.join(debug_dir(), name), "w") as fd:
  114. fd.write(
  115. f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n"
  116. )
  117. def debug_insert_nops(
  118. frame, cache_size, hooks, _, *, skip: int = 0
  119. ) -> Optional[GuardedCode]:
  120. """used to debug jump updates"""
  121. def insert_nops(instructions, code_options):
  122. instructions.insert(0, create_instruction("NOP"))
  123. instructions.insert(0, create_instruction("NOP"))
  124. if is_generator(frame.f_code):
  125. return None
  126. debug_checks(frame.f_code)
  127. code = transform_code_object(frame.f_code, insert_nops)
  128. graph = OutputGraph(
  129. code_options={},
  130. compiler_fn=None,
  131. root_tx=None,
  132. export=False,
  133. export_constraints=None,
  134. frame_state={"_id": 0},
  135. # TODO: shouldn't this be f_locals/f_globals from frame?
  136. local_scope=locals(),
  137. global_scope=globals(),
  138. f_code=frame.f_code,
  139. )
  140. return GuardedCode(code, CheckFunctionManager(graph).check_fn)
  141. class CompileCounter:
  142. def __init__(self):
  143. self.frame_count = 0
  144. self.op_count = 0
  145. def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
  146. self.frame_count += 1
  147. for node in gm.graph.nodes:
  148. if "call" in node.op:
  149. self.op_count += 1
  150. return gm.forward
  151. def clear(self):
  152. self.frame_count = 0
  153. self.op_count = 0
  154. class CompileCounterWithBackend:
  155. def __init__(self, backend):
  156. self.frame_count = 0
  157. self.op_count = 0
  158. self.backend = backend
  159. self.graphs = []
  160. def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
  161. from .backends.registry import lookup_backend
  162. self.frame_count += 1
  163. for node in gm.graph.nodes:
  164. if "call" in node.op:
  165. self.op_count += 1
  166. self.graphs.append(gm)
  167. return lookup_backend(self.backend)(gm, example_inputs)
  168. # Equivalent to backend="eager", but also records graphs that
  169. # we can assert on
  170. class EagerAndRecordGraphs:
  171. def __init__(self):
  172. self.graphs = []
  173. def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
  174. self.graphs.append(gm)
  175. return gm.forward
  176. def strip_comment(code) -> str:
  177. code = str(code)
  178. return re.sub(r"(?m)^ *#.*\n?", "", code)
  179. def remove_trailing_space(code) -> str:
  180. return "\n".join([line.rstrip() for line in code.split("\n")])
  181. def normalize_gm(gm_str) -> str:
  182. # strip comments as comments have path to files which may differ from
  183. # system to system.
  184. return remove_trailing_space(strip_comment(gm_str))
  185. def standard_test(
  186. self,
  187. fn,
  188. nargs,
  189. expected_ops=None,
  190. expected_ops_dynamic=None,
  191. expected_frame_count=1,
  192. ):
  193. if not config.assume_static_by_default and expected_ops_dynamic is not None:
  194. expected_ops = expected_ops_dynamic
  195. actual = CompileCounter()
  196. args1 = [torch.randn(10, 10) for _ in range(nargs)]
  197. args2 = [torch.randn(10, 10) for _ in range(nargs)]
  198. correct1 = fn(*args1)
  199. correct2 = fn(*args2)
  200. reset()
  201. opt_fn = optimize_assert(actual)(fn)
  202. val1a = opt_fn(*args1)
  203. val2a = opt_fn(*args2)
  204. val1b = opt_fn(*args1)
  205. val2b = opt_fn(*args2)
  206. reset()
  207. self.assertTrue(same(val1a, correct1))
  208. self.assertTrue(same(val1b, correct1))
  209. self.assertTrue(same(val2a, correct2))
  210. self.assertTrue(same(val2b, correct2))
  211. self.assertEqual(actual.frame_count, expected_frame_count)
  212. if expected_ops is not None:
  213. self.assertEqual(actual.op_count, expected_ops)
  214. def dummy_fx_compile(gm: fx.GraphModule, example_inputs):
  215. return gm.forward
  216. def format_speedup(speedup, pvalue, is_correct=True, pvalue_threshold=0.1):
  217. if not is_correct:
  218. return "ERROR"
  219. if pvalue > pvalue_threshold:
  220. return f"{speedup:.3f}x SAME"
  221. return f"{speedup:.3f}x p={pvalue:.2f}"
  222. def rand_strided(
  223. size: Sequence[int],
  224. stride: Sequence[int],
  225. dtype: torch.dtype = torch.float32,
  226. device: Union[str, torch.device] = "cpu",
  227. extra_size: int = 0,
  228. ):
  229. needed_size = (
  230. sum((shape - 1) * stride for shape, stride in zip(size, stride))
  231. + 1
  232. + extra_size
  233. )
  234. if dtype.is_floating_point:
  235. if dtype.itemsize == 1:
  236. """
  237. normal distribution kernel is not implemented for fp8..
  238. Workaround that by creating a fp16 tensor and then cast.
  239. """
  240. buffer = torch.randn(needed_size, dtype=torch.float16, device=device).to(
  241. dtype=dtype
  242. )
  243. else:
  244. buffer = torch.randn(needed_size, dtype=dtype, device=device)
  245. else:
  246. buffer = torch.zeros(size=[needed_size], dtype=dtype, device=device)
  247. return torch.as_strided(buffer, size, stride)
  248. def _make_fn_with_patches(fn, *patches):
  249. @functools.wraps(fn)
  250. def _fn(*args, **kwargs):
  251. with contextlib.ExitStack() as stack:
  252. for module, attr, val in patches:
  253. stack.enter_context(patch.object(module, attr, val))
  254. return fn(*args, **kwargs)
  255. return _fn
  256. def make_test_cls_with_patches(
  257. cls, cls_prefix, fn_suffix, *patches, xfail_prop=None, decorator=lambda x: x
  258. ):
  259. DummyTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
  260. DummyTestClass.__qualname__ = DummyTestClass.__name__
  261. for name in dir(cls):
  262. if name.startswith("test_"):
  263. fn = getattr(cls, name)
  264. if not callable(fn):
  265. setattr(DummyTestClass, name, getattr(cls, name))
  266. continue
  267. new_name = f"{name}{fn_suffix}"
  268. new_fn = _make_fn_with_patches(fn, *patches)
  269. new_fn.__name__ = new_name
  270. if xfail_prop is not None and hasattr(fn, xfail_prop):
  271. new_fn = unittest.expectedFailure(new_fn)
  272. setattr(DummyTestClass, new_name, decorator(new_fn))
  273. # NB: Doesn't handle slots correctly, but whatever
  274. elif not hasattr(DummyTestClass, name):
  275. setattr(DummyTestClass, name, getattr(cls, name))
  276. return DummyTestClass
  277. # test Python 3.11+ specific features
  278. def skipIfNotPy311(fn):
  279. if sys.version_info >= (3, 11):
  280. return fn
  281. return unittest.skip(fn)
  282. def skipIfNotPy312(fn):
  283. if sys.version_info >= (3, 12):
  284. return fn
  285. return unittest.skip(fn)
  286. def xfailIfPy312(fn):
  287. if sys.version_info >= (3, 12):
  288. return unittest.expectedFailure(fn)
  289. return fn
  290. def skipIfPy312(fn):
  291. if sys.version_info >= (3, 12):
  292. return unittest.skip(fn)
  293. return fn
  294. # Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py
  295. # and test/dynamo/test_dynamic_shapes.py
  296. def expectedFailureDynamic(fn):
  297. fn._expected_failure_dynamic = True
  298. return fn
  299. # Controls tests generated in test/inductor/test_torchinductor_codegen_dynamic_shapes.py
  300. def expectedFailureCodegenDynamic(fn):
  301. fn._expected_failure_codegen_dynamic = True
  302. return fn
  303. # Controls test generated in test/inductor/test_cpp_wrapper.py
  304. def expectedFailureDynamicWrapper(fn):
  305. fn._expected_failure_dynamic_wrapper = True
  306. return fn
  307. def reset_rng_state(use_xla=False):
  308. torch.manual_seed(1337)
  309. random.seed(1337)
  310. if np:
  311. np.random.seed(1337)
  312. if use_xla:
  313. import torch_xla.core.xla_model as xm
  314. xm.set_rng_state(1337, str(xm.xla_device()))