| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- # mypy: allow-untyped-defs
- """Functionality for Python <-> C++ frontend inter-op."""
- from torch import nn
- class OrderedDictWrapper:
- """A wrapper around a C++ OrderedDict.
- It dynamically evaluates the OrderedDict getter on a bound C++ module, such
- that new changes on the C++ side are picked up. Otherwise accessing e.g.
- ``cpp_module._parameters`` just once would get a frozen copy of the parameters
- at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__``
- so using properties does not work.
- """
- def __init__(self, cpp_module, attr):
- self.cpp_module = cpp_module
- self.attr = attr
- @property
- def cpp_dict(self):
- return getattr(self.cpp_module, self.attr)
- # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we
- # must manually override them.
- def items(self):
- return self.cpp_dict.items()
- def keys(self):
- return self.cpp_dict.keys()
- def values(self):
- return self.cpp_dict.values()
- def __iter__(self):
- return self.cpp_dict.__iter__()
- def __len__(self):
- return self.cpp_dict.__len__()
- def __contains__(self, key):
- return self.cpp_dict.__contains__(key)
- def __getitem__(self, key):
- return self.cpp_dict.__getitem__(key)
- class ModuleWrapper(nn.Module):
- """A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access."""
- def __init__(self, cpp_module):
- # Assign before the super class constructor so ``self.training`` can be
- # assigned to in the super class constructor.
- self.cpp_module = cpp_module
- super().__init__()
- self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment]
- self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment]
- self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment]
- for attr in dir(cpp_module):
- # Skip magic methods and the three attributes above.
- if not attr.startswith("_"):
- setattr(self, attr, getattr(self.cpp_module, attr))
- def _apply(self, fn, recurse=True):
- for param in self.parameters():
- # Tensors stored in modules are graph leaves, and we don't
- # want to create copy nodes, so we have to unpack the data.
- param.data = fn(param.data)
- if param._grad is not None:
- param._grad.data = fn(param._grad.data)
- for buf in self.buffers():
- buf.data = fn(buf.data)
- return self
- # nn.Module defines training as a boolean
- @property # type: ignore[override]
- def training(self):
- return self.cpp_module.training
- @training.setter
- def training(self, mode):
- self.cpp_module.train(mode)
- def __repr__(self):
- return self.cpp_module.__repr__()
|