Skip to content

Commit b762bf4

Browse files
authored
[Device] Catch WebGPU OOM error (#402)
Prior to this PR, when users `createEngine()` or call `reload()` with a model that is too large for the device, likely the device would keep generating, ignoring OOM issue and correctness. See #356 and #209. This PR catches such error with `device.lost.then()`, depending on tvmjs to call `device.destroy()` upon detecting error in `createBuffer()` via apache/tvm#17005. We have only observed `createBuffer()` errors and hence will only process such kind of errors for now. Besides, since most OOM errors occur in `reload()`, we make the error handling synchronous despite using `.then()` by throwing the error at the end of `reload()` if there is one.
1 parent 3481fed commit b762bf4

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

src/engine.ts

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ import { Conversation, compareConversationObject, getConversation } from "./conv
4242
* @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in
4343
* `engineConfig.appConfig`.
4444
* @param engineConfig Optionally configures the engine, see `webllm.EngineConfig`.
45-
* @returns An intialized `WebLLM.Engine` with `modelId` loaded.
45+
* @returns An initialized `WebLLM.Engine` with `modelId` loaded.
46+
* @throws Throws error when device lost (mostly due to OOM); users should re-call `CreateEngine()`,
47+
* potentially with a smaller model or smaller context window size.
4648
*/
4749
export async function CreateEngine(
4850
modelId: string,
@@ -70,7 +72,7 @@ export class Engine implements EngineInterface {
7072
private pipeline?: LLMChatPipeline;
7173
private initProgressCallback?: InitProgressCallback;
7274
private interruptSignal = false;
73-
private deviceLostIsError = false; // whether device.lost is due to actual error or model reload
75+
private deviceLostIsError = true; // whether device.lost is due to actual error or model reload
7476
private config?: ChatConfig;
7577

7678
constructor() {
@@ -89,8 +91,16 @@ export class Engine implements EngineInterface {
8991
this.logitProcessorRegistry = logitProcessorRegistry;
9092
}
9193

94+
/**
95+
* Reload model `modelId`.
96+
* @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in
97+
* `engineConfig.appConfig`.
98+
* @param chatOpts To optionally override the `mlc-chat-config.json` of `modelId`.
99+
* @param appConfig Configure the app with the list of models and whether to use IndexedDB cache.
100+
* @throws Throws error when device lost (mostly due to OOM); users should re-call reload(),
101+
* potentially with a smaller model or smaller context window size.
102+
*/
92103
async reload(modelId: string, chatOpts?: ChatOptions, appConfig?: AppConfig): Promise<void> {
93-
this.deviceLostIsError = false; // so that unload() does not trigger device.lost warning
94104
this.unload();
95105

96106
this.logitProcessor = this.logitProcessorRegistry?.get(modelId);
@@ -195,15 +205,21 @@ export class Engine implements EngineInterface {
195205
}
196206
}
197207

198-
tvm.initWebGPU(gpuDetectOutput.device);
208+
// Most device lost happens in `reload()` since we allocate memory ahead of time. So we can
209+
// use this flag at the end of `reload()` to make the error handling synchronous.
210+
// This `.then()` exists throughout the lifetime of the device. Though we have not
211+
// experienced device error outside of `reload()`, it is still possible this `.then()` is
212+
// triggered outside of `reload()`. TODO: does this cause unexpected behavior?
213+
let deviceLostInReload = false;
199214
gpuDetectOutput.device.lost.then((info: any) => {
200-
// `fetchNDArrayCache` may exceed available memory; use `lost.then` to prevent crashing
201215
if (this.deviceLostIsError) {
202216
console.error("Device was lost, please try to initialize again. ", info);
203217
this.unload();
218+
deviceLostInReload = true;
204219
}
205220
});
206-
this.deviceLostIsError = true;
221+
tvm.initWebGPU(gpuDetectOutput.device);
222+
207223
const tokenizer = await this.asyncLoadTokenizer(modelUrl, this.config, appConfig);
208224
const cacheType = appConfig.useIndexedDBCache ? "indexeddb" : "cache";
209225
await tvm.fetchNDArrayCache(modelUrl, tvm.webgpu(), "webllm/model", cacheType);
@@ -220,6 +236,13 @@ export class Engine implements EngineInterface {
220236
})
221237
}
222238
this.currentModelId = modelId;
239+
240+
if (deviceLostInReload) {
241+
throw Error(
242+
"WebGPU device lost during `reload()`.\n This is probably due to OOM, try reload with a " +
243+
"model that has less parameters or a smaller context length."
244+
);
245+
}
223246
}
224247

225248
async generate(
@@ -479,9 +502,11 @@ export class Engine implements EngineInterface {
479502
}
480503

481504
async unload() {
505+
this.deviceLostIsError = false; // so that unload() does not trigger device.lost error
482506
this.pipeline?.dispose();
483507
this.pipeline = undefined;
484508
this.currentModelId = undefined;
509+
this.deviceLostIsError = true;
485510
}
486511

487512
async getMaxStorageBufferBindingSize(): Promise<number> {

0 commit comments

Comments
 (0)