common_subclass.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # mypy: ignore-errors
  2. import torch
  3. from copy import deepcopy
  4. from torch.utils._pytree import tree_map
  5. # TODO: Move LoggingTensor here.
  6. from torch.testing._internal.logging_tensor import LoggingTensor
  7. # Base class for wrapper-style tensors.
  8. class WrapperTensor(torch.Tensor):
  9. @staticmethod
  10. def __new__(cls, *args, **kwargs):
  11. t, kwargs = cls.get_wrapper_properties(*args, **kwargs)
  12. if "size" not in kwargs:
  13. size = t.size()
  14. else:
  15. size = kwargs["size"]
  16. del kwargs["size"]
  17. if "dtype" not in kwargs:
  18. kwargs["dtype"] = t.dtype
  19. if "layout" not in kwargs:
  20. kwargs["layout"] = t.layout
  21. if "device" not in kwargs:
  22. kwargs["device"] = t.device
  23. if "requires_grad" not in kwargs:
  24. kwargs["requires_grad"] = False
  25. # Ignore memory_format and pin memory for now as I don't know how to
  26. # safely access them on a Tensor (if possible??)
  27. wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs)
  28. wrapper._validate_methods()
  29. return wrapper
  30. @classmethod
  31. def get_wrapper_properties(cls, *args, **kwargs):
  32. # Should return both an example Tensor and a dictionary of kwargs
  33. # to override any of that example Tensor's properly.
  34. # This is very similar to the `t.new_*(args)` API
  35. raise NotImplementedError("You need to implement get_wrapper_properties")
  36. def _validate_methods(self):
  37. # Skip this if not in debug mode?
  38. # Changing these on the python side is wrong as it would not be properly reflected
  39. # on the c++ side
  40. # This doesn't catch attributes set in the __init__
  41. forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"]
  42. for el in forbidden_overrides:
  43. if getattr(self.__class__, el) is not getattr(torch.Tensor, el):
  44. raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the "
  45. f"property {el} but this is not allowed as such change would "
  46. "not be reflected to c++ callers.")
  47. class DiagTensorBelow(WrapperTensor):
  48. @classmethod
  49. def get_wrapper_properties(cls, diag, requires_grad=False):
  50. assert diag.ndim == 1
  51. return diag, {"size": diag.size() + diag.size(), "requires_grad": requires_grad}
  52. def __init__(self, diag, requires_grad=False):
  53. self.diag = diag
  54. handled_ops = {}
  55. @classmethod
  56. def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
  57. if not all(issubclass(cls, t) for t in types):
  58. return NotImplemented
  59. # For everything else, call the handler:
  60. fn = cls.handled_ops.get(func.__name__, None)
  61. if fn:
  62. return fn(*args, **(kwargs or {}))
  63. else:
  64. # Note that here, because we don't need to provide the autograd formulas
  65. # we can have a default "fallback" that creates a plain Tensor based
  66. # on the diag elements and calls the func again.
  67. def unwrap(e):
  68. return e.diag.diag() if isinstance(e, DiagTensorBelow) else e
  69. def wrap(e):
  70. if isinstance(e, torch.Tensor) and e.ndim == 1:
  71. return DiagTensorBelow(e)
  72. if isinstance(e, torch.Tensor) and e.ndim == 2 and e.count_nonzero() == e.diag().count_nonzero():
  73. return DiagTensorBelow(e.diag())
  74. return e
  75. rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
  76. return rs
  77. def __repr__(self):
  78. return super().__repr__(tensor_contents=f"diag={self.diag}")
  79. class SparseTensor(WrapperTensor):
  80. @classmethod
  81. def get_wrapper_properties(cls, size, values, indices, requires_grad=False):
  82. assert values.device == indices.device
  83. return values, {"size": size, "requires_grad": requires_grad}
  84. def __init__(self, size, values, indices, requires_grad=False):
  85. self.values = values
  86. self.indices = indices
  87. def __repr__(self):
  88. return super().__repr__(tensor_contents=f"values={self.values}, indices={self.indices}")
  89. def sparse_to_dense(self):
  90. res = torch.zeros(self.size(), dtype=self.values.dtype)
  91. res[self.indices.unbind(1)] = self.values
  92. return res
  93. @staticmethod
  94. def from_dense(t):
  95. indices = t.nonzero()
  96. values = t[indices.unbind(1)]
  97. return SparseTensor(t.size(), values, indices)
  98. @classmethod
  99. def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
  100. func_name = f"{func.__module__}.{func.__name__}"
  101. res = cls._try_call_special_impl(func_name, args, kwargs)
  102. if res is not NotImplemented:
  103. return res
  104. # Otherwise, use a default implementation that construct dense
  105. # tensors and use that to compute values
  106. def unwrap(e):
  107. return e.sparse_to_dense() if isinstance(e, SparseTensor) else e
  108. # Wrap back all Tensors into our custom class
  109. def wrap(e):
  110. # Check for zeros and use that to get indices
  111. return SparseTensor.from_dense(e) if isinstance(e, torch.Tensor) else e
  112. rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
  113. return rs
  114. # To show how things happen later
  115. def __rmul__(self, other):
  116. return super().__rmul__(other)
  117. _SPECIAL_IMPLS = {}
  118. @classmethod
  119. def _try_call_special_impl(cls, func, args, kwargs):
  120. if func not in cls._SPECIAL_IMPLS:
  121. return NotImplemented
  122. return cls._SPECIAL_IMPLS[func](args, kwargs)
  123. # Example non-wrapper subclass that stores extra state.
  124. class NonWrapperTensor(torch.Tensor):
  125. def __new__(cls, data):
  126. t = torch.Tensor._make_subclass(cls, data)
  127. t.extra_state = {
  128. 'last_func_called': None
  129. }
  130. return t
  131. @classmethod
  132. def __torch_function__(cls, func, types, args=(), kwargs=None):
  133. result = super().__torch_function__(func, types, args, kwargs)
  134. if isinstance(result, cls):
  135. # Do something with the extra state. For the example here, just store the name of the
  136. # last function called (skip for deepcopy so the copy has the same extra state).
  137. if func is torch.Tensor.__deepcopy__:
  138. result.extra_state = deepcopy(args[0].extra_state)
  139. else:
  140. result.extra_state = {
  141. 'last_func_called': func.__name__,
  142. }
  143. return result
  144. # new_empty() must be defined for deepcopy to work
  145. def new_empty(self, shape):
  146. return type(self)(torch.empty(shape))
  147. # Class used to store info about subclass tensors used in testing.
  148. class SubclassInfo:
  149. __slots__ = ['name', 'create_fn', 'closed_under_ops']
  150. def __init__(self, name, create_fn, closed_under_ops=True):
  151. self.name = name
  152. self.create_fn = create_fn # create_fn(shape) -> tensor instance
  153. self.closed_under_ops = closed_under_ops
  154. subclass_db = {
  155. torch.Tensor: SubclassInfo(
  156. 'base_tensor', create_fn=torch.randn
  157. ),
  158. NonWrapperTensor: SubclassInfo(
  159. 'non_wrapper_tensor',
  160. create_fn=lambda shape: NonWrapperTensor(torch.randn(shape))
  161. ),
  162. LoggingTensor: SubclassInfo(
  163. 'logging_tensor',
  164. create_fn=lambda shape: LoggingTensor(torch.randn(shape))
  165. ),
  166. SparseTensor: SubclassInfo(
  167. 'sparse_tensor',
  168. create_fn=lambda shape: SparseTensor.from_dense(torch.randn(shape).relu())
  169. ),
  170. DiagTensorBelow: SubclassInfo(
  171. 'diag_tensor_below',
  172. create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)),
  173. closed_under_ops=False # sparse semantics
  174. ),
  175. }