test_case.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import importlib
  4. import logging
  5. import torch
  6. import torch.testing
  7. from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
  8. IS_WINDOWS,
  9. TEST_WITH_CROSSREF,
  10. TEST_WITH_TORCHDYNAMO,
  11. TestCase as TorchTestCase,
  12. )
  13. from . import config, reset, utils
  14. log = logging.getLogger(__name__)
  15. def run_tests(needs=()):
  16. from torch.testing._internal.common_utils import run_tests
  17. if TEST_WITH_TORCHDYNAMO or IS_WINDOWS or TEST_WITH_CROSSREF:
  18. return # skip testing
  19. if isinstance(needs, str):
  20. needs = (needs,)
  21. for need in needs:
  22. if need == "cuda" and not torch.cuda.is_available():
  23. return
  24. else:
  25. try:
  26. importlib.import_module(need)
  27. except ImportError:
  28. return
  29. run_tests()
  30. class TestCase(TorchTestCase):
  31. _exit_stack: contextlib.ExitStack
  32. @classmethod
  33. def tearDownClass(cls):
  34. cls._exit_stack.close()
  35. super().tearDownClass()
  36. @classmethod
  37. def setUpClass(cls):
  38. super().setUpClass()
  39. cls._exit_stack = contextlib.ExitStack() # type: ignore[attr-defined]
  40. cls._exit_stack.enter_context( # type: ignore[attr-defined]
  41. config.patch(
  42. raise_on_ctx_manager_usage=True,
  43. suppress_errors=False,
  44. log_compilation_metrics=False,
  45. ),
  46. )
  47. def setUp(self):
  48. self._prior_is_grad_enabled = torch.is_grad_enabled()
  49. super().setUp()
  50. reset()
  51. utils.counters.clear()
  52. def tearDown(self):
  53. for k, v in utils.counters.items():
  54. print(k, v.most_common())
  55. reset()
  56. utils.counters.clear()
  57. super().tearDown()
  58. if self._prior_is_grad_enabled is not torch.is_grad_enabled():
  59. log.warning("Running test changed grad mode")
  60. torch.set_grad_enabled(self._prior_is_grad_enabled)