Skip to content

Commit f0c28a0

Browse files
MasterJH5574tqchen
andauthored
[RUNTIME][METAL] Fix multithreading access of metal runtime (#16605)
This PR fixes a bug where metal runtime cannot be accessed from multiple threads. This is because the ThreadLocal entry initialization happens during global workspace initialization, meaning other threads that tries to use metal runtime later cannot have the thread local entry correctly initialized. This PR fixes the problem by always use nullptr fallback and lookup at the global workspace for default stream. Co-authored-by: tqchen <[email protected]>
1 parent 94e83f2 commit f0c28a0

File tree

3 files changed

+32
-42
lines changed

3 files changed

+32
-42
lines changed

src/runtime/metal/metal_common.h

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,7 @@ class MetalWorkspace final : public DeviceAPI {
136136
std::vector<id<MTLDevice>> devices;
137137
// Warp size constant
138138
std::vector<int> warp_size;
139-
// Whether it is initialized.
140-
bool initialized_{false};
141-
// the mutex for initialization
142-
std::mutex mutex;
139+
MetalWorkspace();
143140
// Destructor
144141
~MetalWorkspace();
145142
// Get device for given device
@@ -149,9 +146,6 @@ class MetalWorkspace final : public DeviceAPI {
149146
<< "Invalid Metal device_id=" << dev.device_id;
150147
return devices[dev.device_id];
151148
}
152-
// Initialize workspace
153-
// Return false if already initialized, otherwise return true.
154-
void Init();
155149
// override device API
156150
void SetDevice(Device dev) final;
157151
void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final;
@@ -163,7 +157,16 @@ class MetalWorkspace final : public DeviceAPI {
163157
void SetStream(Device dev, TVMStreamHandle stream) final;
164158
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
165159
void FreeWorkspace(Device dev, void* data) final;
166-
void ReinitializeStreams();
160+
void ReinitializeDefaultStreams();
161+
162+
/**
163+
* Cast stream to the right metal stream data structure
164+
* if stream is nullptr , return the default stream of device_id
165+
* \param stream the input stream handle
166+
* \param device_id The device id of interest
167+
* \returns The stream used in this function.
168+
*/
169+
Stream* CastStreamOrGetDefault(TVMStreamHandle stream, int device_id);
167170

168171
// get the global workspace
169172
static MetalWorkspace* Global();
@@ -184,7 +187,7 @@ class MetalThreadEntry {
184187
/*! \brief The current device */
185188
Device device;
186189
/*! \brief The current stream */
187-
std::vector<Stream*> stream;
190+
std::vector<TVMStreamHandle> stream;
188191
/*! \brief The shared buffer used for copy. */
189192
std::vector<id<MTLBuffer>> temp_buffer_;
190193
/*! \brief workspace pool */
@@ -193,6 +196,10 @@ class MetalThreadEntry {
193196
MetalThreadEntry() : pool(static_cast<DLDeviceType>(kDLMetal), MetalWorkspace::Global()) {
194197
device.device_id = 0;
195198
device.device_type = static_cast<DLDeviceType>(kDLMetal);
199+
MetalWorkspace* global_ws = MetalWorkspace::Global();
200+
// by default, set the stream to nullptr, which indicate
201+
// that we are using default stream
202+
this->stream.resize(global_ws->devices.size(), nullptr);
196203
}
197204
~MetalThreadEntry();
198205
// Get temp buffer with at least size under dev.

src/runtime/metal/metal_device_api.mm

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242

4343
void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) {
4444
AUTORELEASEPOOL {
45-
this->Init();
4645
size_t index = static_cast<size_t>(dev.device_id);
4746
if (kind == kExist) {
4847
*rv = int(index < devices.size());
@@ -142,29 +141,18 @@ int GetWarpSize(id<MTLDevice> dev) {
142141
}
143142
}
144143

145-
void MetalWorkspace::ReinitializeStreams() {
146-
std::vector<Stream*>& threadStreams = MetalThreadEntry::ThreadLocal()->stream;
147-
ICHECK_EQ(default_streams_.size(), threadStreams.size());
144+
void MetalWorkspace::ReinitializeDefaultStreams() {
148145
for (size_t i = 0; i < default_streams_.size(); ++i) {
149-
if (threadStreams[i] != nullptr && default_streams_[i] != threadStreams[i])
150-
delete threadStreams[i];
151146
delete default_streams_[i];
152147
}
153148
default_streams_.resize(devices.size());
154-
threadStreams.resize(devices.size());
155149
for (size_t i = 0; i < devices.size(); ++i) {
156150
Stream* stream = new Stream(devices[i]);
157151
default_streams_[i] = stream;
158-
threadStreams[i] = stream;
159152
}
160153
}
161154

162-
void MetalWorkspace::Init() {
163-
if (initialized_) return;
164-
std::lock_guard<std::mutex> lock(this->mutex);
165-
if (initialized_) return;
166-
initialized_ = true;
167-
if (devices.size() != 0) return;
155+
MetalWorkspace::MetalWorkspace() {
168156
#if TARGET_OS_IPHONE
169157
// on iPhone
170158
id<MTLDevice> d = MTLCreateSystemDefaultDevice();
@@ -178,7 +166,7 @@ int GetWarpSize(id<MTLDevice> dev) {
178166
warp_size.push_back(GetWarpSize(d));
179167
}
180168
#endif
181-
ReinitializeStreams();
169+
this->ReinitializeDefaultStreams();
182170
}
183171

184172
void MetalWorkspace::SetDevice(Device dev) {
@@ -189,7 +177,6 @@ int GetWarpSize(id<MTLDevice> dev) {
189177
DLDataType type_hint) {
190178
id<MTLBuffer> buf;
191179
AUTORELEASEPOOL {
192-
this->Init();
193180
id<MTLDevice> dev = GetDevice(device);
194181
// GPU memory only
195182
MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
@@ -220,20 +207,20 @@ int GetWarpSize(id<MTLDevice> dev) {
220207
};
221208
}
222209

223-
Stream* CastStreamOrGetCurrent(TVMStreamHandle stream, int device_id) {
210+
Stream* MetalWorkspace::CastStreamOrGetDefault(TVMStreamHandle stream, int device_id) {
224211
if (stream != nullptr) return static_cast<Stream*>(stream);
225-
ICHECK(MetalThreadEntry::ThreadLocal()->stream[device_id] != nullptr);
226-
return MetalThreadEntry::ThreadLocal()->stream[device_id];
212+
ICHECK_LT(static_cast<size_t>(device_id), default_streams_.size());
213+
ICHECK(default_streams_[device_id] != nullptr);
214+
return default_streams_[device_id];
227215
}
228216

229217
void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
230218
size_t to_offset, size_t size, Device dev_from, Device dev_to,
231219
DLDataType type_hint, TVMStreamHandle stream) {
232220
AUTORELEASEPOOL {
233-
this->Init();
234221
Device dev = dev_from;
235222
if (dev_from.device_type == kDLCPU) dev = dev_to;
236-
Stream* s = CastStreamOrGetCurrent(stream, dev.device_id);
223+
Stream* s = this->CastStreamOrGetDefault(stream, dev.device_id);
237224
if (s->HasErrorHappened()) {
238225
LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream";
239226
}
@@ -303,15 +290,12 @@ int GetWarpSize(id<MTLDevice> dev) {
303290
void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) {
304291
ICHECK(stream != nullptr);
305292
ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id;
306-
Stream* s = static_cast<Stream*>(stream);
307-
if (MetalThreadEntry::ThreadLocal()->stream[dev.device_id] == s)
308-
MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = nullptr;
309-
delete s;
293+
delete static_cast<Stream*>(stream);
310294
}
311295

312296
void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
313297
AUTORELEASEPOOL {
314-
Stream* s = CastStreamOrGetCurrent(stream, dev.device_id);
298+
Stream* s = CastStreamOrGetDefault(stream, dev.device_id);
315299
// commit an empty command buffer and wait until it completes.
316300
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
317301
[cb commit];
@@ -325,7 +309,7 @@ int GetWarpSize(id<MTLDevice> dev) {
325309
void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) {
326310
ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id;
327311
ICHECK(stream != nullptr);
328-
MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast<Stream*>(stream);
312+
MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = stream;
329313
}
330314

331315
void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) {
@@ -374,7 +358,7 @@ int GetWarpSize(id<MTLDevice> dev) {
374358
});
375359

376360
TVM_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() {
377-
MetalWorkspace::Global()->ReinitializeStreams();
361+
MetalWorkspace::Global()->ReinitializeDefaultStreams();
378362
});
379363

380364
} // namespace metal

src/runtime/metal/metal_module.mm

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
191191
AUTORELEASEPOOL {
192192
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
193193
int device_id = t->device.device_id;
194-
auto stream = static_cast<metal::Stream*>(t->stream[device_id]);
194+
// obtain the stream
195+
auto stream =
196+
metal::MetalWorkspace::Global()->CastStreamOrGetDefault(t->stream[device_id], device_id);
195197
if (stream->HasErrorHappened()) return;
196198
if (scache_[device_id] == nil) {
197199
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
@@ -265,10 +267,7 @@ Module MetalModuleCreate(std::unordered_map<std::string, std::string> smap,
265267
std::unordered_map<std::string, FunctionInfo> fmap, std::string fmt,
266268
std::string source) {
267269
ObjectPtr<Object> n;
268-
AUTORELEASEPOOL {
269-
metal::MetalWorkspace::Global()->Init();
270-
n = make_object<MetalModuleNode>(smap, fmap, fmt, source);
271-
};
270+
AUTORELEASEPOOL { n = make_object<MetalModuleNode>(smap, fmap, fmt, source); };
272271
return Module(n);
273272
}
274273

0 commit comments

Comments
 (0)