MPSAllocatorInterface.h 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. // Copyright © 2023 Apple Inc.
  2. #pragma once
  3. #include <c10/core/Allocator.h>
  4. #include <c10/util/Registry.h>
  5. #include <ATen/core/ATen_fwd.h>
  6. #define MB(x) (x * 1048576UL)
  7. namespace at::mps {
  8. // this is a public interface to access MPSAllocator.
  9. // Do not declare methods that would depend on MPS or Metal frameworks.
  10. class IMPSAllocator : public c10::Allocator {
  11. public:
  12. // see the comments in MPSAllocator.h for the description of these methods.
  13. virtual void emptyCache() const = 0;
  14. virtual void freeInactiveBuffers() const = 0;
  15. virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
  16. virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
  17. virtual id_t getBufferId(const void* ptr) const = 0;
  18. virtual void setBufferShape(const void* ptr, const IntArrayRef& shape) const = 0;
  19. virtual bool isSharedBuffer(const void* ptr) const = 0;
  20. virtual bool isSharedStorageSupported() const = 0;
  21. virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0;
  22. virtual std::string formatSize(size_t size) const = 0;
  23. virtual void setLowWatermarkRatio(double ratio) const = 0;
  24. virtual void setHighWatermarkRatio(double ratio) const = 0;
  25. virtual ssize_t getLowWatermarkValue() const = 0;
  26. virtual size_t getLowWatermarkLimit() const = 0;
  27. virtual size_t getHighWatermarkLimit() const = 0;
  28. virtual size_t getTotalAllocatedMemory() const = 0;
  29. virtual size_t getCurrentAllocatedMemory() const = 0;
  30. virtual size_t getDriverAllocatedMemory() const = 0;
  31. virtual std::pair<const void*, uint32_t> getSharedBufferPtr(const void* ptr) const = 0;
  32. virtual bool recordEvents(c10::ArrayRef<const void*> buffers) const = 0;
  33. virtual bool waitForEvents(c10::ArrayRef<const void*> buffers) const = 0;
  34. };
  35. class IMpsAllocatorCallback {
  36. public:
  37. enum class EventType {
  38. ALLOCATED, // buffer got allocated to be used immediately
  39. RECYCLED, // buffer pulled from free list to be reused
  40. FREED, // buffer put to free list for future recycling
  41. RELEASED, // buffer memory released
  42. ALLOCATION_FAILED // buffer allocation failed
  43. };
  44. virtual ~IMpsAllocatorCallback() = default;
  45. virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
  46. };
  47. // MPS allocator will execute every registered callback when a block of memory is freed.
  48. C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
  49. #define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
  50. C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__);
  51. IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false);
  52. } // namespace at::mps