cpp.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # mypy: allow-untyped-defs
  2. """Functionality for Python <-> C++ frontend inter-op."""
  3. from torch import nn
  4. class OrderedDictWrapper:
  5. """A wrapper around a C++ OrderedDict.
  6. It dynamically evaluates the OrderedDict getter on a bound C++ module, such
  7. that new changes on the C++ side are picked up. Otherwise accessing e.g.
  8. ``cpp_module._parameters`` just once would get a frozen copy of the parameters
  9. at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__``
  10. so using properties does not work.
  11. """
  12. def __init__(self, cpp_module, attr):
  13. self.cpp_module = cpp_module
  14. self.attr = attr
  15. @property
  16. def cpp_dict(self):
  17. return getattr(self.cpp_module, self.attr)
  18. # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we
  19. # must manually override them.
  20. def items(self):
  21. return self.cpp_dict.items()
  22. def keys(self):
  23. return self.cpp_dict.keys()
  24. def values(self):
  25. return self.cpp_dict.values()
  26. def __iter__(self):
  27. return self.cpp_dict.__iter__()
  28. def __len__(self):
  29. return self.cpp_dict.__len__()
  30. def __contains__(self, key):
  31. return self.cpp_dict.__contains__(key)
  32. def __getitem__(self, key):
  33. return self.cpp_dict.__getitem__(key)
  34. class ModuleWrapper(nn.Module):
  35. """A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access."""
  36. def __init__(self, cpp_module):
  37. # Assign before the super class constructor so ``self.training`` can be
  38. # assigned to in the super class constructor.
  39. self.cpp_module = cpp_module
  40. super().__init__()
  41. self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment]
  42. self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment]
  43. self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment]
  44. for attr in dir(cpp_module):
  45. # Skip magic methods and the three attributes above.
  46. if not attr.startswith("_"):
  47. setattr(self, attr, getattr(self.cpp_module, attr))
  48. def _apply(self, fn, recurse=True):
  49. for param in self.parameters():
  50. # Tensors stored in modules are graph leaves, and we don't
  51. # want to create copy nodes, so we have to unpack the data.
  52. param.data = fn(param.data)
  53. if param._grad is not None:
  54. param._grad.data = fn(param._grad.data)
  55. for buf in self.buffers():
  56. buf.data = fn(buf.data)
  57. return self
  58. # nn.Module defines training as a boolean
  59. @property # type: ignore[override]
  60. def training(self):
  61. return self.cpp_module.training
  62. @training.setter
  63. def training(self, mode):
  64. self.cpp_module.train(mode)
  65. def __repr__(self):
  66. return self.cpp_module.__repr__()