diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index dad156bcdddc..d9154e0f7906 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -136,10 +136,7 @@ class MetalWorkspace final : public DeviceAPI { std::vector> devices; // Warp size constant std::vector warp_size; - // Whether it is initialized. - bool initialized_{false}; - // the mutex for initialization - std::mutex mutex; + MetalWorkspace(); // Destructor ~MetalWorkspace(); // Get device for given device @@ -149,9 +146,6 @@ class MetalWorkspace final : public DeviceAPI { << "Invalid Metal device_id=" << dev.device_id; return devices[dev.device_id]; } - // Initialize workspace - // Return false if already initialized, otherwise return true. - void Init(); // override device API void SetDevice(Device dev) final; void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; @@ -163,7 +157,16 @@ class MetalWorkspace final : public DeviceAPI { void SetStream(Device dev, TVMStreamHandle stream) final; void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; - void ReinitializeStreams(); + void ReinitializeDefaultStreams(); + + /** + * Cast stream to the right metal stream data structure + * if stream is nullptr , return the default stream of device_id + * \param stream the input stream handle + * \param device_id The device id of interest + * \returns The stream used in this function. + */ + Stream* CastStreamOrGetDefault(TVMStreamHandle stream, int device_id); // get the global workspace static MetalWorkspace* Global(); @@ -184,7 +187,7 @@ class MetalThreadEntry { /*! \brief The current device */ Device device; /*! \brief The current stream */ - std::vector stream; + std::vector stream; /*! \brief The shared buffer used for copy. */ std::vector> temp_buffer_; /*! \brief workspace pool */ @@ -193,6 +196,10 @@ class MetalThreadEntry { MetalThreadEntry() : pool(static_cast(kDLMetal), MetalWorkspace::Global()) { device.device_id = 0; device.device_type = static_cast(kDLMetal); + MetalWorkspace* global_ws = MetalWorkspace::Global(); + // by default, set the stream to nullptr, which indicate + // that we are using default stream + this->stream.resize(global_ws->devices.size(), nullptr); } ~MetalThreadEntry(); // Get temp buffer with at least size under dev. diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index c4ffc8943c01..e3853ef6d62a 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -42,7 +42,6 @@ void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { AUTORELEASEPOOL { - this->Init(); size_t index = static_cast(dev.device_id); if (kind == kExist) { *rv = int(index < devices.size()); @@ -142,29 +141,18 @@ int GetWarpSize(id dev) { } } -void MetalWorkspace::ReinitializeStreams() { - std::vector& threadStreams = MetalThreadEntry::ThreadLocal()->stream; - ICHECK_EQ(default_streams_.size(), threadStreams.size()); +void MetalWorkspace::ReinitializeDefaultStreams() { for (size_t i = 0; i < default_streams_.size(); ++i) { - if (threadStreams[i] != nullptr && default_streams_[i] != threadStreams[i]) - delete threadStreams[i]; delete default_streams_[i]; } default_streams_.resize(devices.size()); - threadStreams.resize(devices.size()); for (size_t i = 0; i < devices.size(); ++i) { Stream* stream = new Stream(devices[i]); default_streams_[i] = stream; - threadStreams[i] = stream; } } -void MetalWorkspace::Init() { - if (initialized_) return; - std::lock_guard lock(this->mutex); - if (initialized_) return; - initialized_ = true; - if (devices.size() != 0) return; +MetalWorkspace::MetalWorkspace() { #if TARGET_OS_IPHONE // on iPhone id d = MTLCreateSystemDefaultDevice(); @@ -178,7 +166,7 @@ int GetWarpSize(id dev) { warp_size.push_back(GetWarpSize(d)); } #endif - ReinitializeStreams(); + this->ReinitializeDefaultStreams(); } void MetalWorkspace::SetDevice(Device dev) { @@ -189,7 +177,6 @@ int GetWarpSize(id dev) { DLDataType type_hint) { id buf; AUTORELEASEPOOL { - this->Init(); id dev = GetDevice(device); // GPU memory only MTLResourceOptions storage_mode = MTLResourceStorageModePrivate; @@ -220,20 +207,20 @@ int GetWarpSize(id dev) { }; } -Stream* CastStreamOrGetCurrent(TVMStreamHandle stream, int device_id) { +Stream* MetalWorkspace::CastStreamOrGetDefault(TVMStreamHandle stream, int device_id) { if (stream != nullptr) return static_cast(stream); - ICHECK(MetalThreadEntry::ThreadLocal()->stream[device_id] != nullptr); - return MetalThreadEntry::ThreadLocal()->stream[device_id]; + ICHECK_LT(static_cast(device_id), default_streams_.size()); + ICHECK(default_streams_[device_id] != nullptr); + return default_streams_[device_id]; } void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) { AUTORELEASEPOOL { - this->Init(); Device dev = dev_from; if (dev_from.device_type == kDLCPU) dev = dev_to; - Stream* s = CastStreamOrGetCurrent(stream, dev.device_id); + Stream* s = this->CastStreamOrGetDefault(stream, dev.device_id); if (s->HasErrorHappened()) { LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream"; } @@ -303,15 +290,12 @@ int GetWarpSize(id dev) { void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) { ICHECK(stream != nullptr); ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; - Stream* s = static_cast(stream); - if (MetalThreadEntry::ThreadLocal()->stream[dev.device_id] == s) - MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = nullptr; - delete s; + delete static_cast(stream); } void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { AUTORELEASEPOOL { - Stream* s = CastStreamOrGetCurrent(stream, dev.device_id); + Stream* s = CastStreamOrGetDefault(stream, dev.device_id); // commit an empty command buffer and wait until it completes. id cb = s->GetCommandBuffer(); [cb commit]; @@ -325,7 +309,7 @@ int GetWarpSize(id dev) { void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) { ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; ICHECK(stream != nullptr); - MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast(stream); + MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = stream; } void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { @@ -374,7 +358,7 @@ int GetWarpSize(id dev) { }); TVM_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() { - MetalWorkspace::Global()->ReinitializeStreams(); + MetalWorkspace::Global()->ReinitializeDefaultStreams(); }); } // namespace metal diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 98e32cdf9caa..01d107942664 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -191,7 +191,9 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons AUTORELEASEPOOL { metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->device.device_id; - auto stream = static_cast(t->stream[device_id]); + // obtain the stream + auto stream = + metal::MetalWorkspace::Global()->CastStreamOrGetDefault(t->stream[device_id], device_id); if (stream->HasErrorHappened()) return; if (scache_[device_id] == nil) { scache_[device_id] = m_->GetPipelineState(device_id, func_name_); @@ -265,10 +267,7 @@ Module MetalModuleCreate(std::unordered_map smap, std::unordered_map fmap, std::string fmt, std::string source) { ObjectPtr n; - AUTORELEASEPOOL { - metal::MetalWorkspace::Global()->Init(); - n = make_object(smap, fmap, fmt, source); - }; + AUTORELEASEPOOL { n = make_object(smap, fmap, fmt, source); }; return Module(n); }