diff --git a/cpp/wllama-context.h b/cpp/wllama-context.h index 984fd815..e5902312 100644 --- a/cpp/wllama-context.h +++ b/cpp/wllama-context.h @@ -180,13 +180,18 @@ struct wllama_context } } - std::pair get_next_result() + std::pair get_next_result() { server_task_result_ptr result = rd->next(should_stop); if (result) - return {result->to_json().dump(), result->is_error()}; + { + const bool is_error = result->is_error(); + return {std::move(result), is_error}; + } else - return {"", false}; + { + return {nullptr, false}; + } } kv_dump dump_metadata() @@ -492,11 +497,31 @@ struct wllama_context glue_msg_get_result_res res; bool has_more = run_loop(); - auto [data_json, is_error] = get_next_result(); + auto [result, is_error] = get_next_result(); + + json data_json; + if (result) + { + auto *res = dynamic_cast(result.get()); + if (res) + { + // special handling for embeddings OAI-compat + json body = {{"model", meta->model_name}}; + json responses = json::array(); + responses.push_back(result->to_json()); + // TODO: support base64 output + data_json = format_embeddings_response_oaicompat(body, meta->model_name, responses, false); + } + else + { + // otherwise, it should be a completion result, nothing special to do + data_json = result->to_json(); + } + } res.success.value = true; res.has_more.value = has_more; - res.data_json.value = data_json; + res.data_json.value = result ? data_json.dump() : ""; res.is_error.value = is_error; return res; } diff --git a/examples/basic/index.html b/examples/basic/index.html index a1f89108..9e59c0aa 100644 --- a/examples/basic/index.html +++ b/examples/basic/index.html @@ -133,10 +133,10 @@

Embeddings

setEmbdDisable(true); let embdA = await wllama1.createEmbedding({ input: elemInputA.value }); console.log({embdA}); - embdA = embdA.embedding; // OAI-compat response + embdA = embdA.data[0].embedding; // OAI-compat response let embdB = await wllama1.createEmbedding({ input: elemInputB.value }); console.log({embdB}); - embdB = embdB.embedding; // OAI-compat response + embdB = embdB.data[0].embedding; // OAI-compat response // since embeddings are normalized, we don't need to calculate norm const dotProd = embdA.reduce((acc, _, i) => acc + embdA[i]*embdB[i], 0); elemOutputEmbd.textContent = dotProd; diff --git a/examples/embeddings/index.html b/examples/embeddings/index.html index f91c3099..9ceb5c72 100644 --- a/examples/embeddings/index.html +++ b/examples/embeddings/index.html @@ -59,7 +59,7 @@ print(`Calculating embedding for sentence #${i}: "${sentence}"`); timeStart(); const result = await wllama.createEmbedding({ input: sentence }); - const vector = result.embedding; + const vector = result.data[0].embedding; print(`OK, take ${timeEnd()} ms`); embeddings.push(vector); } diff --git a/llama.cpp b/llama.cpp index 928b486b..2dfeca31 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit 928b486b0c8ef4a126086e078126cdb42e977fc7 +Subproject commit 2dfeca31cc34ffa03c510e75438455d9c6773ffb diff --git a/package.json b/package.json index 97e49c46..d5f79b36 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@wllama/wllama", - "version": "3.1.0", + "version": "3.1.1", "description": "WebAssembly binding for llama.cpp - Enabling on-browser LLM inference", "main": "index.js", "type": "module", diff --git a/src/wasm-from-cdn.ts b/src/wasm-from-cdn.ts index 579f5023..c7e1fd9d 100644 --- a/src/wasm-from-cdn.ts +++ b/src/wasm-from-cdn.ts @@ -2,7 +2,7 @@ // Do not edit this file directly const WasmFromCDN = { - default: 'https://cdn.jsdelivr.net/npm/@wllama/wllama@3.1.0/src/wasm/wllama.wasm', + default: 'https://cdn.jsdelivr.net/npm/@wllama/wllama@3.1.1/src/wasm/wllama.wasm', }; export default WasmFromCDN; \ No newline at end of file diff --git a/src/wasm/wllama.wasm b/src/wasm/wllama.wasm index cf56d4b3..ee2bdd08 100755 Binary files a/src/wasm/wllama.wasm and b/src/wasm/wllama.wasm differ diff --git a/src/wllama.test.ts b/src/wllama.test.ts index ce7612c0..0533b848 100644 --- a/src/wllama.test.ts +++ b/src/wllama.test.ts @@ -172,7 +172,7 @@ test.sequential('generates embeddings', async () => { const res = await wllama.createEmbedding({ input: text }); expect(res).toBeDefined(); - const embedding = (res as any).embedding as number[]; + const embedding = res.data[0].embedding as number[]; expect(Array.isArray(embedding)).toBe(true); expect(embedding.length).toBeGreaterThan(0); for (const e of embedding) { @@ -181,7 +181,7 @@ test.sequential('generates embeddings', async () => { // slightly different text should have high cosine similarity const res2 = await wllama.createEmbedding({ input: text + ' ' }); - const embedding2 = (res2 as any).embedding as number[]; + const embedding2 = res2.data[0].embedding as number[]; const dot = embedding.reduce((acc, v, i) => acc + v * embedding2[i], 0); const norm1 = Math.sqrt(embedding.reduce((acc, v) => acc + v * v, 0)); const norm2 = Math.sqrt(embedding2.reduce((acc, v) => acc + v * v, 0)); diff --git a/src/wllama.ts b/src/wllama.ts index 57cf5e28..195ffad6 100644 --- a/src/wllama.ts +++ b/src/wllama.ts @@ -15,6 +15,7 @@ import CacheManager, { type DownloadOptions } from './cache-manager'; import { ModelManager, Model, type ModelSource } from './model-manager'; import type { GlueMsgCompletionRes, + GlueMsgEmbeddingRes, GlueMsgGetResultRes, GlueMsgLoadRes, } from './glue/messages'; @@ -555,7 +556,7 @@ export class Wllama { ): Promise { this.checkModelLoaded(); - const result = await this.proxy.wllamaAction( + const result = await this.proxy.wllamaAction( 'embedding', { _name: 'embd_req', diff --git a/src/workers-code/generated.ts b/src/workers-code/generated.ts index 68e9f9a5..434b0799 100644 --- a/src/workers-code/generated.ts +++ b/src/workers-code/generated.ts @@ -1,7 +1,7 @@ // This file is auto-generated // To re-generate it, run: npm run build:worker -export const LIBLLAMA_VERSION = 'b9108-928b486'; +export const LIBLLAMA_VERSION = 'b9138-2dfeca3'; export const LLAMA_CPP_WORKER_CODE = "// Start the main llama.cpp\nlet wllamaMalloc;\nlet wllamaStart;\nlet wllamaAction;\nlet wllamaExit;\nlet wllamaDebug;\n\nlet Module = null;\n\n//////////////////////////////////////////////////////////////\n// UTILS\n//////////////////////////////////////////////////////////////\n\n// send message back to main thread\nconst msg = (data, transfer) => postMessage(data, transfer);\n\n// Convert CPP log into JS log\nconst cppLogToJSLog = (line) => {\n const matched = line.match(/@@(DEBUG|INFO|WARN|ERROR)@@(.*)/);\n return !!matched\n ? {\n level: (matched[1] === 'INFO' ? 'debug' : matched[1]).toLowerCase(),\n text: matched[2],\n }\n : { level: 'log', text: line };\n};\n\nconst getHeapU8 = () => {\n const buffer = Module.wasmMemory.buffer;\n return new Uint8Array(buffer);\n};\n\n// Get module config that forwards stdout/err to main thread\nconst getWModuleConfig = (_argMainScriptBlob) => {\n var pathConfig = RUN_OPTIONS.pathConfig;\n var pthreadPoolSize = RUN_OPTIONS.nbThread;\n var argMainScriptBlob = _argMainScriptBlob;\n\n if (!pathConfig['wllama.wasm']) {\n throw new Error('\"wllama.wasm\" is missing in pathConfig');\n }\n return {\n noInitialRun: true,\n print: function (text) {\n if (arguments.length > 1)\n text = Array.prototype.slice.call(arguments).join(' ');\n msg({ verb: 'console.log', args: [text] });\n },\n printErr: function (text) {\n if (arguments.length > 1)\n text = Array.prototype.slice.call(arguments).join(' ');\n const logLine = cppLogToJSLog(text);\n msg({ verb: 'console.' + logLine.level, args: [logLine.text] });\n },\n locateFile: function (filename, basePath) {\n const p = pathConfig[filename];\n const truncate = (str) =>\n str.length > 128 ? `${str.substr(0, 128)}...` : str;\n if (filename.match(/wllama\\.worker\\.js/)) {\n msg({\n verb: 'console.error',\n args: [\n '\"wllama.worker.js\" is removed from v2.2.1. Hint: make sure to clear browser\\'s cache.',\n ],\n });\n } else {\n msg({\n verb: 'console.debug',\n args: [`Loading \"${filename}\" from \"${truncate(p)}\"`],\n });\n return p;\n }\n },\n mainScriptUrlOrBlob: argMainScriptBlob,\n pthreadPoolSize,\n wasmMemory: pthreadPoolSize > 1 ? getWasmMemory() : null,\n onAbort: function (text) {\n msg({ verb: 'signal.abort', args: [text] });\n },\n };\n};\n\n// Get the memory to be used by wasm. (Only used in multi-thread mode)\n// Because we have a weird OOM issue on iOS, we need to try some values\n// See: https://github.com/emscripten-core/emscripten/issues/19144\n// https://github.com/godotengine/godot/issues/70621\nconst getWasmMemory = () => {\n let minBytes = 128 * 1024 * 1024;\n let maxBytes = 4096 * 1024 * 1024;\n let stepBytes = 128 * 1024 * 1024;\n while (maxBytes > minBytes) {\n try {\n const wasmMemory = new WebAssembly.Memory({\n initial: BigInt(minBytes / 65536),\n maximum: BigInt(maxBytes / 65536),\n shared: true,\n address: 'i64',\n });\n return wasmMemory;\n } catch (e) {\n maxBytes -= stepBytes;\n continue; // retry\n }\n }\n throw new Error('Cannot allocate WebAssembly.Memory');\n};\n\n//////////////////////////////////////////////////////////////\n// HEAPFS PATCH\n//////////////////////////////////////////////////////////////\n\n/**\n * By default, emscripten uses memfs. The way it works is by\n * allocating new Uint8Array in javascript heap. This is not good\n * because it requires files to be copied to wasm heap each time\n * a file is read.\n *\n * HeapFS is an alternative, which resolves this problem by\n * allocating space for file directly inside wasm heap. This\n * allows us to mmap without doing any copy.\n *\n * For llama.cpp, this is great because we use MAP_SHARED\n *\n * Ref: https://github.com/ngxson/wllama/pull/39\n * Ref: https://github.com/emscripten-core/emscripten/blob/main/src/library_memfs.js\n *\n * Note 29/05/2024 @ngxson\n * Due to ftell() being limited to MAX_LONG, we cannot load files bigger than 2^31 bytes (or 2GB)\n * Ref: https://github.com/emscripten-core/emscripten/blob/main/system/lib/libc/musl/src/stdio/ftell.c\n */\n\nconst fsNameToFile = {}; // map Name => File\nconst fsIdToFile = {}; // map ID => File\nlet currFileId = 0;\n\n// Patch and redirect memfs calls to wllama\nconst patchHeapFS = () => {\n const m = Module;\n // save functions\n m.MEMFS.stream_ops._read = m.MEMFS.stream_ops.read;\n m.MEMFS.stream_ops._write = m.MEMFS.stream_ops.write;\n m.MEMFS.stream_ops._llseek = m.MEMFS.stream_ops.llseek;\n m.MEMFS.stream_ops._allocate = m.MEMFS.stream_ops.allocate;\n m.MEMFS.stream_ops._mmap = m.MEMFS.stream_ops.mmap;\n m.MEMFS.stream_ops._msync = m.MEMFS.stream_ops.msync;\n\n const patchStream = (stream) => {\n const name = stream.node.name;\n if (fsNameToFile[name]) {\n const f = fsNameToFile[name];\n stream.node.contents = getHeapU8().subarray(f.ptr, f.ptr + f.size);\n stream.node.usedBytes = f.size;\n }\n };\n\n // replace \"read\" functions\n m.MEMFS.stream_ops.read = function (\n stream,\n buffer,\n offset,\n length,\n position\n ) {\n patchStream(stream);\n return m.MEMFS.stream_ops._read(stream, buffer, offset, length, position);\n };\n m.MEMFS.ops_table.file.stream.read = m.MEMFS.stream_ops.read;\n\n // replace \"llseek\" functions\n m.MEMFS.stream_ops.llseek = function (stream, offset, whence) {\n patchStream(stream);\n return m.MEMFS.stream_ops._llseek(stream, offset, whence);\n };\n m.MEMFS.ops_table.file.stream.llseek = m.MEMFS.stream_ops.llseek;\n\n // replace \"mmap\" functions\n m.MEMFS.stream_ops.mmap = function (stream, length, position, prot, flags) {\n patchStream(stream);\n const name = stream.node.name;\n if (fsNameToFile[name]) {\n const f = fsNameToFile[name];\n return {\n ptr: f.ptr + position,\n allocated: false,\n };\n } else {\n return m.MEMFS.stream_ops._mmap(stream, length, position, prot, flags);\n }\n };\n m.MEMFS.ops_table.file.stream.mmap = m.MEMFS.stream_ops.mmap;\n\n // mount FS\n m.FS.mkdir('/models');\n m.FS.mount(m.MEMFS, { root: '.' }, '/models');\n};\n\n// Allocate a new file in wllama heapfs, returns file ID\nconst heapfsAlloc = (name, size) => {\n if (size < 1) {\n throw new Error('File size must be bigger than 0');\n }\n const m = Module;\n const ptr = m.mmapAlloc(size);\n const file = {\n ptr: ptr,\n size: size,\n id: currFileId++,\n };\n fsIdToFile[file.id] = file;\n fsNameToFile[name] = file;\n return file.id;\n};\n\n// Add new file to wllama heapfs, return number of written bytes\nconst heapfsWrite = (id, buffer, offset) => {\n const m = Module;\n if (fsIdToFile[id]) {\n const { ptr, size } = fsIdToFile[id];\n const afterWriteByte = offset + buffer.byteLength;\n if (afterWriteByte > size) {\n throw new Error(\n `File ID ${id} write out of bound, afterWriteByte = ${afterWriteByte} while size = ${size}`\n );\n }\n getHeapU8().set(buffer, ptr + offset);\n return buffer.byteLength;\n } else {\n throw new Error(`File ID ${id} not found in heapfs`);\n }\n};\n\n//////////////////////////////////////////////////////////////\n// MAIN CODE\n//////////////////////////////////////////////////////////////\n\nconst callWrapper = (name, ret, args, isAsync) => {\n const fn = Module.cwrap(\n name,\n ret,\n args,\n isAsync ? { async: true } : undefined\n );\n return async (action, req) => {\n // console.log(`Calling ${name} with action:`, action, 'and req:', req);\n let result;\n try {\n if (args.length === 2) {\n result = isAsync ? await fn(action, req) : fn(action, req);\n } else {\n result = fn();\n }\n } catch (ex) {\n console.error(ex);\n throw ex;\n }\n return result;\n };\n};\n\nonmessage = async (e) => {\n if (!e.data) return;\n const { verb, args, callbackId } = e.data;\n\n if (!callbackId) {\n msg({ verb: 'console.error', args: ['callbackId is required', e.data] });\n return;\n }\n\n if (verb === 'module.init') {\n const argMainScriptBlob = args[0];\n try {\n Module = getWModuleConfig(argMainScriptBlob);\n Module.preRun = () => {\n // ENV can be set here (for future use)\n };\n Module.onRuntimeInitialized = () => {\n // async call once module is ready\n // init FS\n patchHeapFS();\n // init cwrap\n const pointer = 'bigint';\n // TODO: note sure why emscripten cannot bind if there is only 1 argument\n wllamaMalloc = callWrapper('wllama_malloc', pointer, [\n 'number',\n pointer,\n ]);\n wllamaStart = callWrapper('wllama_start', 'string', [], true);\n wllamaAction = callWrapper(\n 'wllama_action',\n pointer,\n ['string', pointer],\n true\n );\n wllamaExit = callWrapper('wllama_exit', 'string', []);\n wllamaDebug = callWrapper('wllama_debug', 'string', []);\n msg({ callbackId, result: null });\n };\n wModuleInit();\n } catch (err) {\n msg({ callbackId, err });\n }\n return;\n }\n\n if (verb === 'fs.alloc') {\n const argFilename = args[0];\n const argSize = args[1];\n try {\n // create blank file\n const emptyBuffer = new ArrayBuffer(0);\n Module['FS_createDataFile'](\n '/models',\n argFilename,\n emptyBuffer,\n true,\n true,\n true\n );\n // alloc data on heap\n const fileId = heapfsAlloc(argFilename, argSize);\n msg({ callbackId, result: { fileId } });\n } catch (err) {\n msg({ callbackId, err });\n }\n return;\n }\n\n if (verb === 'fs.write') {\n const argFileId = args[0];\n const argBuffer = args[1];\n const argOffset = args[2];\n try {\n const writtenBytes = heapfsWrite(argFileId, argBuffer, argOffset);\n msg({ callbackId, result: { writtenBytes } });\n } catch (err) {\n msg({ callbackId, err });\n }\n return;\n }\n\n if (verb === 'wllama.start') {\n try {\n const result = await wllamaStart();\n msg({ callbackId, result });\n } catch (err) {\n msg({ callbackId, err });\n }\n return;\n }\n\n if (verb === 'wllama.action') {\n const argAction = args[0];\n const argEncodedMsg = args[1];\n try {\n const inputPtr = await wllamaMalloc(BigInt(argEncodedMsg.byteLength), 0);\n // copy data to wasm heap\n const inputBuffer = new Uint8Array(\n getHeapU8().buffer,\n Number(inputPtr),\n argEncodedMsg.byteLength\n );\n inputBuffer.set(argEncodedMsg, 0);\n const outputPtr = await wllamaAction(argAction, inputPtr);\n // length of output buffer is written at the first 4 bytes of input buffer\n const outputLen = new Uint32Array(\n getHeapU8().buffer,\n Number(inputPtr),\n 1\n )[0];\n // copy the output buffer to JS heap\n const outputBuffer = new Uint8Array(outputLen);\n const outputSrcView = new Uint8Array(\n getHeapU8().buffer,\n Number(outputPtr),\n outputLen\n );\n outputBuffer.set(outputSrcView, 0); // copy it\n msg({ callbackId, result: outputBuffer }, [outputBuffer.buffer]);\n } catch (err) {\n msg({ callbackId, err });\n }\n return;\n }\n\n if (verb === 'wllama.exit') {\n try {\n const result = await wllamaExit();\n msg({ callbackId, result });\n } catch (err) {\n msg({ callbackId, err });\n }\n return;\n }\n\n if (verb === 'wllama.debug') {\n try {\n const result = await wllamaDebug();\n msg({ callbackId, result });\n } catch (err) {\n msg({ callbackId, err });\n }\n return;\n }\n};\n";