TensorMeta.h 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #pragma once
  2. #include <ATen/DimVector.h>
  3. #include <ATen/core/Dimname.h>
  4. #include <c10/core/TensorOptions.h>
  5. #include <c10/util/strides.h>
  6. namespace at {
  7. class Tensor;
  8. namespace impl {
  9. // Use this to define the prototype for a meta function. There are two
  10. // versions; one that takes one argument (just the operator name), or FUNC2
  11. // variant that takes two arguments (operator name and overload name).
  12. //
  13. // Example usage:
  14. //
  15. // TORCH_META_FUNC2(add, Tensor) (
  16. // const Tensor& self, const Tensor& other
  17. // ) {
  18. // ... compute sizes and options ...
  19. // set_output(sizes, options);
  20. // }
  21. //
  22. #define TORCH_META_FUNC(name) void structured_##name::meta
  23. #define TORCH_META_FUNC2(name, overload) \
  24. void structured_##name##_##overload::meta
  25. // These are versions of TORCH_META_FUNC(2) that include a precompute_out struct
  26. // as a return value. They should be used when the kernel in question has
  27. // precomputed values declared in native_functions.yaml and the corresponding
  28. // implementation should return an instance of the aforementioned struct.
  29. #define TORCH_PRECOMPUTE_META_FUNC(name) \
  30. structured_##name::meta_return_ty structured_##name::meta
  31. #define TORCH_PRECOMPUTE_META_FUNC2(name, overload) \
  32. structured_##name##_##overload::meta_return_ty \
  33. structured_##name##_##overload::meta
  34. // Use this to create a precompute struct in a meta function.
  35. #define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<>
  36. #define TORCH_PRECOMPUTE_STRUCT2(name, overload) \
  37. structured_##name##_##overload::precompute_out<>
  38. // Use this to define the prototype for an implementation. This takes only
  39. // one argument, which is the name of the dispatch key entry you're
  40. // implementing.
  41. //
  42. // Example usage:
  43. //
  44. // TORCH_IMPL_FUNC(add_cpu) (
  45. // Tensor& result, const Tensor& self, const Tensor& other
  46. // ) {
  47. // ... do the actual implementation ...
  48. // }
  49. //
  50. #define TORCH_IMPL_FUNC(name) void structured_##name::impl
  51. // Base class for all structured kernel classes. The set_output virtual
  52. // method is varied depending whether or not the operator is
  53. // functional/out/inplace, and could also be specialized for CPU/CUDA/etc
  54. // (although presently it isn't).
  55. //
  56. // A notable subclass of this interface is TensorIteratorBase.
  57. struct TORCH_API MetaBase {
  58. MetaBase() = default;
  59. MetaBase(const MetaBase&) = default;
  60. MetaBase& operator=(const MetaBase&) = default;
  61. MetaBase(MetaBase&&) noexcept = default;
  62. MetaBase& operator=(MetaBase&&) noexcept = default;
  63. virtual const Tensor& maybe_get_output(int64_t output_idx) = 0;
  64. // Note: [set_output_*]
  65. // See: https://github.com/pytorch/pytorch/issues/69813
  66. // Whenever defining the output properties in the META function of a
  67. // structured kernel (what was usually done with `set_output`), use one of
  68. // these 3 variants, instead. In order to decide which variant to use, check
  69. // the following decision tree:
  70. //
  71. // - Can the kernel you are going to implement support output tensors
  72. // with arbitrary strides?
  73. // |
  74. // -- YES: `set_output_raw_strided`
  75. // |
  76. // -- NO: Should the output tensor strides be contiguous?
  77. // |
  78. // -- YES: `set_output_contiguous`
  79. // |
  80. // -- NO: `set_output_strided`
  81. //
  82. // Use this function whenever the kernel requires specific strides for the
  83. // output. If `strides` does not match the given output strides, proxy outputs
  84. // will be created and passed to the IMPL function.
  85. virtual void set_output_strided(
  86. int64_t output_idx,
  87. IntArrayRef sizes,
  88. IntArrayRef strides,
  89. TensorOptions options,
  90. DimnameList names = {}) {
  91. TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
  92. }
  93. // Use this function whenever the kernel knows how to handle arbitrary strided
  94. // outputs. This function has the same behavior as the old `set_output`: it
  95. // will only re-stride if the given output was resized.
  96. virtual void set_output_raw_strided(
  97. int64_t output_idx,
  98. IntArrayRef sizes,
  99. IntArrayRef strides_hint,
  100. TensorOptions options,
  101. DimnameList names = {}) {
  102. TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
  103. }
  104. // Use this function if the kernel requires contiguous strides.
  105. // Alias for `set_output_strided`, but with contiguous strides.
  106. void set_output_contiguous(
  107. int64_t output_idx,
  108. IntArrayRef sizes,
  109. TensorOptions options,
  110. DimnameList names = {}) {
  111. auto strides = c10::contiguous_strides(sizes);
  112. set_output_strided(output_idx, sizes, strides, options, names);
  113. }
  114. // Returns a reference to an undefined tensor if there is no presupplied
  115. // output
  116. const Tensor& maybe_get_output() {
  117. return maybe_get_output(0);
  118. }
  119. virtual ~MetaBase() = default;
  120. };
  121. } // namespace impl
  122. } // namespace at