Skip to content

Commit 835c439

Browse files
committed
[RUNTIME][METAL] Provide richer runtime when error happens
This PR enhances metal runtime to include more error messages when error happens.
1 parent 6a877df commit 835c439

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

src/runtime/metal/metal_common.h

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include <memory>
3939
#include <mutex>
4040
#include <string>
41+
#include <utility>
4142
#include <vector>
4243

4344
#include "../workspace_pool.h"
@@ -106,25 +107,35 @@ class AutoReleasePoolWrapper {
106107
*/
107108
class Stream {
108109
public:
109-
explicit Stream(id<MTLDevice> device) : error_happened_(false) {
110-
queue_ = [device newCommandQueue];
111-
}
110+
explicit Stream(id<MTLDevice> device) { queue_ = [device newCommandQueue]; }
112111
~Stream() { [queue_ release]; }
113-
id<MTLCommandBuffer> GetCommandBuffer() {
112+
id<MTLCommandBuffer> GetCommandBuffer(bool attach_error_callback = true) {
114113
id<MTLCommandBuffer> cb = [queue_ commandBuffer];
115114
[cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
116-
if (buffer.status == MTLCommandBufferStatusError) SetErrorStatus();
115+
if (buffer.status == MTLCommandBufferStatusError) {
116+
ICHECK(buffer.error != nil);
117+
this->SetError(buffer.error.localizedDescription.UTF8String);
118+
}
117119
}];
118120
return cb;
119121
}
120-
bool HasErrorHappened() { return error_happened_; }
122+
123+
void SetError(std::string error_description) {
124+
error_happened_ = true;
125+
error_description_ = std::move(error_description);
126+
}
127+
128+
bool HasErrorHappened() const { return error_happened_; }
129+
130+
const std::string& ErrorDescription() const { return error_description_; }
121131

122132
private:
123-
void SetErrorStatus() { error_happened_ = true; }
124133
// Queue
125134
id<MTLCommandQueue> queue_;
126135
// Check if error happened in one previous run
127-
bool error_happened_;
136+
bool error_happened_{false};
137+
// error description
138+
std::string error_description_;
128139
};
129140

130141
/*!

src/runtime/metal/metal_device_api.mm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ int GetWarpSize(id<MTLDevice> dev) {
222222
if (dev_from.device_type == kDLCPU) dev = dev_to;
223223
Stream* s = this->CastStreamOrGetDefault(stream, dev.device_id);
224224
if (s->HasErrorHappened()) {
225-
LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream";
225+
LOG(FATAL) << "GPUError: " << s->ErrorDescription();
226226
}
227227
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
228228
int from_dev_type = static_cast<int>(dev_from.device_type);
@@ -301,7 +301,7 @@ int GetWarpSize(id<MTLDevice> dev) {
301301
[cb commit];
302302
[cb waitUntilCompleted];
303303
if (s->HasErrorHappened()) {
304-
LOG(FATAL) << "Error! Some problems on GPU happaned!";
304+
LOG(FATAL) << "GPUError: " << s->ErrorDescription();
305305
}
306306
};
307307
}

src/runtime/metal/metal_module.mm

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,15 +194,19 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
194194
// obtain the stream
195195
auto stream =
196196
metal::MetalWorkspace::Global()->CastStreamOrGetDefault(t->stream[device_id], device_id);
197+
198+
// skip launching so the error can be printed during sync
197199
if (stream->HasErrorHappened()) return;
200+
198201
if (scache_[device_id] == nil) {
199202
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
200203
}
201204
ThreadWorkLoad wl = launch_param_config_.Extract(args);
202205
int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
203206
auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup;
204207
CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
205-
id<MTLCommandBuffer> cb = stream->GetCommandBuffer();
208+
// attach error message directly in this functio
209+
id<MTLCommandBuffer> cb = stream->GetCommandBuffer(/* attach_error_callback= */ false);
206210
id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
207211
[encoder setComputePipelineState:scache_[device_id]];
208212
for (size_t i = 0; i < num_buffer_args_; ++i) {
@@ -219,6 +223,16 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
219223
MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1), wl.block_dim(2));
220224
[encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock];
221225
[encoder endEncoding];
226+
// attach error message with function name
227+
[cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
228+
if (buffer.status == MTLCommandBufferStatusError) {
229+
ICHECK(buffer.error != nil);
230+
std::ostringstream os;
231+
os << "GPUError happens after running " << func_name_ << ": "
232+
<< buffer.error.localizedDescription.UTF8String;
233+
stream->SetError(os.str());
234+
}
235+
}];
222236
[cb commit];
223237
};
224238
}

0 commit comments

Comments
 (0)