Skip to content

Commit 4d1915e

Browse files
authored
[Fix] Avoid unnecessary engine reload by correctly comparing ChatOption and AppConfig objects (#399)
## Overview Currently, even the client is initializing the worker engine with the exact same configurations, the worker will not correctly recognize this but instead it will unnecessarily re-initialize itself. The root cause is due to the use of `===` to compare object equity which is actually comparing object reference equity instead of value equity. This PR fixed it by create utility functions for deep comparing the **VALUE** of these config objects. The code is tedious and thus I generated using AI models. ## Test Tested on https://chat.neet.coffee with the following code added to `web_service_worker.ts`. ```typescript console.log("modelId same? " + this.modelId === params.modelId); console.log("chatOpts same? " + areChatOptionsEqual(this.chatOpts, params.chatOpts)); console.log("appConfig same? " + areAppConfigsEqual(this.appConfig, params.appConfig)); ``` Before the fix: ``` modelId same? true chatOpts same? false appConfig same? false ``` After: ``` modelId same? true chatOpts same? true appConfig same? true Already loaded the model. Skip loading ```
1 parent 3379591 commit 4d1915e

File tree

3 files changed

+133
-6
lines changed

3 files changed

+133
-6
lines changed

src/service_worker.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import {
88
WebWorkerEngine,
99
PostMessageHandler,
1010
} from "./web_worker";
11+
import { areAppConfigsEqual, areChatOptionsEqual } from "./utils";
1112

1213
/**
1314
* A post message handler that sends messages to a chrome.runtime.Port.
@@ -84,8 +85,8 @@ export class ServiceWorkerEngineHandler extends EngineWorkerHandler {
8485
// If the modelId, chatOpts, and appConfig are the same, immediately return
8586
if (
8687
this.modelId === params.modelId &&
87-
this.chatOpts === params.chatOpts &&
88-
this.appConfig === params.appConfig
88+
areChatOptionsEqual(this.chatOpts, params.chatOpts) &&
89+
areAppConfigsEqual(this.appConfig, params.appConfig)
8990
) {
9091
console.log("Already loaded the model. Skip loading");
9192
const gpuDetectOutput = await tvmjs.detectGPUDevice();

src/utils.ts

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import { AppConfig, ChatOptions, ModelRecord } from "./config";
2+
3+
// Helper function to compare two arrays
4+
function areArraysEqual(arr1?: Array<any>, arr2?: Array<any>): boolean {
5+
if (!arr1 && !arr2) return true;
6+
if (!arr1 || !arr2) return false;
7+
if (arr1.length !== arr2.length) return false;
8+
for (let i = 0; i < arr1.length; i++) {
9+
if (arr1[i] !== arr2[i]) return false;
10+
}
11+
return true;
12+
}
13+
14+
// Helper function to compare two objects deeply
15+
function areObjectsEqual(obj1: any, obj2: any): boolean {
16+
if (obj1 === obj2) return true;
17+
if (typeof obj1 !== typeof obj2) return false;
18+
if (typeof obj1 !== "object" || obj1 === null || obj2 === null) return false;
19+
20+
const keys1 = Object.keys(obj1);
21+
const keys2 = Object.keys(obj2);
22+
if (keys1.length !== keys2.length) return false;
23+
24+
for (const key of keys1) {
25+
if (!keys2.includes(key) || !areObjectsEqual(obj1[key], obj2[key]))
26+
return false;
27+
}
28+
return true;
29+
}
30+
31+
// Function to compare two ModelRecord instances
32+
export function areModelRecordsEqual(
33+
record1: ModelRecord,
34+
record2: ModelRecord
35+
): boolean {
36+
// Compare primitive fields
37+
if (
38+
record1.model_url !== record2.model_url ||
39+
record1.model_id !== record2.model_id ||
40+
record1.model_lib_url !== record2.model_lib_url ||
41+
record1.vram_required_MB !== record2.vram_required_MB ||
42+
record1.low_resource_required !== record2.low_resource_required ||
43+
record1.buffer_size_required_bytes !== record2.buffer_size_required_bytes
44+
) {
45+
return false;
46+
}
47+
48+
// Compare required_features arrays
49+
if (
50+
(record1.required_features && !record2.required_features) ||
51+
(!record1.required_features && record2.required_features)
52+
) {
53+
return false;
54+
}
55+
56+
if (record1.required_features && record2.required_features) {
57+
if (record1.required_features.length !== record2.required_features.length) {
58+
return false;
59+
}
60+
61+
for (let i = 0; i < record1.required_features.length; i++) {
62+
if (record1.required_features[i] !== record2.required_features[i]) {
63+
return false;
64+
}
65+
}
66+
}
67+
68+
return true;
69+
}
70+
71+
export function areAppConfigsEqual(
72+
config1?: AppConfig,
73+
config2?: AppConfig
74+
): boolean {
75+
if (config1 === undefined || config2 === undefined) {
76+
return config1 === config2;
77+
}
78+
79+
// Check if both configurations have the same IndexedDB cache usage
80+
if (config1.useIndexedDBCache !== config2.useIndexedDBCache) {
81+
return false;
82+
}
83+
84+
// Check if both configurations have the same number of model records
85+
if (config1.model_list.length !== config2.model_list.length) {
86+
return false;
87+
}
88+
89+
// Compare each ModelRecord in the model_list
90+
for (let i = 0; i < config1.model_list.length; i++) {
91+
if (!areModelRecordsEqual(config1.model_list[i], config2.model_list[i])) {
92+
return false;
93+
}
94+
}
95+
96+
// If all checks passed, the configurations are equal
97+
return true;
98+
}
99+
100+
export function areChatOptionsEqual(
101+
options1?: ChatOptions,
102+
options2?: ChatOptions
103+
): boolean {
104+
if (options1 === undefined || options2 === undefined) {
105+
return options1 === options2;
106+
}
107+
// Compare each property of ChatOptions (which are Partial<ChatConfig>)
108+
if (!areArraysEqual(options1.tokenizer_files, options2.tokenizer_files))
109+
return false;
110+
if (!areObjectsEqual(options1.conv_config, options2.conv_config))
111+
return false;
112+
if (options1.conv_template !== options2.conv_template) return false;
113+
if (options1.mean_gen_len !== options2.mean_gen_len) return false;
114+
if (options1.max_gen_len !== options2.max_gen_len) return false;
115+
if (options1.shift_fill_factor !== options2.shift_fill_factor) return false;
116+
if (options1.repetition_penalty !== options2.repetition_penalty) return false;
117+
if (options1.frequency_penalty !== options2.frequency_penalty) return false;
118+
if (options1.presence_penalty !== options2.presence_penalty) return false;
119+
if (options1.top_p !== options2.top_p) return false;
120+
if (options1.temperature !== options2.temperature) return false;
121+
if (options1.bos_token_id !== options2.bos_token_id) return false;
122+
123+
// If all checks passed, the options are equal
124+
return true;
125+
}

src/web_service_worker.ts

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import {
88
PostMessageHandler,
99
ChatWorker,
1010
} from "./web_worker";
11+
import { areAppConfigsEqual, areChatOptionsEqual } from "./utils";
1112

1213
const BROADCAST_CHANNEL_SERVICE_WORKER_ID = "@mlc-ai/web-llm-sw";
1314
const BROADCAST_CHANNEL_CLIENT_ID = "@mlc-ai/web-llm-client";
@@ -79,8 +80,8 @@ export class ServiceWorkerEngineHandler extends EngineWorkerHandler {
7980
// If the modelId, chatOpts, and appConfig are the same, immediately return
8081
if (
8182
this.modelId === params.modelId &&
82-
this.chatOpts === params.chatOpts &&
83-
this.appConfig === params.appConfig
83+
areChatOptionsEqual(this.chatOpts, params.chatOpts) &&
84+
areAppConfigsEqual(this.appConfig, params.appConfig)
8485
) {
8586
console.log("Already loaded the model. Skip loading");
8687
const gpuDetectOutput = await tvmjs.detectGPUDevice();
@@ -147,8 +148,8 @@ export async function CreateServiceWorkerEngine(
147148
*/
148149
export class ServiceWorkerEngine extends WebWorkerEngine {
149150
constructor(worker: ChatWorker) {
150-
super(worker)
151-
clientBroadcastChannel.onmessage = this.onmessage.bind(this)
151+
super(worker);
152+
clientBroadcastChannel.onmessage = this.onmessage.bind(this);
152153
}
153154

154155
keepAlive() {

0 commit comments

Comments
 (0)