AmpKernels.h 619 B

12345678910111213141516171819202122232425262728
  1. #pragma once
  2. #include <ATen/native/DispatchStub.h>
  3. #include <ATen/core/ATen_fwd.h>
  4. namespace at {
  5. class Tensor;
  6. namespace native {
  7. using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)(
  8. TensorList,
  9. Tensor&,
  10. const Tensor&);
  11. using _amp_update_scale_cpu__fn = Tensor& (*)(
  12. Tensor&,
  13. Tensor&,
  14. const Tensor&,
  15. double,
  16. double,
  17. int64_t);
  18. DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub);
  19. DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub);
  20. } // namespace native
  21. } // namespace at