4242
4343void MetalWorkspace::GetAttr (Device dev, DeviceAttrKind kind, TVMRetValue* rv) {
4444 AUTORELEASEPOOL {
45- this ->Init ();
4645 size_t index = static_cast <size_t >(dev.device_id );
4746 if (kind == kExist ) {
4847 *rv = int (index < devices.size ());
@@ -142,29 +141,18 @@ int GetWarpSize(id<MTLDevice> dev) {
142141 }
143142}
144143
145- void MetalWorkspace::ReinitializeStreams () {
146- std::vector<Stream*>& threadStreams = MetalThreadEntry::ThreadLocal ()->stream ;
147- ICHECK_EQ (default_streams_.size (), threadStreams.size ());
144+ void MetalWorkspace::ReinitializeDefaultStreams () {
148145 for (size_t i = 0 ; i < default_streams_.size (); ++i) {
149- if (threadStreams[i] != nullptr && default_streams_[i] != threadStreams[i])
150- delete threadStreams[i];
151146 delete default_streams_[i];
152147 }
153148 default_streams_.resize (devices.size ());
154- threadStreams.resize (devices.size ());
155149 for (size_t i = 0 ; i < devices.size (); ++i) {
156150 Stream* stream = new Stream (devices[i]);
157151 default_streams_[i] = stream;
158- threadStreams[i] = stream;
159152 }
160153}
161154
162- void MetalWorkspace::Init () {
163- if (initialized_) return ;
164- std::lock_guard<std::mutex> lock (this ->mutex );
165- if (initialized_) return ;
166- initialized_ = true ;
167- if (devices.size () != 0 ) return ;
155+ MetalWorkspace::MetalWorkspace () {
168156#if TARGET_OS_IPHONE
169157 // on iPhone
170158 id <MTLDevice > d = MTLCreateSystemDefaultDevice ();
@@ -178,7 +166,7 @@ int GetWarpSize(id<MTLDevice> dev) {
178166 warp_size.push_back (GetWarpSize (d));
179167 }
180168#endif
181- ReinitializeStreams ();
169+ this -> ReinitializeDefaultStreams ();
182170}
183171
184172void MetalWorkspace::SetDevice (Device dev) {
@@ -189,7 +177,6 @@ int GetWarpSize(id<MTLDevice> dev) {
189177 DLDataType type_hint) {
190178 id <MTLBuffer > buf;
191179 AUTORELEASEPOOL {
192- this ->Init ();
193180 id <MTLDevice > dev = GetDevice (device);
194181 // GPU memory only
195182 MTLResourceOptions storage_mode = MTLResourceStorageModePrivate ;
@@ -220,20 +207,20 @@ int GetWarpSize(id<MTLDevice> dev) {
220207 };
221208}
222209
223- Stream* CastStreamOrGetCurrent (TVMStreamHandle stream, int device_id) {
210+ Stream* MetalWorkspace::CastStreamOrGetDefault (TVMStreamHandle stream, int device_id) {
224211 if (stream != nullptr ) return static_cast <Stream*>(stream);
225- ICHECK (MetalThreadEntry::ThreadLocal ()->stream [device_id] != nullptr );
226- return MetalThreadEntry::ThreadLocal ()->stream [device_id];
212+ ICHECK_LT (static_cast <size_t >(device_id), default_streams_.size ());
213+ ICHECK (default_streams_[device_id] != nullptr );
214+ return default_streams_[device_id];
227215}
228216
229217void MetalWorkspace::CopyDataFromTo (const void * from, size_t from_offset, void * to,
230218 size_t to_offset, size_t size, Device dev_from, Device dev_to,
231219 DLDataType type_hint, TVMStreamHandle stream) {
232220 AUTORELEASEPOOL {
233- this ->Init ();
234221 Device dev = dev_from;
235222 if (dev_from.device_type == kDLCPU ) dev = dev_to;
236- Stream* s = CastStreamOrGetCurrent (stream, dev.device_id );
223+ Stream* s = this -> CastStreamOrGetDefault (stream, dev.device_id );
237224 if (s->HasErrorHappened ()) {
238225 LOG (FATAL) << " Error! Some problems on GPU happaned! Cannot copy data to current stream" ;
239226 }
@@ -303,15 +290,12 @@ int GetWarpSize(id<MTLDevice> dev) {
303290void MetalWorkspace::FreeStream (Device dev, TVMStreamHandle stream) {
304291 ICHECK (stream != nullptr );
305292 ICHECK_LT (dev.device_id , devices.size ()) << " Invalid device id " << dev.device_id ;
306- Stream* s = static_cast <Stream*>(stream);
307- if (MetalThreadEntry::ThreadLocal ()->stream [dev.device_id] == s)
308- MetalThreadEntry::ThreadLocal ()->stream [dev.device_id] = nullptr ;
309- delete s;
293+ delete static_cast <Stream*>(stream);
310294}
311295
312296void MetalWorkspace::StreamSync (Device dev, TVMStreamHandle stream) {
313297 AUTORELEASEPOOL {
314- Stream* s = CastStreamOrGetCurrent (stream, dev.device_id );
298+ Stream* s = CastStreamOrGetDefault (stream, dev.device_id );
315299 // commit an empty command buffer and wait until it completes.
316300 id <MTLCommandBuffer > cb = s->GetCommandBuffer ();
317301 [cb commit ];
@@ -325,7 +309,7 @@ int GetWarpSize(id<MTLDevice> dev) {
325309void MetalWorkspace::SetStream (Device dev, TVMStreamHandle stream) {
326310 ICHECK_LT (dev.device_id , devices.size ()) << " Invalid device id " << dev.device_id ;
327311 ICHECK (stream != nullptr );
328- MetalThreadEntry::ThreadLocal ()->stream [dev.device_id] = static_cast <Stream*>( stream) ;
312+ MetalThreadEntry::ThreadLocal ()->stream [dev.device_id] = stream;
329313}
330314
331315void * MetalWorkspace::AllocWorkspace (Device dev, size_t size, DLDataType type_hint) {
@@ -374,7 +358,7 @@ int GetWarpSize(id<MTLDevice> dev) {
374358});
375359
376360TVM_REGISTER_GLOBAL (" metal.ResetGlobalState" ).set_body_typed([]() {
377- MetalWorkspace::Global ()->ReinitializeStreams ();
361+ MetalWorkspace::Global ()->ReinitializeDefaultStreams ();
378362});
379363
380364} // namespace metal
0 commit comments