test_case.py 995 B

123456789101112131415161718192021222324252627282930313233343536
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import os
  4. from torch._dynamo.test_case import (
  5. run_tests as dynamo_run_tests,
  6. TestCase as DynamoTestCase,
  7. )
  8. from torch._inductor import config
  9. from torch._inductor.utils import fresh_inductor_cache
  10. def run_tests(needs=()):
  11. dynamo_run_tests(needs)
  12. class TestCase(DynamoTestCase):
  13. """
  14. A base TestCase for inductor tests. Enables FX graph caching and isolates
  15. the cache directory for each test.
  16. """
  17. def setUp(self):
  18. super().setUp()
  19. self._inductor_test_stack = contextlib.ExitStack()
  20. self._inductor_test_stack.enter_context(config.patch({"fx_graph_cache": True}))
  21. if (
  22. os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1"
  23. and os.environ.get("TORCH_COMPILE_DEBUG") != "1"
  24. ):
  25. self._inductor_test_stack.enter_context(fresh_inductor_cache())
  26. def tearDown(self):
  27. super().tearDown()
  28. self._inductor_test_stack.close()