TensorNames.h 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. #pragma once
  2. #include <ATen/WrapDimUtils.h>
  3. namespace at::namedinference {
  4. // TensorName and TensorNames are wrappers around Dimname and DimnameList
  5. // that contain helper functions to make writing name inference rules easier.
  6. //
  7. // A TensorName represents a Dimname associated with some DimnameList (from a
  8. // Tensor). This encapsulates all the information that is needed to check if
  9. // names *match* and to *unify* names.
  10. //
  11. // Definition: Two names in two tensors *match* if they are equal, or if at
  12. // least one of them is a wildcard that can be *refined* to the other name.
  13. //
  14. // Definition: unify(name, other) fails if the names do not match. Otherwise,
  15. // it returns the most refined of name and other.
  16. //
  17. // Here is an example of checking if two names match.
  18. // tensor: Tensor[A, None]
  19. // other: Tensor[A]
  20. //
  21. // Let's say we wish to check if tensor.names[-1] matches other.names[-1].
  22. // None (in tensor) cannot match A (in other) because if the None were refined
  23. // to A, `tensor` would have duplicate names [A, A]. Therefore we need to check
  24. // tensor.names [A, None] for the existence of A.
  25. struct TORCH_API TensorName {
  26. explicit TensorName(ArrayRef<Dimname> origin, int origin_idx)
  27. : origin_(origin),
  28. name_(origin[maybe_wrap_dim(
  29. origin_idx,
  30. static_cast<int64_t>(origin.size()))]),
  31. origin_idx_(origin_idx) {}
  32. // op_name is only used for error reporting.
  33. const TensorName& unify(const TensorName& other, const char* op_name) const;
  34. Dimname toDimname() const;
  35. private:
  36. ArrayRef<Dimname> origin_;
  37. Dimname name_;
  38. int origin_idx_; // A named tensor can have at most 64 dims.
  39. TORCH_API friend std::ostream& operator<<(
  40. std::ostream& out,
  41. const TensorName& tensorname);
  42. };
  43. using TensorNameVec = SmallVector<TensorName, 10>;
  44. struct TORCH_API TensorNames {
  45. explicit TensorNames(ArrayRef<Dimname> names);
  46. // Create TensorNames from names[start:end]. Each individual TensorName stores
  47. // `names`, NOT names[start:end], because the original tensor's names are
  48. // `names`.
  49. explicit TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end);
  50. // op_name is only used for error reporting.
  51. TensorNames& unifyFromRightInplace(
  52. const TensorNames& other,
  53. const char* op_name = "unify");
  54. void checkUnique(const char* op_name) const;
  55. void append(TensorName name);
  56. std::vector<Dimname> toDimnameVec() const;
  57. private:
  58. explicit TensorNames(TensorNameVec&& names) : names_(std::move(names)){};
  59. TensorNameVec names_;
  60. };
  61. } // namespace at::namedinference