_classes.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # mypy: allow-untyped-defs
  2. import types
  3. import torch._C
  4. class _ClassNamespace(types.ModuleType):
  5. def __init__(self, name):
  6. super().__init__("torch.classes" + name)
  7. self.name = name
  8. def __getattr__(self, attr):
  9. proxy = torch._C._get_custom_class_python_wrapper(self.name, attr)
  10. if proxy is None:
  11. raise RuntimeError(f"Class {self.name}.{attr} not registered!")
  12. return proxy
  13. class _Classes(types.ModuleType):
  14. __file__ = "_classes.py"
  15. def __init__(self):
  16. super().__init__("torch.classes")
  17. def __getattr__(self, name):
  18. namespace = _ClassNamespace(name)
  19. setattr(self, name, namespace)
  20. return namespace
  21. @property
  22. def loaded_libraries(self):
  23. return torch.ops.loaded_libraries
  24. def load_library(self, path):
  25. """
  26. Loads a shared library from the given path into the current process.
  27. The library being loaded may run global initialization code to register
  28. custom classes with the PyTorch JIT runtime. This allows dynamically
  29. loading custom classes. For this, you should compile your class
  30. and the static registration code into a shared library object, and then
  31. call ``torch.classes.load_library('path/to/libcustom.so')`` to load the
  32. shared object.
  33. After the library is loaded, it is added to the
  34. ``torch.classes.loaded_libraries`` attribute, a set that may be inspected
  35. for the paths of all libraries loaded using this function.
  36. Args:
  37. path (str): A path to a shared library to load.
  38. """
  39. torch.ops.load_library(path)
  40. # The classes "namespace"
  41. classes = _Classes()