MPSStream.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. // Copyright © 2022 Apple Inc.
  2. #pragma once
  3. #include <cstdint>
  4. #include <utility>
  5. #include <c10/core/DeviceGuard.h>
  6. #include <c10/util/Exception.h>
  7. #include <c10/core/Stream.h>
  8. #include <ATen/mps/MPSDevice.h>
  9. #ifdef __OBJC__
  10. #include <Foundation/Foundation.h>
  11. #include <Metal/Metal.h>
  12. #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
  13. #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
  14. typedef id<MTLCommandQueue> MTLCommandQueue_t;
  15. typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
  16. typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
  17. typedef id<MTLSharedEvent> MTLSharedEvent_t;
  18. typedef id<MTLDevice> MTLDevice_t;
  19. #else
  20. typedef void* MTLCommandQueue_t;
  21. typedef void* MTLCommandQueue;
  22. typedef void* MTLCommandBuffer_t;
  23. typedef void* MTLCommandBuffer;
  24. typedef void* MTLComputeCommandEncoder_t;
  25. typedef void* MTLSharedEvent_t;
  26. typedef void* dispatch_queue_t;
  27. typedef void* MTLDevice_t;
  28. #define nil NULL;
  29. #endif
  30. namespace at::mps {
  31. //-----------------------------------------------------------------
  32. // MPSStream
  33. //-----------------------------------------------------------------
  34. enum class SyncType {
  35. NONE, // no commit to command buffer
  36. COMMIT, // commit and flush the command buffer
  37. COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
  38. COMMIT_AND_CONTINUE,// commit and continue with a new underlying command buffer
  39. COMMIT_ADAPTIVE, // commit adaptively based on available memory
  40. };
  41. class TORCH_API MPSStream
  42. {
  43. public:
  44. enum Unchecked { UNCHECKED };
  45. /// Construct a MPSStream from a Stream. This construction is checked,
  46. /// and will raise an error if the Stream is not, in fact, a MPS stream.
  47. explicit MPSStream(Stream stream);
  48. ~MPSStream();
  49. MTLCommandQueue_t commandQueue() const { return _commandQueue; };
  50. dispatch_queue_t queue() const { return _serialQueue; }
  51. MPSCommandBuffer* commandBuffer();
  52. MTLComputeCommandEncoder_t commandEncoder();
  53. void endKernelCoalescing();
  54. void synchronize(SyncType syncType);
  55. void fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
  56. void copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
  57. size_t length, size_t srcOffset, size_t dstOffset,
  58. uint64_t profileId, SyncType syncType = SyncType::NONE);
  59. void copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
  60. size_t length, size_t srcOffset, size_t dstOffset,
  61. bool non_blocking, uint64_t profileId);
  62. void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE);
  63. void addCompletedHandler(MTLCommandBufferHandler block);
  64. /// Get the MPS device index that this stream is associated with.
  65. c10::DeviceIndex device_index() const { return _stream.device_index(); }
  66. MTLCommandQueue_t stream() const { return _commandQueue; };
  67. MTLDevice_t device() const { return [_commandQueue device];}
  68. /// Explicit conversion to Stream.
  69. Stream unwrap() const { return _stream; }
  70. private:
  71. Stream _stream;
  72. MTLCommandQueue_t _commandQueue = nil;
  73. MPSCommandBuffer* _commandBuffer = nil;
  74. MPSCommandBuffer* _prevCommandBuffer = nil;
  75. MTLComputeCommandEncoder_t _commandEncoder = nil;
  76. MPSGraphExecutionDescriptor *_executionDescriptor = nil;
  77. MPSGraphCompilationDescriptor *_compilationDescriptor = nil;
  78. dispatch_queue_t _serialQueue = nullptr;
  79. // CommitAndContinue is enabled by default
  80. bool _enableCommitAndContinue = true;
  81. // use synchronize() to access any of these commit functions outside MPSStream
  82. void commit();
  83. void commitAndWait();
  84. void commitAndContinue();
  85. void flush();
  86. };
  87. /**
  88. * Get the current MPS stream
  89. */
  90. TORCH_API MPSStream* getCurrentMPSStream();
  91. /**
  92. * Get the default MPS stream
  93. */
  94. TORCH_API MPSStream* getDefaultMPSStream();
  95. //-----------------------------------------------------------------
  96. // MPSStreamImpl
  97. //-----------------------------------------------------------------
  98. class TORCH_API MPSStreamImpl
  99. {
  100. public:
  101. /**
  102. * Gets single instance of the MPSStream.
  103. */
  104. static MPSStream* getInstance();
  105. private:
  106. static MPSStream* _stream;
  107. MPSStreamImpl();
  108. };
  109. } // namespace at::mps