_state.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # mypy: allow-untyped-defs
  2. """JIT-related state.
  3. This module stores various pieces of Python-global state relating to the JIT.
  4. This is not intended to be imported directly; please the exposed
  5. functionalities in `torch.jit`.
  6. """
  7. import os
  8. import weakref
  9. from typing import Any, Dict, Type
  10. import torch
  11. class EnabledProxy:
  12. """Stores whether the JIT is enabled or not.
  13. This is just a wrapper for a bool, so that we get reference semantics
  14. """
  15. def __init__(self):
  16. self.enabled = self.parse_env(
  17. "PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED"
  18. )
  19. def parse_env(self, name, default, true_message, false_message):
  20. value = os.environ.get(name)
  21. if value is None:
  22. return default
  23. if value.lower() in {"1", "true", "yes"}:
  24. return True
  25. elif value.lower() in {"0", "false", "no"}:
  26. return False
  27. if value == "1v":
  28. print(true_message)
  29. return True
  30. elif value == "0v":
  31. print(false_message)
  32. return False
  33. raise ValueError(f"Unknown setting of {name}. Try using 0 or 1.")
  34. def __bool__(self):
  35. return self.enabled
  36. _enabled = EnabledProxy()
  37. def disable():
  38. _enabled.enabled = False
  39. def enable():
  40. _enabled.enabled = True
  41. # The Python CompilationUnit. All functions and modules defined in Python will
  42. # live in here. It's defined in Python because doing in cpp creates static
  43. # destruction order issues.
  44. _python_cu = torch._C.CompilationUnit()
  45. # python class => ScriptClass mapping
  46. _script_classes: Dict[Type[Any], Type[Any]] = {}
  47. _name_to_pyclass: Dict[str, Type[Any]] = {}
  48. def _add_script_class(python_class, script_class):
  49. _script_classes[python_class] = script_class
  50. _name_to_pyclass[script_class.qualified_name()] = python_class
  51. def _get_script_class(python_class):
  52. override = getattr(python_class, "_jit_override_qualname", None)
  53. if override is not None:
  54. python_class = _get_python_class(override)
  55. return _script_classes.get(python_class, None)
  56. def _get_python_class(qualified_name):
  57. return _name_to_pyclass.get(qualified_name, None)
  58. def _clear_class_state():
  59. _script_classes.clear()
  60. _name_to_pyclass.clear()
  61. # Caching: we currently cache compilation of free functions and overloaded functions.
  62. # To cache free functions we hold a weak ref to the function object and
  63. # map to the compiled fn's qualified name.
  64. # To cache overloaded functions we hold a weak ref to the function obj and
  65. # map to all of its overloaded compiled fns.
  66. # In the future we could consider caching more types of objects so that
  67. # aliasing is preserved across separate compilations of the same object.
  68. _jit_caching_layer: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
  69. _jit_function_overload_caching: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
  70. def _try_get_jit_cached_overloads(key):
  71. qual_names = _jit_function_overload_caching.get(key, None)
  72. if qual_names:
  73. return [_python_cu.find_function(qual_name) for qual_name in qual_names]
  74. else:
  75. return None
  76. def _set_jit_overload_cache(key, compiled_fns):
  77. _jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns]
  78. def _try_get_jit_cached_function(key):
  79. if getattr(key, "__disable_jit_function_caching__", False) is True:
  80. return None
  81. qual_name = _jit_caching_layer.get(key, None)
  82. if qual_name:
  83. return _python_cu.find_function(qual_name)
  84. else:
  85. return None
  86. def _set_jit_function_cache(key, value):
  87. # only free functions currently supported
  88. assert isinstance(value, torch.jit.ScriptFunction)
  89. _jit_caching_layer[key] = value.qualified_name