Normalization.h 554 B

12345678910111213141516171819
  1. #pragma once
  2. #include <ATen/TensorIterator.h>
  3. #include <ATen/native/DispatchStub.h>
  4. namespace at::native {
  5. using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
  6. DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
  7. enum class BatchNormBackend {
  8. Native,
  9. Cudnn,
  10. Miopen,
  11. };
  12. TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
  13. } // namespace at::native