@@ -42,7 +42,9 @@ import { Conversation, compareConversationObject, getConversation } from "./conv
4242 * @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in
4343 * `engineConfig.appConfig`.
4444 * @param engineConfig Optionally configures the engine, see `webllm.EngineConfig`.
45- * @returns An intialized `WebLLM.Engine` with `modelId` loaded.
45+ * @returns An initialized `WebLLM.Engine` with `modelId` loaded.
46+ * @throws Throws error when device lost (mostly due to OOM); users should re-call `CreateEngine()`,
47+ * potentially with a smaller model or smaller context window size.
4648 */
4749export async function CreateEngine (
4850 modelId : string ,
@@ -70,7 +72,7 @@ export class Engine implements EngineInterface {
7072 private pipeline ?: LLMChatPipeline ;
7173 private initProgressCallback ?: InitProgressCallback ;
7274 private interruptSignal = false ;
73- private deviceLostIsError = false ; // whether device.lost is due to actual error or model reload
75+ private deviceLostIsError = true ; // whether device.lost is due to actual error or model reload
7476 private config ?: ChatConfig ;
7577
7678 constructor ( ) {
@@ -89,8 +91,16 @@ export class Engine implements EngineInterface {
8991 this . logitProcessorRegistry = logitProcessorRegistry ;
9092 }
9193
94+ /**
95+ * Reload model `modelId`.
96+ * @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in
97+ * `engineConfig.appConfig`.
98+ * @param chatOpts To optionally override the `mlc-chat-config.json` of `modelId`.
99+ * @param appConfig Configure the app with the list of models and whether to use IndexedDB cache.
100+ * @throws Throws error when device lost (mostly due to OOM); users should re-call reload(),
101+ * potentially with a smaller model or smaller context window size.
102+ */
92103 async reload ( modelId : string , chatOpts ?: ChatOptions , appConfig ?: AppConfig ) : Promise < void > {
93- this . deviceLostIsError = false ; // so that unload() does not trigger device.lost warning
94104 this . unload ( ) ;
95105
96106 this . logitProcessor = this . logitProcessorRegistry ?. get ( modelId ) ;
@@ -195,15 +205,21 @@ export class Engine implements EngineInterface {
195205 }
196206 }
197207
198- tvm . initWebGPU ( gpuDetectOutput . device ) ;
208+ // Most device lost happens in `reload()` since we allocate memory ahead of time. So we can
209+ // use this flag at the end of `reload()` to make the error handling synchronous.
210+ // This `.then()` exists throughout the lifetime of the device. Though we have not
211+ // experienced device error outside of `reload()`, it is still possible this `.then()` is
212+ // triggered outside of `reload()`. TODO: does this cause unexpected behavior?
213+ let deviceLostInReload = false ;
199214 gpuDetectOutput . device . lost . then ( ( info : any ) => {
200- // `fetchNDArrayCache` may exceed available memory; use `lost.then` to prevent crashing
201215 if ( this . deviceLostIsError ) {
202216 console . error ( "Device was lost, please try to initialize again. " , info ) ;
203217 this . unload ( ) ;
218+ deviceLostInReload = true ;
204219 }
205220 } ) ;
206- this . deviceLostIsError = true ;
221+ tvm . initWebGPU ( gpuDetectOutput . device ) ;
222+
207223 const tokenizer = await this . asyncLoadTokenizer ( modelUrl , this . config , appConfig ) ;
208224 const cacheType = appConfig . useIndexedDBCache ? "indexeddb" : "cache" ;
209225 await tvm . fetchNDArrayCache ( modelUrl , tvm . webgpu ( ) , "webllm/model" , cacheType ) ;
@@ -220,6 +236,13 @@ export class Engine implements EngineInterface {
220236 } )
221237 }
222238 this . currentModelId = modelId ;
239+
240+ if ( deviceLostInReload ) {
241+ throw Error (
242+ "WebGPU device lost during `reload()`.\n This is probably due to OOM, try reload with a " +
243+ "model that has less parameters or a smaller context length."
244+ ) ;
245+ }
223246 }
224247
225248 async generate (
@@ -479,9 +502,11 @@ export class Engine implements EngineInterface {
479502 }
480503
481504 async unload ( ) {
505+ this . deviceLostIsError = false ; // so that unload() does not trigger device.lost error
482506 this . pipeline ?. dispose ( ) ;
483507 this . pipeline = undefined ;
484508 this . currentModelId = undefined ;
509+ this . deviceLostIsError = true ;
485510 }
486511
487512 async getMaxStorageBufferBindingSize ( ) : Promise < number > {
0 commit comments