__init__.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # mypy: allow-untyped-defs
  2. import types
  3. from contextlib import contextmanager
  4. # The idea for this parameter is that we forbid bare assignment
  5. # to torch.backends.<cudnn|mkldnn>.enabled and friends when running our
  6. # test suite, where it's very easy to forget to undo the change
  7. # later.
  8. __allow_nonbracketed_mutation_flag = True
  9. def disable_global_flags():
  10. global __allow_nonbracketed_mutation_flag
  11. __allow_nonbracketed_mutation_flag = False
  12. def flags_frozen():
  13. return not __allow_nonbracketed_mutation_flag
  14. @contextmanager
  15. def __allow_nonbracketed_mutation():
  16. global __allow_nonbracketed_mutation_flag
  17. old = __allow_nonbracketed_mutation_flag
  18. __allow_nonbracketed_mutation_flag = True
  19. try:
  20. yield
  21. finally:
  22. __allow_nonbracketed_mutation_flag = old
  23. class ContextProp:
  24. def __init__(self, getter, setter):
  25. self.getter = getter
  26. self.setter = setter
  27. def __get__(self, obj, objtype):
  28. return self.getter()
  29. def __set__(self, obj, val):
  30. if not flags_frozen():
  31. self.setter(val)
  32. else:
  33. raise RuntimeError(
  34. f"not allowed to set {obj.__name__} flags "
  35. "after disable_global_flags; please use flags() context manager instead"
  36. )
  37. class PropModule(types.ModuleType):
  38. def __init__(self, m, name):
  39. super().__init__(name)
  40. self.m = m
  41. def __getattr__(self, attr):
  42. return self.m.__getattribute__(attr)
  43. from torch.backends import (
  44. cpu as cpu,
  45. cuda as cuda,
  46. cudnn as cudnn,
  47. mha as mha,
  48. mkl as mkl,
  49. mkldnn as mkldnn,
  50. mps as mps,
  51. nnpack as nnpack,
  52. openmp as openmp,
  53. quantized as quantized,
  54. )