MPSGeneratorImpl.h 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. // Copyright © 2022 Apple Inc.
  2. #pragma once
  3. #include <ATen/core/Generator.h>
  4. #include <ATen/core/PhiloxRNGEngine.h>
  5. #include <c10/core/GeneratorImpl.h>
  6. #include <c10/util/Optional.h>
  7. namespace at {
  8. namespace mps::detail {
  9. static const uint32_t PHILOX_STATE_N = 7;
  10. struct rng_data_pod {
  11. std::array<uint32_t, PHILOX_STATE_N> state{1};
  12. uint64_t seed = default_rng_seed_val;
  13. };
  14. TORCH_API const Generator& getDefaultMPSGenerator();
  15. TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
  16. } // namespace mps::detail
  17. struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
  18. // Constructors
  19. MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
  20. ~MPSGeneratorImpl() override = default;
  21. // MPSGeneratorImpl methods
  22. std::shared_ptr<MPSGeneratorImpl> clone() const;
  23. void set_current_seed(uint64_t seed) override;
  24. void set_offset(uint64_t offset) override;
  25. uint64_t get_offset() const override;
  26. uint64_t current_seed() const override;
  27. uint64_t seed() override;
  28. void set_state(const c10::TensorImpl& new_state) override;
  29. c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
  30. void update_philox_counters();
  31. void set_engine(at::Philox4_32 engine) { engine_ = engine; };
  32. at::Philox4_32 engine() { return engine_; };
  33. uint32_t* state_data() { return data_.state.data(); }
  34. static DeviceType device_type() { return DeviceType::MPS; };
  35. private:
  36. mps::detail::rng_data_pod data_;
  37. at::Philox4_32 engine_;
  38. MPSGeneratorImpl* clone_impl() const override;
  39. };
  40. } // namespace at