|
13 | 13 | "\n",
|
14 | 14 | "from panel.custom import JSComponent, ESMEvent\n",
|
15 | 15 | "\n",
|
16 |
| - "pn.extension(template='material')" |
| 16 | + "pn.extension('mathjax', template='material')" |
17 | 17 | ]
|
18 | 18 | },
|
19 | 19 | {
|
|
31 | 31 | "metadata": {},
|
32 | 32 | "outputs": [],
|
33 | 33 | "source": [
|
34 |
| - "\n", |
35 | 34 | "MODELS = {\n",
|
36 |
| - " 'Mistral-7b-Instruct': 'Mistral-7B-Instruct-v0.3-q4f16_1-MLC',\n", |
37 |
| - " 'SmolLM': 'SmolLM-360M-Instruct-q4f16_1-MLC',\n", |
38 |
| - " 'Gemma-2b': 'gemma-2-2b-it-q4f16_1-MLC',\n", |
39 |
| - " 'Llama-3.1-8b-Instruct': 'Llama-3.1-8B-Instruct-q4f32_1-MLC-1k'\n", |
| 35 | + " 'SmolLM (130MB)': 'SmolLM-135M-Instruct-q4f16_1-MLC',\n", |
| 36 | + " 'TinyLlama-1.1B-Chat (675 MB)': 'TinyLlama-1.1B-Chat-v1.0-q4f16_1-MLC-1k',\n", |
| 37 | + " 'Gemma-2b (1895 MB)': 'gemma-2-2b-it-q4f16_1-MLC',\n", |
| 38 | + " 'Mistral-7b-Instruct (4570 MB)': 'Mistral-7B-Instruct-v0.3-q4f16_1-MLC',\n", |
| 39 | + " 'Llama-3.1-8b-Instruct (4598 MB)': 'Llama-3-8B-Instruct-q4f16_1-MLC-1k',\n", |
40 | 40 | "}\n",
|
41 | 41 | "\n",
|
42 | 42 | "class WebLLM(JSComponent):\n",
|
43 | 43 | "\n",
|
44 | 44 | " loaded = param.Boolean(default=False, doc=\"\"\"\n",
|
45 | 45 | " Whether the model is loaded.\"\"\")\n",
|
46 | 46 | "\n",
|
47 |
| - " model = param.Selector(default='SmolLM-360M-Instruct-q4f16_1-MLC', objects=MODELS)\n", |
| 47 | + " history = param.Integer(default=3)\n", |
48 | 48 | "\n",
|
49 |
| - " temperature = param.Number(default=1, bounds=(0, 2))\n", |
| 49 | + " status = param.Dict(default={'text': '', 'progress': 0})\n", |
50 | 50 | "\n",
|
51 | 51 | " load_model = param.Event()\n",
|
| 52 | + "\n", |
| 53 | + " model = param.Selector(default='SmolLM-135M-Instruct-q4f16_1-MLC', objects=MODELS)\n", |
| 54 | + "\n", |
| 55 | + " running = param.Boolean(default=False, doc=\"\"\"\n", |
| 56 | + " Whether the LLM is currently running.\"\"\")\n", |
52 | 57 | " \n",
|
| 58 | + " temperature = param.Number(default=1, bounds=(0, 2), doc=\"\"\"\n", |
| 59 | + " Temperature of the model completions.\"\"\")\n", |
| 60 | + "\n", |
53 | 61 | " _esm = \"\"\"\n",
|
54 | 62 | " import * as webllm from \"https://esm.run/@mlc-ai/web-llm\";\n",
|
55 | 63 | "\n",
|
56 | 64 | " const engines = new Map()\n",
|
57 | 65 | "\n",
|
58 | 66 | " export async function render({ model }) {\n",
|
59 | 67 | " model.on(\"msg:custom\", async (event) => {\n",
|
60 |
| - " console.log(event)\n", |
61 | 68 | " if (event.type === 'load') {\n",
|
62 | 69 | " if (!engines.has(model.model)) {\n",
|
63 |
| - " engines.set(model.model, await webllm.CreateMLCEngine(model.model))\n", |
| 70 | + " const initProgressCallback = (status) => {\n", |
| 71 | + " model.status = status\n", |
| 72 | + " }\n", |
| 73 | + " const mlc = await webllm.CreateMLCEngine(\n", |
| 74 | + " model.model,\n", |
| 75 | + " {initProgressCallback}\n", |
| 76 | + " )\n", |
| 77 | + " engines.set(model.model, mlc)\n", |
64 | 78 | " }\n",
|
65 | 79 | " model.loaded = true\n",
|
66 | 80 | " } else if (event.type === 'completion') {\n",
|
|
73 | 87 | " temperature: model.temperature ,\n",
|
74 | 88 | " stream: true,\n",
|
75 | 89 | " })\n",
|
| 90 | + " model.running = true\n", |
76 | 91 | " for await (const chunk of chunks) {\n",
|
| 92 | + " if (!model.running) {\n", |
| 93 | + " break\n", |
| 94 | + " }\n", |
77 | 95 | " model.send_msg(chunk.choices[0])\n",
|
78 | 96 | " }\n",
|
79 | 97 | " }\n",
|
|
83 | 101 | "\n",
|
84 | 102 | " def __init__(self, **params):\n",
|
85 | 103 | " super().__init__(**params)\n",
|
| 104 | + " if pn.state.location:\n", |
| 105 | + " pn.state.location.sync(self, {'model': 'model'})\n", |
86 | 106 | " self._buffer = []\n",
|
87 | 107 | "\n",
|
88 | 108 | " @param.depends('load_model', watch=True)\n",
|
|
93 | 113 | " @param.depends('loaded', watch=True)\n",
|
94 | 114 | " def _loaded(self):\n",
|
95 | 115 | " self.loading = False\n",
|
96 |
| - " self.param.load_model.constant = True\n", |
97 | 116 | "\n",
|
98 | 117 | " @param.depends('model', watch=True)\n",
|
99 | 118 | " def _update_load_model(self):\n",
|
100 |
| - " self.param.load_model.constant = False\n", |
| 119 | + " self.loaded = False\n", |
101 | 120 | "\n",
|
102 | 121 | " def _handle_msg(self, msg):\n",
|
103 |
| - " self._buffer.insert(0, msg)\n", |
| 122 | + " if self.running:\n", |
| 123 | + " self._buffer.insert(0, msg)\n", |
104 | 124 | "\n",
|
105 | 125 | " async def create_completion(self, msgs):\n",
|
106 | 126 | " self._send_msg({'type': 'completion', 'messages': msgs})\n",
|
|
119 | 139 | "\n",
|
120 | 140 | " async def callback(self, contents: str, user: str):\n",
|
121 | 141 | " if not self.loaded:\n",
|
122 |
| - " yield f'Model `{self.model}` is loading.' if self.param.load_model.constant else 'Load the model'\n", |
| 142 | + " if self.loading:\n", |
| 143 | + " yield pn.pane.Markdown(\n", |
| 144 | + " f'## `{self.model}`\\n\\n' + self.param.status.rx()['text']\n", |
| 145 | + " )\n", |
| 146 | + " else:\n", |
| 147 | + " yield 'Load the model'\n", |
123 | 148 | " return\n",
|
| 149 | + " self.running = False\n", |
| 150 | + " self._buffer.clear()\n", |
124 | 151 | " message = \"\"\n",
|
125 | 152 | " async for chunk in llm.create_completion([{'role': 'user', 'content': contents}]):\n",
|
126 | 153 | " message += chunk['delta'].get('content', '')\n",
|
127 | 154 | " yield message\n",
|
128 | 155 | "\n",
|
129 | 156 | " def menu(self):\n",
|
| 157 | + " status = self.param.status.rx()\n", |
130 | 158 | " return pn.Column(\n",
|
131 | 159 | " pn.widgets.Select.from_param(self.param.model, sizing_mode='stretch_width'),\n",
|
132 | 160 | " pn.widgets.FloatSlider.from_param(self.param.temperature, sizing_mode='stretch_width'),\n",
|
133 | 161 | " pn.widgets.Button.from_param(\n",
|
134 | 162 | " self.param.load_model, sizing_mode='stretch_width',\n",
|
135 |
| - " loading=self.param.loading\n", |
136 |
| - " )\n", |
| 163 | + " disabled=self.param.loaded.rx().rx.or_(self.param.loading)\n", |
| 164 | + " ),\n", |
| 165 | + " pn.indicators.Progress(\n", |
| 166 | + " value=(status['progress']*100).rx.pipe(int), visible=self.param.loading,\n", |
| 167 | + " sizing_mode='stretch_width'\n", |
| 168 | + " ),\n", |
| 169 | + " pn.pane.Markdown(status['text'], visible=self.param.loading)\n", |
137 | 170 | " )"
|
138 | 171 | ]
|
139 | 172 | },
|
|
154 | 187 | "source": [
|
155 | 188 | "llm = WebLLM()\n",
|
156 | 189 | "\n",
|
157 |
| - "pn.Column(llm.menu(), llm).servable(area='sidebar')" |
| 190 | + "intro = pn.pane.Alert(\"\"\"\n", |
| 191 | + "`WebLLM` runs large-language models entirely in your browser.\n", |
| 192 | + "When visiting the application the first time the model has\n", |
| 193 | + "to be downloaded and loaded into memory, which may take \n", |
| 194 | + "some time. Models are ordered by size (and capability),\n", |
| 195 | + "e.g. SmolLLM is very quick to download but produces poor\n", |
| 196 | + "quality output while Mistral-7b will take a while to\n", |
| 197 | + "download but produces much higher quality output.\n", |
| 198 | + "\"\"\".replace('\\n', ' '))\n", |
| 199 | + "\n", |
| 200 | + "pn.Column(\n", |
| 201 | + " llm.menu(),\n", |
| 202 | + " intro,\n", |
| 203 | + " llm\n", |
| 204 | + ").servable(area='sidebar')" |
158 | 205 | ]
|
159 | 206 | },
|
160 | 207 | {
|
|
179 | 226 | " respond=False,\n",
|
180 | 227 | ")\n",
|
181 | 228 | "\n",
|
182 |
| - "chat_interface.servable(title='WebLLM')" |
| 229 | + "llm.param.watch(lambda e: chat_interface.send(f'Loaded `{e.obj.model}`, start chatting!', user='System', respond=False), 'loaded')\n", |
| 230 | + "\n", |
| 231 | + "pn.Row(chat_interface).servable(title='WebLLM')" |
183 | 232 | ]
|
184 | 233 | }
|
185 | 234 | ],
|
|
0 commit comments