__init__.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import torch
  2. from . import convert_frame, eval_frame, resume_execution
  3. from .backends.registry import list_backends, lookup_backend, register_backend
  4. from .callback import callback_handler, on_compile_end, on_compile_start
  5. from .code_context import code_context
  6. from .convert_frame import replay
  7. from .decorators import (
  8. allow_in_graph,
  9. assume_constant_result,
  10. disable,
  11. disallow_in_graph,
  12. forbid_in_graph,
  13. graph_break,
  14. mark_dynamic,
  15. mark_static,
  16. mark_static_address,
  17. maybe_mark_dynamic,
  18. run,
  19. )
  20. from .eval_frame import (
  21. _reset_guarded_backend_cache,
  22. explain,
  23. export,
  24. is_dynamo_supported,
  25. is_inductor_supported,
  26. optimize,
  27. optimize_assert,
  28. OptimizedModule,
  29. reset_code,
  30. )
  31. from .external_utils import is_compiling
  32. from .mutation_guard import GenerationTracker
  33. from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
  34. __all__ = [
  35. "allow_in_graph",
  36. "assume_constant_result",
  37. "disallow_in_graph",
  38. "forbid_in_graph",
  39. "graph_break",
  40. "mark_dynamic",
  41. "maybe_mark_dynamic",
  42. "mark_static",
  43. "mark_static_address",
  44. "optimize",
  45. "optimize_assert",
  46. "export",
  47. "explain",
  48. "run",
  49. "replay",
  50. "disable",
  51. "reset",
  52. "OptimizedModule",
  53. "is_compiling",
  54. "register_backend",
  55. "list_backends",
  56. "lookup_backend",
  57. ]
  58. if torch.manual_seed is torch.random.manual_seed:
  59. import torch.jit._builtins
  60. # Wrap manual_seed with the disable decorator.
  61. # Can't do it at its implementation due to dependency issues.
  62. torch.manual_seed = torch._disable_dynamo(torch.manual_seed)
  63. # Add the new manual_seed to the builtin registry.
  64. torch.jit._builtins._register_builtin(torch.manual_seed, "aten::manual_seed")
  65. def reset() -> None:
  66. """Clear all compile caches and restore initial state"""
  67. with convert_frame.compile_lock:
  68. reset_code_caches()
  69. convert_frame.input_codes.clear()
  70. convert_frame.output_codes.clear()
  71. orig_code_map.clear()
  72. guard_failures.clear()
  73. graph_break_reasons.clear()
  74. resume_execution.ContinueExecutionCache.cache.clear()
  75. _reset_guarded_backend_cache()
  76. reset_frame_count()
  77. torch._C._dynamo.compiled_autograd.clear_cache()
  78. convert_frame.FRAME_COUNTER = 0
  79. convert_frame.FRAME_COMPILE_COUNTER.clear()
  80. callback_handler.clear()
  81. GenerationTracker.clear()
  82. torch._dynamo.utils.warn_once_cache.clear()
  83. def reset_code_caches() -> None:
  84. """Clear compile caches that are keyed by code objects"""
  85. with convert_frame.compile_lock:
  86. for weak_code in (
  87. convert_frame.input_codes.seen + convert_frame.output_codes.seen
  88. ):
  89. code = weak_code()
  90. if code:
  91. reset_code(code)
  92. code_context.clear()