TensorSubclassLikeUtils.h 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. #pragma once
  2. #include <ATen/core/List.h>
  3. #include <ATen/core/Tensor.h>
  4. #include <c10/core/impl/TorchDispatchModeTLS.h>
  5. #ifndef AT_PER_OPERATOR_HEADERS
  6. #include <ATen/Functions.h>
  7. #else
  8. #include <ATen/ops/equal.h>
  9. #endif
  10. namespace at {
  11. // Note [Tensor-subclass-like Tensors]
  12. // Tensor-subclass-like is defined as:
  13. // - a Tensor subclass (via __torch_dispatch__ in Python or extending
  14. // TensorImpl in C++)
  15. // - anything else that shares the same perils as Tensor subclasses.
  16. // For example, many Tensor subclasses do not have storage and meta Tensors
  17. // do not have storage either, so meta Tensors belong here.
  18. //
  19. // We should ensure that PyTorch internals supports Tensor-subclass-like
  20. // objects. In particular, Tensor-subclass-like objects struggle with two
  21. // classes of operations that are problematic for Tensor subclasses:
  22. // 1. Because some Tensor subclasses do not have storage, .item() or
  23. // .data_ptr() calls are not good.
  24. // 2. Certain in-place operations can eliminate the typing of the Tensor
  25. // subclass. For example:
  26. // >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input)
  27. // If input is a Tensor subclass, then the above ends up either erroring out
  28. // or returning a regular non-Tensor-subclass Tensor!
  29. constexpr auto kFunctorchWrappedTensors = DispatchKeySet(
  30. {DispatchKey::FuncTorchGradWrapper,
  31. DispatchKey::FuncTorchBatched,
  32. DispatchKey::Functionalize});
  33. constexpr auto kTensorSubclassLike =
  34. kFunctorchWrappedTensors |
  35. DispatchKeySet(
  36. {// WARNING: DO NOT put combined backend component + functionality keys
  37. // here, you will incorrectly always match on the functionality key
  38. // no matter the backend component
  39. DispatchKey::Batched,
  40. DispatchKey::Sparse,
  41. DispatchKey::SparseCsr,
  42. DispatchKey::Python}) |
  43. DispatchKeySet(BackendComponent::MetaBit);
  44. inline bool isTensorSubclassLike(const Tensor& tensor) {
  45. if (c10::impl::dispatch_mode_enabled())
  46. return true;
  47. auto key_set = tensor.unsafeGetTensorImpl()->key_set();
  48. return !(key_set & kTensorSubclassLike).empty();
  49. }
  50. inline bool areAnyTensorSubclassLike(TensorList tensors) {
  51. if (c10::impl::dispatch_mode_enabled())
  52. return true;
  53. return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike);
  54. }
  55. inline bool areAnyOptionalTensorSubclassLike(
  56. const c10::List<std::optional<Tensor>>& tensors) {
  57. if (c10::impl::dispatch_mode_enabled())
  58. return true;
  59. return std::any_of(
  60. tensors.begin(), tensors.end(), [](const optional<Tensor>& opt_tensor) {
  61. return (
  62. opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value()));
  63. });
  64. }
  65. // Helper function to deal testing truthfulness of a scalar tensor
  66. // in a Composite Compliant manner.
  67. // NOTE: This function expects a scalar tensor of boolean dtype.
  68. // Eg.
  69. // Non-Composite Compliant Pattern : (t == 0).all().item<bool>()
  70. // Composite Compliant Patter : is_salar_tensor_true((t == 0).all())
  71. inline bool is_scalar_tensor_true(const Tensor& t) {
  72. TORCH_INTERNAL_ASSERT(t.dim() == 0)
  73. TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool)
  74. return at::equal(t, t.new_ones({}, t.options()));
  75. }
  76. } // namespace at