OperationUtils.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. // Copyright © 2022 Apple Inc.
  2. #pragma once
  3. #include <initializer_list>
  4. #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
  5. #include <ATen/Tensor.h>
  6. #include <ATen/Utils.h>
  7. #include <ATen/mps/MPSStream.h>
  8. #include <ATen/native/mps/TensorFactory.h>
  9. #include <c10/util/Optional.h>
  10. #include <c10/core/ScalarType.h>
  11. #include <torch/library.h>
  12. #include <exception>
  13. #include <unordered_map>
  14. #ifndef AT_PER_OPERATOR_HEADERS
  15. #include <ATen/Functions.h>
  16. #include <ATen/NativeFunctions.h>
  17. #else
  18. #include <ATen/ops/empty.h>
  19. #include <ATen/ops/empty_like.h>
  20. #include <ATen/ops/zeros.h>
  21. #include <ATen/ops/zeros_like.h>
  22. #endif
  23. #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
  24. // Fwd declarations
  25. namespace at {
  26. struct TensorIteratorBase;
  27. }
  28. using namespace at::mps;
  29. namespace at::native::mps {
  30. void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
  31. struct MPSScalar {
  32. id<MTLBuffer> getMTLBuffer() const { return __builtin_bit_cast(id<MTLBuffer>, buffer.get()); }
  33. size_t size = 0;
  34. ScalarType type = ScalarType::Undefined;
  35. c10::DataPtr buffer; // stores MTLBuffer (frees buffer if MPSScalar instance goes out of scope)
  36. union {
  37. float f; // MPS doesn't support 'double'
  38. at::Half h;
  39. int64_t i;
  40. bool b;
  41. c10::complex<float> cf;
  42. c10::complex<at::Half> ch;
  43. at::BFloat16 bf16;
  44. } value {};
  45. };
  46. void runMPSGraph(MPSStream* mpsStream,
  47. MPSGraph* mpsGraph,
  48. NSDictionary* feeds,
  49. NSDictionary* results);
  50. MPSDataType getMPSDataType(ScalarType scalar_type);
  51. static inline MPSDataType getMPSDataType(const Tensor& t) {
  52. return getMPSDataType(t.scalar_type());
  53. }
  54. MPSDataType getMPSScalarType(ScalarType scalar_type);
  55. static inline MPSDataType getMPSScalarType(const Tensor& t) {
  56. return getMPSScalarType(t.scalar_type());
  57. }
  58. MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type);
  59. std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false);
  60. static inline std::string getMPSTypeString(const Tensor& t, bool short_name = false) {
  61. return getMPSTypeString(t.scalar_type(), short_name);
  62. }
  63. std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type);
  64. static inline std::string scalarToMetalTypeString(const Tensor& t) {
  65. return scalarToMetalTypeString(t.scalar_type());
  66. }
  67. NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
  68. NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
  69. std::string getMPSShapeString(MPSShape* shape);
  70. std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
  71. std::string getArrayRefString(const IntArrayRef s);
  72. // use has_storage() on the returned tensor to determine if src actually is a view
  73. Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst);
  74. Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output);
  75. bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape);
  76. MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType);
  77. MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
  78. MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
  79. // The MPSShape could vary based on memory format
  80. MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
  81. MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
  82. static inline id<MTLBuffer> getMTLBufferStorage(const at::Tensor& tensor) {
  83. return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
  84. }
  85. class Placeholder {
  86. public:
  87. Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {}
  88. Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {}
  89. Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr,
  90. bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid);
  91. MPSGraphTensor* getMPSGraphTensor() {
  92. return _placeholder;
  93. }
  94. MPSGraphTensorData* getMPSGraphTensorData() {
  95. return _value;
  96. }
  97. bool isIntermediate() {
  98. return _value == nullptr;
  99. }
  100. private:
  101. MPSGraphTensor* _placeholder;
  102. MPSGraphTensorData* _value;
  103. Tensor _tensor;
  104. };
  105. void resize_tensor(Tensor* output);
  106. Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device);
  107. MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
  108. MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor);
  109. MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType);
  110. MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType);
  111. MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor);
  112. MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar);
  113. MPSGraph* make_mps_graph();
  114. void printTensorNDArray(const Tensor& t);
  115. MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType);
  116. MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
  117. MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape);
  118. MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor);
  119. MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
  120. MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar);
  121. string get_mem_format_string(c10::MemoryFormat memory_format);
  122. using MPSCacheKey = uint64_t;
  123. // derive this class to cache a graph and its inputs/outputs
  124. // can be used to store any NSObject
  125. struct MPSCachedGraph
  126. {
  127. MPSCachedGraph(NSObject *object) : _object([object retain]) {}
  128. virtual ~MPSCachedGraph() {
  129. [_object release];
  130. _object = nullptr;
  131. }
  132. template<typename T>
  133. inline T* as() {
  134. return static_cast<T*>(this);
  135. }
  136. MPSGraph *graph() const { return (MPSGraph *)_object; }
  137. NSObject *object() const { return _object; }
  138. private:
  139. NSObject *_object = nullptr;
  140. };
  141. struct MPSUnaryCachedGraph : public MPSCachedGraph
  142. {
  143. MPSUnaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
  144. MPSGraphTensor *inputTensor_ = nil;
  145. MPSGraphTensor *outputTensor_ = nil;
  146. };
  147. struct MPSUnaryGradCachedGraph : public MPSCachedGraph
  148. {
  149. MPSUnaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
  150. MPSGraphTensor *gradOutputTensor_ = nil;
  151. MPSGraphTensor *inputTensor_ = nil;
  152. MPSGraphTensor *outputTensor_ = nil; // some backward input is actually the forward's output
  153. MPSGraphTensor *gradInputTensor_ = nil;
  154. };
  155. struct MPSBinaryCachedGraph : public MPSCachedGraph
  156. {
  157. MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
  158. MPSGraphTensor *inputTensor_ = nil;
  159. MPSGraphTensor *otherTensor_ = nil;
  160. MPSGraphTensor *outputTensor_ = nil;
  161. };
  162. struct MPSBinaryGradCachedGraph : public MPSCachedGraph
  163. {
  164. MPSBinaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
  165. MPSGraphTensor *gradOutputTensor_ = nil;
  166. MPSGraphTensor *inputTensor_ = nil;
  167. MPSGraphTensor *otherTensor_ = nil;
  168. MPSGraphTensor *gradInputTensor_ = nil;
  169. };
  170. // TODO: Improve the overall design of MPSGraphCache.
  171. // https://github.com/pytorch/pytorch/issues/77176
  172. // Cache holding various keys mapped to graphs
  173. struct MPSGraphCache
  174. {
  175. typedef MPSCachedGraph * (^CreateCachedGraphBlock)();
  176. struct CacheEntry {
  177. CacheEntry(const std::string& key, MPSCachedGraph *cachedGraph) : cachedGraph_(cachedGraph), key_(key) {}
  178. MPSCachedGraph* cachedGraph_ = nullptr;
  179. std::string key_;
  180. };
  181. public:
  182. static MPSGraphCache* getInstance() {
  183. if(_instance_cache == nullptr) {
  184. _instance_cache = new MPSGraphCache();
  185. }
  186. return _instance_cache;
  187. }
  188. ~MPSGraphCache() {
  189. dispatch_release(serialQueue_);
  190. for (const auto& i : cache_) {
  191. delete i.second.cachedGraph_;
  192. }
  193. }
  194. // Disallow the copy constructor and operator= functions
  195. MPSGraphCache(const MPSGraphCache&) = delete;
  196. void operator=(const MPSGraphCache&) = delete;
  197. MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
  198. __block MPSCachedGraph* cachedGraph = nil;
  199. MPSCacheKey hash = std::hash<std::string>{}(key);
  200. dispatch_sync_with_rethrow(serialQueue_, ^() {
  201. // verify the cached entry doesn't already exist
  202. if (cache_.count(hash) != 0) {
  203. auto& entry = cache_.at(hash);
  204. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
  205. cachedGraph = entry.cachedGraph_;
  206. } else {
  207. cachedGraph = createCacheBlock();
  208. CacheEntry entry(key, cachedGraph);
  209. cache_.emplace(hash, entry);
  210. profileCachedGraph(entry);
  211. }
  212. });
  213. return cachedGraph;
  214. }
  215. template<typename T>
  216. inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
  217. return static_cast<T *>(CreateCachedGraph(key, createCacheBlock));
  218. }
  219. MPSCachedGraph* LookUp(const std::string& key) const {
  220. __block MPSCachedGraph* cachedGraph = nullptr;
  221. MPSCacheKey hash = std::hash<std::string>{}(key);
  222. dispatch_sync(serialQueue_, ^() {
  223. if (cache_.count(hash) != 0) {
  224. auto& entry = cache_.at(hash);
  225. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
  226. cachedGraph = entry.cachedGraph_;
  227. profileCachedGraph(entry);
  228. }
  229. });
  230. return cachedGraph;
  231. }
  232. template<typename T>
  233. inline T* LookUpAs(const std::string& key) const {
  234. return static_cast<T *>(LookUp(key));
  235. }
  236. private:
  237. MPSGraphCache() {
  238. serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL);
  239. }
  240. // this is defined in OperationUtils.mm to not include
  241. // MPSProfiler.h in header OperationUtils.h
  242. void profileCachedGraph(const CacheEntry& cacheEntry) const;
  243. static MPSGraphCache* _instance_cache;
  244. std::unordered_map<MPSCacheKey, CacheEntry> cache_;
  245. dispatch_queue_t serialQueue_ = nullptr;
  246. };
  247. // Common template for creating graph with a specified cache if missing
  248. template<typename T>
  249. inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(MPSGraph*, T*)> instantiate) {
  250. auto cache_ = MPSGraphCache::getInstance();
  251. if (auto rc = cache_->LookUpAs<T>(key)) {
  252. return rc;
  253. }
  254. return cache_->CreateCachedGraphAs<T>(key, ^mps::MPSCachedGraph*() {
  255. T* newCachedGraph = nil;
  256. @autoreleasepool {
  257. // Initialize graph
  258. auto mpsGraph = mps::make_mps_graph();
  259. newCachedGraph = new T(mpsGraph);
  260. instantiate(mpsGraph, newCachedGraph);
  261. }
  262. return newCachedGraph;
  263. });
  264. }
  265. // Common math operations
  266. MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
  267. #define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
  268. if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
  269. TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, \
  270. ", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
  271. }
  272. /**
  273. * Returns distance from lowest to highest element offset in given tensor.
  274. */
  275. size_t compute_storage_numel_distance(const at::Tensor& t);
  276. /**
  277. * Checks whether tensor is mapped to a contiguous area in the storage.
  278. */
  279. inline bool is_dense_in_storage(const at::Tensor& t) {
  280. return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
  281. }
  282. class MetalShaderLibrary {
  283. public:
  284. MetalShaderLibrary(const std::string& src, unsigned nparams_ = 0): shaderSource(src), nparams(nparams_) {}
  285. MetalShaderLibrary(const MetalShaderLibrary&) = delete;
  286. inline id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname) {
  287. return getLibraryPipelineState(getLibrary(), fname);
  288. }
  289. id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname, const std::initializer_list<std::string>& params) {
  290. return getLibraryPipelineState(getLibrary(params), fname);
  291. }
  292. private:
  293. id<MTLComputePipelineState> getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
  294. id<MTLLibrary> getLibrary();
  295. id<MTLLibrary> getLibrary(const std::initializer_list<std::string>& params);
  296. id<MTLLibrary> compileLibrary(const std::string& src);
  297. std::string shaderSource;
  298. unsigned nparams;
  299. id<MTLLibrary> library = nil;
  300. std::unordered_map<std::string, id<MTLLibrary>> libMap;
  301. std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
  302. };
  303. static inline void mtl_setBuffer(id<MTLComputeCommandEncoder> encoder, const Tensor& t, unsigned idx) {
  304. [encoder setBuffer:getMTLBufferStorage(t)
  305. offset:t.storage_offset() * t.element_size()
  306. atIndex:idx];
  307. }
  308. static inline void mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,
  309. id<MTLComputePipelineState> cplState,
  310. uint32_t length) {
  311. const uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup];
  312. auto size = MTLSizeMake(length, 1, 1);
  313. auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1);
  314. [encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize];
  315. }
  316. id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder, const TensorIteratorBase& iter, bool use_64bit_index = false);
  317. inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) {
  318. return @{ p1.getMPSGraphTensor(): p1.getMPSGraphTensorData() };
  319. }
  320. inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2) {
  321. return @{
  322. p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
  323. p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
  324. };
  325. }
  326. inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3) {
  327. return @{
  328. p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
  329. p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
  330. p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
  331. };
  332. }
  333. inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3, Placeholder& p4) {
  334. return @{
  335. p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
  336. p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
  337. p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
  338. p4.getMPSGraphTensor(): p4.getMPSGraphTensorData(),
  339. };
  340. }
  341. inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, Placeholder& result) {
  342. runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result));
  343. }
  344. inline bool supportsComplex() {
  345. return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
  346. }
  347. // MPS yet to support double types, but starting from MacOS 14, supports bfloat16
  348. inline bool supportedFloatingType(ScalarType dtype) {
  349. return dtype == kFloat || dtype == kHalf || dtype == kBFloat16;
  350. }
  351. inline bool supportedFloatingType(const Tensor& t) {
  352. return supportedFloatingType(t.scalar_type());
  353. }
  354. inline bool supportedFloatingOrComplexType(ScalarType dtype) {
  355. if (dtype == kComplexFloat || dtype == kComplexHalf) {
  356. return supportsComplex();
  357. }
  358. return supportedFloatingType(dtype);
  359. }
  360. inline bool supportedFloatingOrComplexType(const Tensor& t) {
  361. return supportedFloatingOrComplexType(t.scalar_type());
  362. }
  363. inline bool needsGather(const Tensor& t) {
  364. return !t.is_contiguous() || t.storage_offset();
  365. }
  366. } // namespace at::native::mps