delayed_mul_tensor.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import torch
  7. from . import _Tensor, Tensor
  8. from .reference import _dims, _enable_layers, llist, ltuple
  9. class DelayedMulTensor(_Tensor):
  10. def __init__(self, lhs, rhs):
  11. self._lhs, self._rhs = lhs, rhs
  12. self._data = None
  13. self._levels_data = None
  14. self._has_device = lhs._has_device or rhs._has_device
  15. self._batchtensor_data = None
  16. self._tensor_data = None
  17. @property
  18. def _levels(self):
  19. if self._levels_data is None:
  20. levels = llist(self._lhs._levels)
  21. for l in self._rhs._levels:
  22. if l not in levels:
  23. levels.append(l)
  24. self._levels_data = ltuple(levels)
  25. return self._levels_data
  26. @property
  27. def _batchtensor(self):
  28. if self._batchtensor_data is None:
  29. with _enable_layers(self._levels):
  30. print("bt multiply fallback")
  31. self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor
  32. return self._batchtensor_data
  33. @property
  34. def _tensor(self):
  35. if self._tensor_data is None:
  36. self._tensor_data = Tensor.from_batched(
  37. self._batchtensor, self._has_device
  38. )._tensor
  39. return self._tensor_data
  40. @property
  41. def ndim(self):
  42. return self._batchtensor.ndim
  43. @property
  44. def dims(self):
  45. return ltuple(super().dims)
  46. def sum(self, dim):
  47. dims = _dims(dim, 0, False, False)
  48. n = ord("a")
  49. all_levels = self._levels
  50. def to_char(d):
  51. return chr(n + all_levels.index(d))
  52. plhs, levelslhs = self._lhs._tensor, self._lhs._levels
  53. prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
  54. new_dims = tuple(d for d in self.dims if d not in dims)
  55. new_levels = [l for l in self._levels if l not in dims]
  56. fmt = "".join(
  57. [
  58. *(to_char(d) for d in levelslhs),
  59. ",",
  60. *(to_char(d) for d in levelsrhs),
  61. "->",
  62. *(to_char(d) for d in new_levels),
  63. ]
  64. )
  65. result_data = torch.einsum(fmt, (plhs, prhs))
  66. return Tensor.from_positional(result_data, new_levels, True)