| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437 |
- // Copyright © 2022 Apple Inc.
- #pragma once
- #include <initializer_list>
- #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
- #include <ATen/Tensor.h>
- #include <ATen/Utils.h>
- #include <ATen/mps/MPSStream.h>
- #include <ATen/native/mps/TensorFactory.h>
- #include <c10/util/Optional.h>
- #include <c10/core/ScalarType.h>
- #include <torch/library.h>
- #include <exception>
- #include <unordered_map>
- #ifndef AT_PER_OPERATOR_HEADERS
- #include <ATen/Functions.h>
- #include <ATen/NativeFunctions.h>
- #else
- #include <ATen/ops/empty.h>
- #include <ATen/ops/empty_like.h>
- #include <ATen/ops/zeros.h>
- #include <ATen/ops/zeros_like.h>
- #endif
- #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
- // Fwd declarations
- namespace at {
- struct TensorIteratorBase;
- }
- using namespace at::mps;
- namespace at::native::mps {
- void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
- struct MPSScalar {
- id<MTLBuffer> getMTLBuffer() const { return __builtin_bit_cast(id<MTLBuffer>, buffer.get()); }
- size_t size = 0;
- ScalarType type = ScalarType::Undefined;
- c10::DataPtr buffer; // stores MTLBuffer (frees buffer if MPSScalar instance goes out of scope)
- union {
- float f; // MPS doesn't support 'double'
- at::Half h;
- int64_t i;
- bool b;
- c10::complex<float> cf;
- c10::complex<at::Half> ch;
- at::BFloat16 bf16;
- } value {};
- };
- void runMPSGraph(MPSStream* mpsStream,
- MPSGraph* mpsGraph,
- NSDictionary* feeds,
- NSDictionary* results);
- MPSDataType getMPSDataType(ScalarType scalar_type);
- static inline MPSDataType getMPSDataType(const Tensor& t) {
- return getMPSDataType(t.scalar_type());
- }
- MPSDataType getMPSScalarType(ScalarType scalar_type);
- static inline MPSDataType getMPSScalarType(const Tensor& t) {
- return getMPSScalarType(t.scalar_type());
- }
- MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type);
- std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false);
- static inline std::string getMPSTypeString(const Tensor& t, bool short_name = false) {
- return getMPSTypeString(t.scalar_type(), short_name);
- }
- std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type);
- static inline std::string scalarToMetalTypeString(const Tensor& t) {
- return scalarToMetalTypeString(t.scalar_type());
- }
- NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
- NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
- std::string getMPSShapeString(MPSShape* shape);
- std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
- std::string getArrayRefString(const IntArrayRef s);
- // use has_storage() on the returned tensor to determine if src actually is a view
- Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst);
- Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output);
- bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape);
- MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType);
- MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
- MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
- // The MPSShape could vary based on memory format
- MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
- MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
- static inline id<MTLBuffer> getMTLBufferStorage(const at::Tensor& tensor) {
- return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
- }
- class Placeholder {
- public:
- Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {}
- Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {}
- Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr,
- bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid);
- MPSGraphTensor* getMPSGraphTensor() {
- return _placeholder;
- }
- MPSGraphTensorData* getMPSGraphTensorData() {
- return _value;
- }
- bool isIntermediate() {
- return _value == nullptr;
- }
- private:
- MPSGraphTensor* _placeholder;
- MPSGraphTensorData* _value;
- Tensor _tensor;
- };
- void resize_tensor(Tensor* output);
- Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device);
- MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
- MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor);
- MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType);
- MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType);
- MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor);
- MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar);
- MPSGraph* make_mps_graph();
- void printTensorNDArray(const Tensor& t);
- MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType);
- MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
- MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape);
- MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor);
- MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
- MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar);
- string get_mem_format_string(c10::MemoryFormat memory_format);
- using MPSCacheKey = uint64_t;
- // derive this class to cache a graph and its inputs/outputs
- // can be used to store any NSObject
- struct MPSCachedGraph
- {
- MPSCachedGraph(NSObject *object) : _object([object retain]) {}
- virtual ~MPSCachedGraph() {
- [_object release];
- _object = nullptr;
- }
- template<typename T>
- inline T* as() {
- return static_cast<T*>(this);
- }
- MPSGraph *graph() const { return (MPSGraph *)_object; }
- NSObject *object() const { return _object; }
- private:
- NSObject *_object = nullptr;
- };
- struct MPSUnaryCachedGraph : public MPSCachedGraph
- {
- MPSUnaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
- MPSGraphTensor *inputTensor_ = nil;
- MPSGraphTensor *outputTensor_ = nil;
- };
- struct MPSUnaryGradCachedGraph : public MPSCachedGraph
- {
- MPSUnaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
- MPSGraphTensor *gradOutputTensor_ = nil;
- MPSGraphTensor *inputTensor_ = nil;
- MPSGraphTensor *outputTensor_ = nil; // some backward input is actually the forward's output
- MPSGraphTensor *gradInputTensor_ = nil;
- };
- struct MPSBinaryCachedGraph : public MPSCachedGraph
- {
- MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
- MPSGraphTensor *inputTensor_ = nil;
- MPSGraphTensor *otherTensor_ = nil;
- MPSGraphTensor *outputTensor_ = nil;
- };
- struct MPSBinaryGradCachedGraph : public MPSCachedGraph
- {
- MPSBinaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
- MPSGraphTensor *gradOutputTensor_ = nil;
- MPSGraphTensor *inputTensor_ = nil;
- MPSGraphTensor *otherTensor_ = nil;
- MPSGraphTensor *gradInputTensor_ = nil;
- };
- // TODO: Improve the overall design of MPSGraphCache.
- // https://github.com/pytorch/pytorch/issues/77176
- // Cache holding various keys mapped to graphs
- struct MPSGraphCache
- {
- typedef MPSCachedGraph * (^CreateCachedGraphBlock)();
- struct CacheEntry {
- CacheEntry(const std::string& key, MPSCachedGraph *cachedGraph) : cachedGraph_(cachedGraph), key_(key) {}
- MPSCachedGraph* cachedGraph_ = nullptr;
- std::string key_;
- };
- public:
- static MPSGraphCache* getInstance() {
- if(_instance_cache == nullptr) {
- _instance_cache = new MPSGraphCache();
- }
- return _instance_cache;
- }
- ~MPSGraphCache() {
- dispatch_release(serialQueue_);
- for (const auto& i : cache_) {
- delete i.second.cachedGraph_;
- }
- }
- // Disallow the copy constructor and operator= functions
- MPSGraphCache(const MPSGraphCache&) = delete;
- void operator=(const MPSGraphCache&) = delete;
- MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
- __block MPSCachedGraph* cachedGraph = nil;
- MPSCacheKey hash = std::hash<std::string>{}(key);
- dispatch_sync_with_rethrow(serialQueue_, ^() {
- // verify the cached entry doesn't already exist
- if (cache_.count(hash) != 0) {
- auto& entry = cache_.at(hash);
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
- cachedGraph = entry.cachedGraph_;
- } else {
- cachedGraph = createCacheBlock();
- CacheEntry entry(key, cachedGraph);
- cache_.emplace(hash, entry);
- profileCachedGraph(entry);
- }
- });
- return cachedGraph;
- }
- template<typename T>
- inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
- return static_cast<T *>(CreateCachedGraph(key, createCacheBlock));
- }
- MPSCachedGraph* LookUp(const std::string& key) const {
- __block MPSCachedGraph* cachedGraph = nullptr;
- MPSCacheKey hash = std::hash<std::string>{}(key);
- dispatch_sync(serialQueue_, ^() {
- if (cache_.count(hash) != 0) {
- auto& entry = cache_.at(hash);
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
- cachedGraph = entry.cachedGraph_;
- profileCachedGraph(entry);
- }
- });
- return cachedGraph;
- }
- template<typename T>
- inline T* LookUpAs(const std::string& key) const {
- return static_cast<T *>(LookUp(key));
- }
- private:
- MPSGraphCache() {
- serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL);
- }
- // this is defined in OperationUtils.mm to not include
- // MPSProfiler.h in header OperationUtils.h
- void profileCachedGraph(const CacheEntry& cacheEntry) const;
- static MPSGraphCache* _instance_cache;
- std::unordered_map<MPSCacheKey, CacheEntry> cache_;
- dispatch_queue_t serialQueue_ = nullptr;
- };
- // Common template for creating graph with a specified cache if missing
- template<typename T>
- inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(MPSGraph*, T*)> instantiate) {
- auto cache_ = MPSGraphCache::getInstance();
- if (auto rc = cache_->LookUpAs<T>(key)) {
- return rc;
- }
- return cache_->CreateCachedGraphAs<T>(key, ^mps::MPSCachedGraph*() {
- T* newCachedGraph = nil;
- @autoreleasepool {
- // Initialize graph
- auto mpsGraph = mps::make_mps_graph();
- newCachedGraph = new T(mpsGraph);
- instantiate(mpsGraph, newCachedGraph);
- }
- return newCachedGraph;
- });
- }
- // Common math operations
- MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
- #define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
- if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
- TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, \
- ", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
- }
- /**
- * Returns distance from lowest to highest element offset in given tensor.
- */
- size_t compute_storage_numel_distance(const at::Tensor& t);
- /**
- * Checks whether tensor is mapped to a contiguous area in the storage.
- */
- inline bool is_dense_in_storage(const at::Tensor& t) {
- return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
- }
- class MetalShaderLibrary {
- public:
- MetalShaderLibrary(const std::string& src, unsigned nparams_ = 0): shaderSource(src), nparams(nparams_) {}
- MetalShaderLibrary(const MetalShaderLibrary&) = delete;
- inline id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname) {
- return getLibraryPipelineState(getLibrary(), fname);
- }
- id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname, const std::initializer_list<std::string>& params) {
- return getLibraryPipelineState(getLibrary(params), fname);
- }
- private:
- id<MTLComputePipelineState> getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
- id<MTLLibrary> getLibrary();
- id<MTLLibrary> getLibrary(const std::initializer_list<std::string>& params);
- id<MTLLibrary> compileLibrary(const std::string& src);
- std::string shaderSource;
- unsigned nparams;
- id<MTLLibrary> library = nil;
- std::unordered_map<std::string, id<MTLLibrary>> libMap;
- std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
- };
- static inline void mtl_setBuffer(id<MTLComputeCommandEncoder> encoder, const Tensor& t, unsigned idx) {
- [encoder setBuffer:getMTLBufferStorage(t)
- offset:t.storage_offset() * t.element_size()
- atIndex:idx];
- }
- static inline void mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,
- id<MTLComputePipelineState> cplState,
- uint32_t length) {
- const uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup];
- auto size = MTLSizeMake(length, 1, 1);
- auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1);
- [encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize];
- }
- id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder, const TensorIteratorBase& iter, bool use_64bit_index = false);
- inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) {
- return @{ p1.getMPSGraphTensor(): p1.getMPSGraphTensorData() };
- }
- inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2) {
- return @{
- p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
- p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
- };
- }
- inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3) {
- return @{
- p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
- p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
- p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
- };
- }
- inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3, Placeholder& p4) {
- return @{
- p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
- p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
- p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
- p4.getMPSGraphTensor(): p4.getMPSGraphTensorData(),
- };
- }
- inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, Placeholder& result) {
- runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result));
- }
- inline bool supportsComplex() {
- return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
- }
- // MPS yet to support double types, but starting from MacOS 14, supports bfloat16
- inline bool supportedFloatingType(ScalarType dtype) {
- return dtype == kFloat || dtype == kHalf || dtype == kBFloat16;
- }
- inline bool supportedFloatingType(const Tensor& t) {
- return supportedFloatingType(t.scalar_type());
- }
- inline bool supportedFloatingOrComplexType(ScalarType dtype) {
- if (dtype == kComplexFloat || dtype == kComplexHalf) {
- return supportsComplex();
- }
- return supportedFloatingType(dtype);
- }
- inline bool supportedFloatingOrComplexType(const Tensor& t) {
- return supportedFloatingOrComplexType(t.scalar_type());
- }
- inline bool needsGather(const Tensor& t) {
- return !t.is_contiguous() || t.storage_offset();
- }
- } // namespace at::native::mps
|