BucketizationUtils.h 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/native/TypeProperties.h>
  4. #include <ATen/ScalarOps.h>
  5. #ifndef AT_PER_OPERATOR_HEADERS
  6. #include <ATen/NativeFunctions.h>
  7. #else
  8. #include <ATen/ops/result_type.h>
  9. #endif
  10. namespace at::native {
  11. // original values given by raw_*. If an original value is not contiguous, will make a contiguous copy to
  12. // the corresponding trimmed_* value. Additionally, if the dtypes of the boundary and input tensor do not
  13. // match, will change them to be a common super type so comparisons are done between the same types.
  14. // For any trimmed_* tensor, if its outgoing value matches what it was incoming (typically null), then the
  15. // corresponding raw_* version should be used since it was already contiguous of the right type.
  16. inline void searchsorted_maybe_trim_input_tensors(
  17. Tensor& trimmed_input,
  18. Tensor& trimmed_boundaries,
  19. Tensor& trimmed_sorter,
  20. const Tensor& raw_input,
  21. const Tensor& raw_boundaries,
  22. const Tensor& raw_sorter) {
  23. bool in_is_contiguous = raw_input.is_contiguous();
  24. bool bd_is_contiguous = raw_boundaries.is_contiguous();
  25. bool sort_is_contiguous = raw_sorter.is_contiguous();
  26. if (!in_is_contiguous) {
  27. TORCH_WARN_ONCE("torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due "
  28. "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value "
  29. "tensor if possible. This message will only appear once per program.");
  30. trimmed_input = raw_input.contiguous();
  31. }
  32. if (!bd_is_contiguous) {
  33. TORCH_WARN_ONCE("torch.searchsorted(): boundary tensor is non-contiguous, this will lower the performance due "
  34. "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous boundary "
  35. "tensor if possible. This message will only appear once per program.");
  36. trimmed_boundaries = raw_boundaries.contiguous();
  37. }
  38. if (!sort_is_contiguous) {
  39. TORCH_WARN_ONCE("torch.searchsorted(): sorter tensor is non-contiguous, this will lower the performance due "
  40. "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous sorter "
  41. "tensor if possible. This message will only appear once per program.");
  42. trimmed_sorter = raw_sorter.contiguous();
  43. }
  44. if (raw_input.dtype() != raw_boundaries.dtype()) {
  45. at::native::ResultTypeState state = {};
  46. state = at::native::update_result_type_state(raw_boundaries, state);
  47. state = at::native::update_result_type_state(raw_input, state);
  48. ScalarType common_stype = at::native::result_type(state);
  49. TORCH_INTERNAL_ASSERT(common_stype != ScalarType::Undefined);
  50. if (common_stype != raw_input.scalar_type()) {
  51. trimmed_input = in_is_contiguous ? raw_input.to(common_stype) : trimmed_input.to(common_stype);
  52. }
  53. if (common_stype != raw_boundaries.scalar_type()) {
  54. trimmed_boundaries = bd_is_contiguous ? raw_boundaries.to(common_stype) : trimmed_boundaries.to(common_stype);
  55. }
  56. }
  57. }
  58. /* unused but needed for internal jagged tensor class */
  59. inline void searchsorted_maybe_trim_input_tensors(
  60. Tensor& trimmed_input,
  61. Tensor& trimmed_boundaries,
  62. const Tensor& raw_input,
  63. const Tensor& raw_boundaries) {
  64. Tensor trimmed_sorter;
  65. Tensor raw_sorter;
  66. return searchsorted_maybe_trim_input_tensors(
  67. trimmed_input,
  68. trimmed_boundaries,
  69. trimmed_sorter,
  70. raw_input,
  71. raw_boundaries,
  72. raw_sorter);
  73. }
  74. inline bool searchsorted_dims_matched_before_last_dim(const Tensor& boundaries, const Tensor& input) {
  75. if (boundaries.dim() != input.dim()) {
  76. return false;
  77. }
  78. const auto& dims_bd = boundaries.sizes();
  79. const auto& dims_in = input.sizes();
  80. for (int64_t dim = 0; dim + 1 < boundaries.dim(); ++dim) {
  81. if (dims_bd[dim] != dims_in[dim]) {
  82. return false;
  83. }
  84. }
  85. return true;
  86. }
  87. inline Tensor searchsorted_scalar_tensor(const Scalar& scalar, const c10::Device& device) {
  88. auto tensor = c10::scalar_to_tensor(scalar, device);
  89. // This is to adopt the scalar promotion rules defined in native/TypeProperties.h
  90. // So we have the same type promotion rules as binary operations.
  91. tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
  92. return tensor;
  93. }
  94. inline void searchsorted_pre_check(
  95. const Tensor& boundaries,
  96. const Tensor& input,
  97. const Tensor& output,
  98. const bool out_int32,
  99. const bool right,
  100. const std::optional<c10::string_view> side_opt,
  101. const Tensor& sorter) {
  102. if (side_opt) {
  103. const c10::string_view side = *side_opt;
  104. TORCH_CHECK(side == "left" || side == "right", "torch.searchsorted(): side can only be 'left' or 'right' but ",
  105. "got ", side);
  106. // assume the user has not explicitly set (right=False, side="right")
  107. TORCH_CHECK(!right || side == "right", "torch.searchsorted(): side and right can't be set to opposites, got side "
  108. "of ", side, " while right was True");
  109. }
  110. TORCH_CHECK(boundaries.device() == input.device(), "torch.searchsorted(): boundaries and input value tensors ",
  111. "should have same device type, but got boundaries tensor device type ", boundaries.device(), " and input value ",
  112. "tensor device type ", input.device());
  113. if (sorter.defined()) {
  114. TORCH_CHECK(sorter.device() == boundaries.device(), "torch.searchsorted(): sorter and boundary tensors should ",
  115. "have same device type, but got sorter tensor device type ", sorter.device(), " and input value tensor ",
  116. "device type ", boundaries.device());
  117. TORCH_CHECK(sorter.sizes() == boundaries.sizes(), "torch.searchsorted(): boundary and sorter must have the same "
  118. "size, but got boundary tensor ", boundaries.sizes(), "and got sorter tensor ", sorter.sizes());
  119. TORCH_CHECK(sorter.scalar_type() == ScalarType::Long, "torch.searchsorted(): sorter must be a tensor of long ",
  120. "dtype but got dtype ", sorter.scalar_type());
  121. if (sorter.numel() > 0) {
  122. auto minmax = sorter.aminmax();
  123. int64_t vmin = std::get<0>(minmax).item().toLong();
  124. int64_t vmax = std::get<1>(minmax).item().toLong();
  125. TORCH_CHECK(vmin >= 0 && vmax < sorter.sizes().back(), "torch.searchsorted(): sorter index out of range");
  126. }
  127. }
  128. TORCH_CHECK(input.dim() > 0 || (input.dim() == 0 && input.numel() == 1 && boundaries.dim() == 1),
  129. "torch.searchsorted(): input value can be a scalar only when boundaries tensor dimension is 1, but we got ",
  130. "boundaries tensor dim(", boundaries.dim(), ") and input value's dim(", input.dim(), ") numel(",
  131. input.numel(), ")");
  132. TORCH_CHECK(boundaries.dim() != 0, "torch.searchsorted(): boundaries tensor should have positive dimension, but ",
  133. "got 0 dimension");
  134. TORCH_CHECK(boundaries.dim() == 1 || searchsorted_dims_matched_before_last_dim(boundaries, input),
  135. "torch.searchsorted(): boundaries tensor should be 1 dimension or the first N-1 dimensions of boundaries tensor ",
  136. "and input value tensor must match, but we got boundaries tensor ", boundaries.sizes(), " and input value tensor ",
  137. input.sizes());
  138. ScalarType output_dtype = output.scalar_type();
  139. TORCH_CHECK(
  140. (output_dtype == ScalarType::Long && !out_int32) ||
  141. (output_dtype == ScalarType::Int && out_int32),
  142. "torch.searchsorted(): output tensor's dtype is wrong, it can only be Int(int32) or Long(int64) depending on ",
  143. "whether out_int32 flag is True, but we got output tensor's dtype ", output_dtype,
  144. " and out_int32 flag is ", (out_int32 ? "True" : "False"));
  145. if (out_int32) {
  146. TORCH_CHECK(boundaries.sizes().back() < INT_MAX,
  147. "torch.searchsorted(): the size of boundaries' last dimension should be less than ", INT_MAX, ", but we got ",
  148. boundaries.sizes().back());
  149. }
  150. }
  151. } // namespace at::native