Skip to content

Commit cb66e6c

Browse files
tqchenyongwww
authored andcommitted
[WEB] Reduce memleak in web runtime (apache#14086)
This PR robustifies the web runtime to reduce memory leak and enhances the runtime with object support. Specifically we introduce scoping and auto-release mechanism when we exit the scope. The improvements are helpful to deal with memory leak in wasm and webgpu settings
1 parent b956917 commit cb66e6c

File tree

13 files changed

+670
-161
lines changed

13 files changed

+670
-161
lines changed

web/.eslintignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
dist
2+
debug

web/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ The following is an example to reproduce this.
8181
- Start the WebSocket RPC
8282
- Browswer version: open https://localhost:8888, click connect to proxy
8383
- NodeJS version: `npm run rpc`
84-
- run `python tests/node/websock_rpc_test.py` to run the rpc client.
84+
- run `python tests/python/websock_rpc_test.py` to run the rpc test.
8585

8686

8787
## WebGPU Experiments

web/apps/node/example.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm"));
3131
// the async version of the API.
3232
tvmjs.instantiate(wasmSource, new EmccWASI())
3333
.then((tvm) => {
34+
tvm.beginScope();
3435
const log_info = tvm.getGlobalFunc("testing.log_info_str");
3536
log_info("hello world");
3637
// List all the global functions from the runtime.
3738
console.log("Runtime functions using EmccWASI\n", tvm.listGlobalFuncNames());
39+
tvm.endScope();
3840
});

web/emcc/wasm_runtime.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
#include <tvm/runtime/logging.h>
3333

3434
#include "src/runtime/c_runtime_api.cc"
35+
#include "src/runtime/container.cc"
3536
#include "src/runtime/contrib/sort/sort.cc"
3637
#include "src/runtime/cpu_device_api.cc"
3738
#include "src/runtime/file_utils.cc"
38-
#include "src/runtime/graph_executor/graph_executor.cc"
3939
#include "src/runtime/library_module.cc"
4040
#include "src/runtime/logging.cc"
4141
#include "src/runtime/module.cc"

web/src/ctypes.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ export type FTVMModGetFunction = (
4646
* TVMModuleHandle dep);
4747
*/
4848
export type FTVMModImport = (mod: Pointer, dep: Pointer) => number;
49+
4950
/**
5051
* int TVMModFree(TVMModuleHandle mod);
5152
*/
@@ -161,6 +162,27 @@ export type FTVMBackendPackedCFunc = (
161162
argValues: Pointer, argCodes: Pointer, nargs: number,
162163
outValue: Pointer, outCode: Pointer) => number;
163164

165+
166+
/**
167+
* int TVMObjectFree(TVMObjectHandle obj);
168+
*/
169+
export type FTVMObjectFree = (obj: Pointer) => number;
170+
171+
/**
172+
* int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);
173+
*/
174+
export type FTVMObjectGetTypeIndex = (obj: Pointer, out_tindex: Pointer) => number;
175+
176+
/**
177+
* int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key);
178+
*/
179+
export type FTVMObjectTypeIndex2Key = (type_index: number, out_type_key: Pointer) => number;
180+
181+
/**
182+
* int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);
183+
*/
184+
export type FTVMObjectTypeKey2Index = (type_key: Pointer, out_tindex: Pointer) => number;
185+
164186
// -- TVM Wasm Auxiliary C API --
165187

166188
/** void* TVMWasmAllocSpace(int size); */

web/src/index.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919

2020
export {
2121
Scalar, DLDevice, DLDataType,
22-
PackedFunc, Module, NDArray, Instance,
23-
instantiate
22+
PackedFunc, Module, NDArray,
23+
TVMArray,
24+
Instance, instantiate
2425
} from "./runtime";
2526
export { Disposable, LibraryProvider } from "./types";
2627
export { RPCServer } from "./rpc_server";

web/src/rpc_server.ts

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import { assert, StringToUint8Array, Uint8ArrayToString } from "./support";
2222
import { detectGPUDevice } from "./webgpu";
2323
import * as compact from "./compact";
2424
import * as runtime from "./runtime";
25+
import { timeStamp } from "console";
26+
import { Disposable } from "./types";
2527

2628
enum RPCServerState {
2729
InitHeader,
@@ -83,6 +85,7 @@ export class RPCServer {
8385
private pendingSend: Promise<void> = Promise.resolve();
8486
private name: string;
8587
private inst?: runtime.Instance = undefined;
88+
private globalObjects: Array<Disposable> = [];
8689
private serverRecvData?: (header: Uint8Array, body: Uint8Array) => void;
8790
private currPacketHeader?: Uint8Array;
8891
private currPacketLength = 0;
@@ -121,6 +124,9 @@ export class RPCServer {
121124
// eslint-disable-next-line @typescript-eslint/no-unused-vars
122125
private onClose(_event: CloseEvent): void {
123126
if (this.inst !== undefined) {
127+
this.globalObjects.forEach(obj => {
128+
obj.dispose();
129+
});
124130
this.inst.dispose();
125131
}
126132
if (this.state == RPCServerState.ReceivePacketHeader) {
@@ -263,6 +269,9 @@ export class RPCServer {
263269
}
264270

265271
this.inst = inst;
272+
// begin scope to allow handling of objects
273+
// the object should stay alive during all sessions.
274+
this.inst.beginScope();
266275
const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer");
267276

268277
const messageHandler = fcreate(
@@ -301,8 +310,10 @@ export class RPCServer {
301310
this.name,
302311
this.key
303312
);
304-
305-
fcreate.dispose();
313+
// message handler should persist across RPC runs
314+
this.globalObjects.push(
315+
this.inst.detachFromCurrentScope(messageHandler)
316+
);
306317
const writeFlag = this.inst.scalar(3, "int32");
307318

308319
this.serverRecvData = (header: Uint8Array, body: Uint8Array): void => {
@@ -320,7 +331,6 @@ export class RPCServer {
320331
// register the callback to redirect the session to local.
321332
const flocal = this.inst.getGlobalFunc("wasm.LocalSession");
322333
const localSession = flocal();
323-
flocal.dispose();
324334
assert(localSession instanceof runtime.Module);
325335

326336
// eslint-disable-next-line @typescript-eslint/no-unused-vars
@@ -333,13 +343,14 @@ export class RPCServer {
333343
);
334344
messageHandler(header, writeFlag);
335345
messageHandler(body, writeFlag);
336-
localSession.dispose();
337346

338347
this.log("Finish initializing the Wasm Server..");
339348
this.requestBytes(SizeOf.I64);
340349
this.state = RPCServerState.ReceivePacketHeader;
341350
// call process events in case there are bufferred data.
342351
this.processEvents();
352+
// recycle all values.
353+
this.inst.endScope();
343354
};
344355

345356
this.state = RPCServerState.WaitForCallback;

0 commit comments

Comments
 (0)