TriangularOpsUtils.h 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. #include <ATen/core/Tensor.h>
  2. #include <ATen/native/LinearAlgebraUtils.h>
  3. namespace at::native {
  4. /*
  5. * Given batches of matrices with arbitrary batch dim,
  6. * computes the number of batches for Triu and Tril. This ignores stride 0 dimension
  7. */
  8. static inline int64_t batchCountTrilTriu(const Tensor& batched_matrices) {
  9. int64_t result = 1;
  10. for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
  11. if (batched_matrices.stride(i) != 0) {
  12. result *= batched_matrices.size(i);
  13. }
  14. }
  15. return result;
  16. }
  17. /* Checks a necessary property for the triu and tril implementations, hence the name.
  18. * Here batch contiguity is checked for tensors with greater than 4 dimensions.
  19. * Contiguous tensors and tensors with less than 3 dimensions pass this check
  20. */
  21. static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor& tensor, bool allow_zero_stride) {
  22. // Complete contiguity is the most desired property, which is why
  23. // we return true if the tensor is contiguous
  24. if (tensor.is_contiguous()) {
  25. auto default_strides_for_size = batched_matrix_contiguous_strides(tensor.sizes());
  26. if (tensor.strides() == default_strides_for_size) {
  27. return std::make_tuple(true, tensor);
  28. } else {
  29. return std::make_tuple(false, tensor.as_strided(tensor.sizes(), default_strides_for_size));
  30. }
  31. }
  32. int64_t dims = tensor.dim();
  33. // Tensors with dimension less than 4 are handled by default
  34. if (allow_zero_stride && dims <= 3) {
  35. return std::make_tuple(true, tensor);
  36. }
  37. int64_t expected_stride = tensor.size(-1) * tensor.size(-2);
  38. for (int64_t i = dims - 3; i >= 0; i--) {
  39. // Skip trivial dimension;
  40. if (allow_zero_stride && i == 0 && (tensor.stride(i) == 0 || tensor.size(i) == 1)) {
  41. continue;
  42. }
  43. if (expected_stride != tensor.stride(i)) {
  44. return std::make_tuple(false, tensor.contiguous());
  45. }
  46. expected_stride *= tensor.size(i);
  47. }
  48. return std::make_tuple(true, tensor);
  49. }
  50. } // namespace at::native