test_minifier_common.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. import io
  4. import logging
  5. import os
  6. import re
  7. import shutil
  8. import subprocess
  9. import sys
  10. import tempfile
  11. import traceback
  12. from typing import Optional
  13. from unittest.mock import patch
  14. import torch
  15. import torch._dynamo
  16. import torch._dynamo.test_case
  17. from torch.utils._traceback import report_compile_source_on_error
  18. @dataclasses.dataclass
  19. class MinifierTestResult:
  20. minifier_code: str
  21. repro_code: str
  22. def _get_module(self, t):
  23. match = re.search(r"class Repro\(torch\.nn\.Module\):\s+([ ].*\n| *\n)+", t)
  24. assert match is not None, "failed to find module"
  25. r = match.group(0)
  26. r = re.sub(r"\s+$", "\n", r, flags=re.MULTILINE)
  27. r = re.sub(r"\n{3,}", "\n\n", r)
  28. return r.strip()
  29. def minifier_module(self):
  30. return self._get_module(self.minifier_code)
  31. def repro_module(self):
  32. return self._get_module(self.repro_code)
  33. class MinifierTestBase(torch._dynamo.test_case.TestCase):
  34. DEBUG_DIR = tempfile.mkdtemp()
  35. @classmethod
  36. def setUpClass(cls):
  37. super().setUpClass()
  38. cls._exit_stack.enter_context( # type: ignore[attr-defined]
  39. torch._dynamo.config.patch(debug_dir_root=cls.DEBUG_DIR)
  40. )
  41. # These configurations make new process startup slower. Disable them
  42. # for the minification tests to speed them up.
  43. cls._exit_stack.enter_context( # type: ignore[attr-defined]
  44. torch._inductor.config.patch(
  45. {
  46. # https://github.com/pytorch/pytorch/issues/100376
  47. "pattern_matcher": False,
  48. # multiprocess compilation takes a long time to warmup
  49. "compile_threads": 1,
  50. # https://github.com/pytorch/pytorch/issues/100378
  51. "cpp.vec_isa_ok": False,
  52. }
  53. )
  54. )
  55. @classmethod
  56. def tearDownClass(cls):
  57. if os.getenv("PYTORCH_KEEP_TMPDIR", "0") != "1":
  58. shutil.rmtree(cls.DEBUG_DIR)
  59. else:
  60. print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}")
  61. cls._exit_stack.close() # type: ignore[attr-defined]
  62. def _gen_codegen_fn_patch_code(self, device, bug_type):
  63. assert bug_type in ("compile_error", "runtime_error", "accuracy")
  64. return f"""\
  65. {torch._dynamo.config.codegen_config()}
  66. {torch._inductor.config.codegen_config()}
  67. torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_TESTING_ONLY = {bug_type!r}
  68. """
  69. def _maybe_subprocess_run(self, args, *, isolate, cwd=None):
  70. if not isolate:
  71. assert len(args) >= 2, args
  72. assert args[0] == "python3", args
  73. if args[1] == "-c":
  74. assert len(args) == 3, args
  75. code = args[2]
  76. args = ["-c"]
  77. else:
  78. assert len(args) >= 2, args
  79. with open(args[1]) as f:
  80. code = f.read()
  81. args = args[1:]
  82. # WARNING: This is not a perfect simulation of running
  83. # the program out of tree. We only interpose on things we KNOW we
  84. # need to handle for tests. If you need more stuff, you will
  85. # need to augment this appropriately.
  86. # NB: Can't use save_config because that will omit some fields,
  87. # but we must save and reset ALL fields
  88. dynamo_config = torch._dynamo.config.shallow_copy_dict()
  89. inductor_config = torch._inductor.config.shallow_copy_dict()
  90. try:
  91. stderr = io.StringIO()
  92. log_handler = logging.StreamHandler(stderr)
  93. log = logging.getLogger("torch._dynamo")
  94. log.addHandler(log_handler)
  95. try:
  96. prev_cwd = os.getcwd()
  97. if cwd is not None:
  98. os.chdir(cwd)
  99. with patch("sys.argv", args), report_compile_source_on_error():
  100. exec(code, {"__name__": "__main__", "__compile_source__": code})
  101. rc = 0
  102. except Exception:
  103. rc = 1
  104. traceback.print_exc(file=stderr)
  105. finally:
  106. log.removeHandler(log_handler)
  107. if cwd is not None:
  108. os.chdir(prev_cwd) # type: ignore[possibly-undefined]
  109. # Make sure we don't leave buggy compiled frames lying
  110. # around
  111. torch._dynamo.reset()
  112. finally:
  113. torch._dynamo.config.load_config(dynamo_config)
  114. torch._inductor.config.load_config(inductor_config)
  115. # TODO: return a more appropriate data structure here
  116. return subprocess.CompletedProcess(
  117. args,
  118. rc,
  119. b"",
  120. stderr.getvalue().encode("utf-8"),
  121. )
  122. else:
  123. return subprocess.run(args, capture_output=True, cwd=cwd, check=False)
  124. # Run `code` in a separate python process.
  125. # Returns the completed process state and the directory containing the
  126. # minifier launcher script, if `code` outputted it.
  127. def _run_test_code(self, code, *, isolate):
  128. proc = self._maybe_subprocess_run(
  129. ["python3", "-c", code], isolate=isolate, cwd=self.DEBUG_DIR
  130. )
  131. print("test stdout:", proc.stdout.decode("utf-8"))
  132. print("test stderr:", proc.stderr.decode("utf-8"))
  133. repro_dir_match = re.search(
  134. r"(\S+)minifier_launcher.py", proc.stderr.decode("utf-8")
  135. )
  136. if repro_dir_match is not None:
  137. return proc, repro_dir_match.group(1)
  138. return proc, None
  139. # Runs the minifier launcher script in `repro_dir`
  140. def _run_minifier_launcher(self, repro_dir, isolate, *, minifier_args=()):
  141. self.assertIsNotNone(repro_dir)
  142. launch_file = os.path.join(repro_dir, "minifier_launcher.py")
  143. with open(launch_file) as f:
  144. launch_code = f.read()
  145. self.assertTrue(os.path.exists(launch_file))
  146. args = ["python3", launch_file, "minify", *minifier_args]
  147. if not isolate:
  148. args.append("--no-isolate")
  149. launch_proc = self._maybe_subprocess_run(args, isolate=isolate, cwd=repro_dir)
  150. print("minifier stdout:", launch_proc.stdout.decode("utf-8"))
  151. stderr = launch_proc.stderr.decode("utf-8")
  152. print("minifier stderr:", stderr)
  153. self.assertNotIn("Input graph did not fail the tester", stderr)
  154. return launch_proc, launch_code
  155. # Runs the repro script in `repro_dir`
  156. def _run_repro(self, repro_dir, *, isolate=True):
  157. self.assertIsNotNone(repro_dir)
  158. repro_file = os.path.join(repro_dir, "repro.py")
  159. with open(repro_file) as f:
  160. repro_code = f.read()
  161. self.assertTrue(os.path.exists(repro_file))
  162. repro_proc = self._maybe_subprocess_run(
  163. ["python3", repro_file], isolate=isolate, cwd=repro_dir
  164. )
  165. print("repro stdout:", repro_proc.stdout.decode("utf-8"))
  166. print("repro stderr:", repro_proc.stderr.decode("utf-8"))
  167. return repro_proc, repro_code
  168. # Template for testing code.
  169. # `run_code` is the code to run for the test case.
  170. # `patch_code` is the code to be patched in every generated file; usually
  171. # just use this to turn on bugs via the config
  172. def _gen_test_code(self, run_code, repro_after, repro_level):
  173. return f"""\
  174. import torch
  175. import torch._dynamo
  176. {torch._dynamo.config.codegen_config()}
  177. {torch._inductor.config.codegen_config()}
  178. torch._dynamo.config.repro_after = "{repro_after}"
  179. torch._dynamo.config.repro_level = {repro_level}
  180. torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}"
  181. {run_code}
  182. """
  183. # Runs a full minifier test.
  184. # Minifier tests generally consist of 3 stages:
  185. # 1. Run the problematic code
  186. # 2. Run the generated minifier launcher script
  187. # 3. Run the generated repro script
  188. #
  189. # If possible, you should run the test with isolate=False; use
  190. # isolate=True only if the bug you're testing would otherwise
  191. # crash the process
  192. def _run_full_test(
  193. self, run_code, repro_after, expected_error, *, isolate, minifier_args=()
  194. ) -> Optional[MinifierTestResult]:
  195. if isolate:
  196. repro_level = 3
  197. elif expected_error is None or expected_error == "AccuracyError":
  198. repro_level = 4
  199. else:
  200. repro_level = 2
  201. test_code = self._gen_test_code(run_code, repro_after, repro_level)
  202. print("running test", file=sys.stderr)
  203. test_proc, repro_dir = self._run_test_code(test_code, isolate=isolate)
  204. if expected_error is None:
  205. # Just check that there was no error
  206. self.assertEqual(test_proc.returncode, 0)
  207. self.assertIsNone(repro_dir)
  208. return None
  209. # NB: Intentionally do not test return code; we only care about
  210. # actually generating the repro, we don't have to crash
  211. self.assertIn(expected_error, test_proc.stderr.decode("utf-8"))
  212. self.assertIsNotNone(repro_dir)
  213. print("running minifier", file=sys.stderr)
  214. minifier_proc, minifier_code = self._run_minifier_launcher(
  215. repro_dir, isolate=isolate, minifier_args=minifier_args
  216. )
  217. print("running repro", file=sys.stderr)
  218. repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate)
  219. self.assertIn(expected_error, repro_proc.stderr.decode("utf-8"))
  220. self.assertNotEqual(repro_proc.returncode, 0)
  221. return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code)