Skip to content

Commit 49fc613

Browse files
authored
[Unity][WEBGPU] Enable wasm exception propagation (#16330)
This PR enables wasm exception propagation among c++ runtime generated wasm and javascript. Right now the error.message is passed back this would allow us to do some handling in webgpu related exceptions raised through FFI boundaries. Note that this would require the latest emscripten and on the nodejs, --experimental-wasm-eh support.
1 parent d509661 commit 49fc613

File tree

7 files changed

+60
-7
lines changed

7 files changed

+60
-7
lines changed

python/tvm/contrib/emcc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def create_tvmjs_wasm(output, objects, options=None, cc="emcc"):
4242
cmd += ["-O3"]
4343
cmd += ["-std=c++17"]
4444
cmd += ["--no-entry"]
45+
cmd += ["-fwasm-exception"]
4546
cmd += ["-s", "WASM_BIGINT=1"]
4647
cmd += ["-s", "ERROR_ON_UNDEFINED_SYMBOLS=0"]
4748
cmd += ["-s", "STANDALONE_WASM=1"]

web/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ all: dist/wasm/tvmjs_runtime.wasm dist/wasm/tvmjs_runtime.wasi.js src/tvmjs_runt
2727

2828
EMCC = emcc
2929

30-
EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes
30+
EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes -fwasm-exceptions
3131

3232
EMCC_LDFLAGS = --no-entry -s WASM_BIGINT=1 -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1\
3333
-s ERROR_ON_UNDEFINED_SYMBOLS=0 --pre-js emcc/preload.js

web/emcc/tvmjs_support.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,15 @@ class AsyncLocalSession : public LocalSession {
148148
int code = args[0];
149149
TVMRetValue rv;
150150
rv = args[1];
151-
this->EncodeReturn(std::move(rv),
152-
[&](TVMArgs encoded_args) { callback(RPCCode::kReturn, encoded_args); });
151+
if (code == static_cast<int>(RPCCode::kReturn)) {
152+
this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) {
153+
callback(RPCCode::kReturn, encoded_args);
154+
});
155+
} else {
156+
// for exception, we can pass through as since this is just normal encoding.
157+
ICHECK_EQ(code, static_cast<int>(RPCCode::kException));
158+
callback(RPCCode::kException, args);
159+
}
153160
});
154161

155162
TVMRetValue temp;

web/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"build": "rollup -c",
1414
"lint": "eslint -c .eslintrc.json .",
1515
"typedoc": "typedoc src/index.ts --plugin typedoc-plugin-missing-exports",
16-
"test": "jest",
16+
"test": "node --experimental-wasm-eh node_modules/.bin/jest",
1717
"bundle": "npm run build && cp lib/index.js dist/index.js && cp lib/index.js dist/tvmjs.bundle.js",
1818
"example": "npm run bundle && node apps/node/example.js",
1919
"example:wasi": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_example.js",

web/src/ctypes.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ export type PtrOffset = number;
3333
*/
3434
export type FTVMGetLastError = () => Pointer;
3535

36+
/**
37+
* void TVMAPISetLastError(const char* msg);
38+
*/
39+
export type FTVMAPISetLastError = (msg: Pointer) => void;
40+
3641
/**
3742
* int TVMModGetFunction(TVMModuleHandle mod,
3843
* const char* func_name,

web/src/runtime.ts

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class FFILibrary implements Disposable {
7878
if (code != 0) {
7979
const msgPtr = (this.exports
8080
.TVMGetLastError as ctypes.FTVMGetLastError)();
81+
console.log("Here");
8182
throw new Error("TVMError: " + this.memory.loadCString(msgPtr));
8283
}
8384
}
@@ -1902,10 +1903,15 @@ export class Instance implements Disposable {
19021903
// need to keep it alive until callback is fulfilled.
19031904
const callback = this.detachFromCurrentScope(args[args.length - 1] as PackedFunc);
19041905
const promise: Promise<any> = func(...fargs);
1905-
promise.then((rv: any) => {
1906+
const onFulfilled = (rv: any) => {
19061907
callback(this.scalar(AsyncCallbackCode.kReturn, "int32"), rv);
19071908
callback.dispose();
1908-
});
1909+
};
1910+
const onRejected = (reason: any) => {
1911+
callback(this.scalar(AsyncCallbackCode.kException, "int32"), reason.toString());
1912+
callback.dispose();
1913+
};
1914+
promise.then(onFulfilled, onRejected);
19091915
};
19101916
this.registerFunc("__async." + name, asyncVariant, override);
19111917
}
@@ -2216,7 +2222,26 @@ export class Instance implements Disposable {
22162222
jsArgs.push(this.retValueToJS(valuePtr, tcode, true));
22172223
}
22182224

2219-
const rv = func(...jsArgs);
2225+
let rv: any;
2226+
try {
2227+
rv = func(...jsArgs);
2228+
} catch (error) {
2229+
// error handling
2230+
// store error via SetLastError
2231+
this.ctx.endScope();
2232+
const errMsg = "JSCallbackError: " + error.message;
2233+
const stack = lib.getOrAllocCallStack();
2234+
const errMsgOffset = stack.allocRawBytes(errMsg.length + 1);
2235+
stack.storeRawBytes(errMsgOffset, StringToUint8Array(errMsg));
2236+
stack.commitToWasmMemory();
2237+
(this.lib.exports.TVMAPISetLastError as ctypes.FTVMAPISetLastError)(
2238+
stack.ptrFromOffset(errMsgOffset)
2239+
);
2240+
this.lib.recycleCallStack(stack);
2241+
return -1;
2242+
}
2243+
2244+
// normal return path
22202245
// recycle all js object value in function unless we want to retain them.
22212246
this.ctx.endScope();
22222247

web/tests/node/test_packed_func.js

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,21 @@ test("RegisterGlobal", () => {
126126
tvm.endScope();
127127
});
128128

129+
test("ExceptionPassing", () => {
130+
tvm.beginScope();
131+
tvm.registerFunc("throw_error", function (msg) {
132+
throw Error(msg);
133+
});
134+
let f = tvm.getGlobalFunc("throw_error");
135+
try {
136+
f("error-xyz");
137+
throw Error("error not caught");
138+
} catch (error) {
139+
assert(error.message.indexOf("error-xyz") != -1);
140+
}
141+
tvm.endScope();
142+
});
143+
129144
test("NDArrayCbArg", () => {
130145
tvm.beginScope();
131146
let use_count = tvm.getGlobalFunc("testing.object_use_count");

0 commit comments

Comments
 (0)