| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- #pragma once
- #include <ATen/core/Tensor.h>
- #include <ATen/native/DispatchStub.h>
- namespace at::native {
- using padding_fn = void (*)(const Tensor&, const Tensor&, IntArrayRef);
- // reflection padding
- DECLARE_DISPATCH(padding_fn, reflection_pad1d_kernel);
- DECLARE_DISPATCH(padding_fn, reflection_pad1d_backward_kernel);
- DECLARE_DISPATCH(padding_fn, reflection_pad2d_kernel);
- DECLARE_DISPATCH(padding_fn, reflection_pad2d_backward_kernel);
- DECLARE_DISPATCH(padding_fn, reflection_pad3d_kernel);
- DECLARE_DISPATCH(padding_fn, reflection_pad3d_backward_kernel);
- // replication padding
- DECLARE_DISPATCH(padding_fn, replication_pad1d_kernel);
- DECLARE_DISPATCH(padding_fn, replication_pad1d_backward_kernel);
- DECLARE_DISPATCH(padding_fn, replication_pad2d_kernel);
- DECLARE_DISPATCH(padding_fn, replication_pad2d_backward_kernel);
- DECLARE_DISPATCH(padding_fn, replication_pad3d_kernel);
- DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel);
- namespace padding {
- template <int dim>
- inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
- TORCH_CHECK(padding.size() == 2 * dim,
- "padding size is expected to be ", 2 * dim,
- ", but got: ", padding.size());
- int input_dim = input.dim();
- bool is_batch_mode = input_dim == (dim + 2);
- bool valid_batch_mode = is_batch_mode;
- bool valid_non_batch_mode = !is_batch_mode;
- if (is_batch_mode) {
- // allow batch size of 0-dim.
- for (const auto d : c10::irange(1, input_dim)) {
- valid_batch_mode = valid_batch_mode && input.size(d) != 0;
- }
- } else {
- for (const auto d : c10::irange(0, input_dim)) {
- valid_non_batch_mode = valid_non_batch_mode && input.size(d) != 0;
- }
- }
- // allow empty batch size but not other dimensions.
- TORCH_CHECK(valid_batch_mode || valid_non_batch_mode,
- "Expected ", dim + 1, "D or ", dim + 2,
- "D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
- input.sizes());
- }
- } // namespace padding
- } // at::native
|