__init__.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import dis
  2. import inspect
  3. from typing import Sequence, Union
  4. import functorch._C
  5. import torch
  6. from functorch._C import dim as _C
  7. from .tree_map import tree_flatten, tree_map
  8. from .wrap_type import wrap_type
  9. _C._patch_tensor_class()
  10. dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
  11. class DimensionMismatchError(Exception):
  12. pass
  13. class DimensionBindError(Exception):
  14. pass
  15. from . import op_properties
  16. # use dict to avoid writing C++ bindings for set
  17. pointwise = dict.fromkeys(op_properties.pointwise, True)
  18. use_c = True
  19. if not use_c:
  20. from . import reference
  21. class _Tensor:
  22. # fast path around slow wrapping/unwrapping logic for simply queries used
  23. # by the implementation...
  24. @property
  25. def dims(self):
  26. return tuple(d for d in self._levels if isinstance(d, Dim))
  27. def dim(self):
  28. return self.ndim
  29. if use_c:
  30. __torch_function__ = classmethod(_C.__torch_function__)
  31. expand = _C._instancemethod(_C.expand)
  32. else:
  33. __torch_function__ = reference.__torch_function__
  34. expand = reference.expand
  35. index = _C._instancemethod(_C.index)
  36. def __repr__(self):
  37. tensor, levels, ndim = self._tensor, self._levels, self.ndim
  38. return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}"
  39. TensorLike = (_Tensor, torch.Tensor)
  40. class Dim(_C.Dim, _Tensor):
  41. # note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence.
  42. # Tensor defines format, but we want to print Dims with special formatting
  43. __format__ = object.__format__
  44. class Tensor(_Tensor, _C.Tensor):
  45. if not use_c:
  46. from_batched = staticmethod(_C.Tensor_from_batched)
  47. from_positional = staticmethod(_C.Tensor_from_positional)
  48. sum = _C._instancemethod(_C.Tensor_sum)
  49. def cat(tensors, dim, new_dim):
  50. n = dims()
  51. return stack(tensors, n, dim).index([n, dim], new_dim)
  52. if use_c:
  53. _wrap = _C._wrap
  54. def _def(name, *args, **kwargs):
  55. orig = getattr(torch.Tensor, name)
  56. setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs)))
  57. t__getitem__ = _C._instancemethod(_C.__getitem__)
  58. stack = _C.stack
  59. split = _C._instancemethod(_C.split)
  60. else:
  61. _wrap, _def = reference._wrap, reference._def
  62. t__getitem__ = reference.t__getitem__
  63. stack = reference.stack
  64. split = reference.split
  65. # note: there is no python reference
  66. t__setitem__ = _C._instancemethod(_C.__setitem__)
  67. # this is patched in the C API because otherwise torch.Tensor will
  68. # no longer be considered a sequence and things will break
  69. # torch.Tensor.__getitem__ = t__getitem__
  70. _Tensor.__getitem__ = t__getitem__
  71. # torch.Tensor.__setitem__ = t__setitem__
  72. _Tensor.__setitem__ = t__setitem__
  73. torch.Tensor.split = split
  74. _Tensor.split = split
  75. torch.Tensor.expand = _C._instancemethod(_C.expand)
  76. torch.Tensor.index = _C._instancemethod(_C.index)
  77. wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__)
  78. del _Tensor.ndim
  79. if use_c:
  80. _Tensor.order = _C._instancemethod(_C.order)
  81. else:
  82. _Tensor.order = reference.positional
  83. _def("mean")
  84. _def("sum")
  85. _def("all")
  86. _def("amax")
  87. _def("amin")
  88. _def("aminmax")
  89. _def("any")
  90. _def("count_nonzero")
  91. _def("logsumexp")
  92. _def("nanmean")
  93. _def("nansum")
  94. _def("prod")
  95. _def("std", keepdim_offset=2)
  96. _def("var", keepdim_offset=2)
  97. _def("max", single_dim=True)
  98. _def("min", single_dim=True)
  99. _def("argmax", single_dim=True)
  100. _def("argmin", single_dim=True)
  101. _def("kthvalue", single_dim=True)
  102. _def("median", single_dim=True)
  103. _def("nanmedian", single_dim=True)
  104. _def("mode", single_dim=True)
  105. _def("sort", reduce=False)
  106. _def("argsort", reduce=False)
  107. _def("unbind", single_dim=True)
  108. _def("chunk", dim_offset=1, reduce=False)
  109. _def("cummax", single_dim=True, reduce=False)
  110. _def("cummin", single_dim=True, reduce=False)
  111. _def("cumprod", single_dim=True, reduce=False)
  112. _def("cumprod_", single_dim=True, reduce=False)
  113. _def("cumsum", single_dim=True, reduce=False)
  114. _def("cumsum_", single_dim=True, reduce=False)
  115. _def("logcumsumexp", single_dim=True, reduce=False)
  116. _def("renorm", dim_offset=1, single_dim=True, reduce=False)
  117. _def("softmax", single_dim=True, reduce=False)
  118. softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
  119. # stuff to handle in the future, because they require special
  120. # binding logic for dims
  121. # cross
  122. # diag_embed
  123. # diagonal
  124. # diagonal_scatter
  125. # diff
  126. # nanquantile
  127. # quantile
  128. # roll
  129. # rot90
  130. # topk (new dimes on output)
  131. # should these all be subsumed by inplace indexing?
  132. # index_add_
  133. # index_add
  134. # index_copy
  135. # index_copy_
  136. # index_fill
  137. # index_fill_
  138. # index_select
  139. # scatter
  140. # scatter_
  141. # scatter_add
  142. # scatter_add_
  143. # scatter_reduce