Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions cpp/wllama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,18 @@ struct wllama_context
}
}

std::pair<std::string, bool> get_next_result()
std::pair<server_task_result_ptr, bool> get_next_result()
{
server_task_result_ptr result = rd->next(should_stop);
if (result)
return {result->to_json().dump(), result->is_error()};
{
const bool is_error = result->is_error();
return {std::move(result), is_error};
}
else
return {"", false};
{
return {nullptr, false};
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
}

kv_dump dump_metadata()
Expand Down Expand Up @@ -492,11 +497,31 @@ struct wllama_context
glue_msg_get_result_res res;

bool has_more = run_loop();
auto [data_json, is_error] = get_next_result();
auto [result, is_error] = get_next_result();

json data_json;
if (result)
{
auto *res = dynamic_cast<server_task_result_embd *>(result.get());
if (res)
{
// special handling for embeddings OAI-compat
json body = {{"model", meta->model_name}};
json responses = json::array();
responses.push_back(result->to_json());
// TODO: support base64 output
data_json = format_embeddings_response_oaicompat(body, meta->model_name, responses, false);
}
else
{
// otherwise, it should be a completion result, nothing special to do
data_json = result->to_json();
}
}

res.success.value = true;
res.has_more.value = has_more;
res.data_json.value = data_json;
res.data_json.value = result ? data_json.dump() : "";
res.is_error.value = is_error;
return res;
}
Expand Down
4 changes: 2 additions & 2 deletions examples/basic/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ <h2>Embeddings</h2>
setEmbdDisable(true);
let embdA = await wllama1.createEmbedding({ input: elemInputA.value });
console.log({embdA});
embdA = embdA.embedding; // OAI-compat response
embdA = embdA.data[0].embedding; // OAI-compat response
let embdB = await wllama1.createEmbedding({ input: elemInputB.value });
console.log({embdB});
embdB = embdB.embedding; // OAI-compat response
embdB = embdB.data[0].embedding; // OAI-compat response
// since embeddings are normalized, we don't need to calculate norm
const dotProd = embdA.reduce((acc, _, i) => acc + embdA[i]*embdB[i], 0);
elemOutputEmbd.textContent = dotProd;
Expand Down
2 changes: 1 addition & 1 deletion examples/embeddings/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
print(`Calculating embedding for sentence #${i}: "${sentence}"`);
timeStart();
const result = await wllama.createEmbedding({ input: sentence });
const vector = result.embedding;
const vector = result.data[0].embedding;
print(`OK, take ${timeEnd()} ms`);
embeddings.push(vector);
}
Expand Down
2 changes: 1 addition & 1 deletion llama.cpp
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@wllama/wllama",
"version": "3.1.0",
"version": "3.1.1",
"description": "WebAssembly binding for llama.cpp - Enabling on-browser LLM inference",
"main": "index.js",
"type": "module",
Expand Down
2 changes: 1 addition & 1 deletion src/wasm-from-cdn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Do not edit this file directly

const WasmFromCDN = {
default: 'https://cdn.jsdelivr.net/npm/@wllama/wllama@3.1.0/src/wasm/wllama.wasm',
default: 'https://cdn.jsdelivr.net/npm/@wllama/wllama@3.1.1/src/wasm/wllama.wasm',
};

export default WasmFromCDN;
Binary file modified src/wasm/wllama.wasm
Binary file not shown.
4 changes: 2 additions & 2 deletions src/wllama.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ test.sequential('generates embeddings', async () => {
const res = await wllama.createEmbedding({ input: text });

expect(res).toBeDefined();
const embedding = (res as any).embedding as number[];
const embedding = res.data[0].embedding as number[];
expect(Array.isArray(embedding)).toBe(true);
expect(embedding.length).toBeGreaterThan(0);
for (const e of embedding) {
Expand All @@ -181,7 +181,7 @@ test.sequential('generates embeddings', async () => {

// slightly different text should have high cosine similarity
const res2 = await wllama.createEmbedding({ input: text + ' ' });
const embedding2 = (res2 as any).embedding as number[];
const embedding2 = res2.data[0].embedding as number[];
const dot = embedding.reduce((acc, v, i) => acc + v * embedding2[i], 0);
Comment thread
coderabbitai[bot] marked this conversation as resolved.
const norm1 = Math.sqrt(embedding.reduce((acc, v) => acc + v * v, 0));
const norm2 = Math.sqrt(embedding2.reduce((acc, v) => acc + v * v, 0));
Expand Down
3 changes: 2 additions & 1 deletion src/wllama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import CacheManager, { type DownloadOptions } from './cache-manager';
import { ModelManager, Model, type ModelSource } from './model-manager';
import type {
GlueMsgCompletionRes,
GlueMsgEmbeddingRes,
GlueMsgGetResultRes,
GlueMsgLoadRes,
} from './glue/messages';
Expand Down Expand Up @@ -555,7 +556,7 @@ export class Wllama {
): Promise<CreateEmbeddingResponse> {
this.checkModelLoaded();

const result = await this.proxy.wllamaAction<GlueMsgCompletionRes>(
const result = await this.proxy.wllamaAction<GlueMsgEmbeddingRes>(
'embedding',
{
_name: 'embd_req',
Expand Down
2 changes: 1 addition & 1 deletion src/workers-code/generated.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// This file is auto-generated
// To re-generate it, run: npm run build:worker

export const LIBLLAMA_VERSION = 'b9108-928b486';
export const LIBLLAMA_VERSION = 'b9138-2dfeca3';

export const LLAMA_CPP_WORKER_CODE = "// Start the main llama.cpp\nlet wllamaMalloc;\nlet wllamaStart;\nlet wllamaAction;\nlet wllamaExit;\nlet wllamaDebug;\n\nlet Module = null;\n\n//////////////////////////////////////////////////////////////\n// UTILS\n//////////////////////////////////////////////////////////////\n\n// send message back to main thread\nconst msg = (data, transfer) => postMessage(data, transfer);\n\n// Convert CPP log into JS log\nconst cppLogToJSLog = (line) => {\n const matched = line.match(/@@(DEBUG|INFO|WARN|ERROR)@@(.*)/);\n return !!matched\n ? {\n level: (matched[1] === 'INFO' ? 'debug' : matched[1]).toLowerCase(),\n text: matched[2],\n }\n : { level: 'log', text: line };\n};\n\nconst getHeapU8 = () => {\n const buffer = Module.wasmMemory.buffer;\n return new Uint8Array(buffer);\n};\n\n// Get module config that forwards stdout/err to main thread\nconst getWModuleConfig = (_argMainScriptBlob) => {\n var pathConfig = RUN_OPTIONS.pathConfig;\n var pthreadPoolSize = RUN_OPTIONS.nbThread;\n var argMainScriptBlob = _argMainScriptBlob;\n\n if (!pathConfig['wllama.wasm']) {\n throw new Error('\"wllama.wasm\" is missing in pathConfig');\n }\n return {\n noInitialRun: true,\n print: function (text) {\n if (arguments.length > 1)\n text = Array.prototype.slice.call(arguments).join(' ');\n msg({ verb: 'console.log', args: [text] });\n },\n printErr: function (text) {\n if (arguments.length > 1)\n text = Array.prototype.slice.call(arguments).join(' ');\n const logLine = cppLogToJSLog(text);\n msg({ verb: 'console.' + logLine.level, args: [logLine.text] });\n },\n locateFile: function (filename, basePath) {\n const p = pathConfig[filename];\n const truncate = (str) =>\n str.length > 128 ? `${str.substr(0, 128)}...` : str;\n if (filename.match(/wllama\\.worker\\.js/)) {\n msg({\n verb: 'console.error',\n args: [\n '\"wllama.worker.js\" is removed from v2.2.1. Hint: make sure to clear browser\\'s cache.',\n ],\n });\n } else {\n msg({\n verb: 'console.debug',\n args: [`Loading \"${filename}\" from \"${truncate(p)}\"`],\n });\n return p;\n }\n },\n mainScriptUrlOrBlob: argMainScriptBlob,\n pthreadPoolSize,\n wasmMemory: pthreadPoolSize > 1 ? getWasmMemory() : null,\n onAbort: function (text) {\n msg({ verb: 'signal.abort', args: [text] });\n },\n };\n};\n\n// Get the memory to be used by wasm. (Only used in multi-thread mode)\n// Because we have a weird OOM issue on iOS, we need to try some values\n// See: https://github.com/emscripten-core/emscripten/issues/19144\n// https://github.com/godotengine/godot/issues/70621\nconst getWasmMemory = () => {\n let minBytes = 128 * 1024 * 1024;\n let maxBytes = 4096 * 1024 * 1024;\n let stepBytes = 128 * 1024 * 1024;\n while (maxBytes > minBytes) {\n try {\n const wasmMemory = new WebAssembly.Memory({\n initial: BigInt(minBytes / 65536),\n maximum: BigInt(maxBytes / 65536),\n shared: true,\n address: 'i64',\n });\n return wasmMemory;\n } catch (e) {\n maxBytes -= stepBytes;\n continue; // retry\n }\n }\n throw new Error('Cannot allocate WebAssembly.Memory');\n};\n\n//////////////////////////////////////////////////////////////\n// HEAPFS PATCH\n//////////////////////////////////////////////////////////////\n\n/**\n * By default, emscripten uses memfs. The way it works is by\n * allocating new Uint8Array in javascript heap. This is not good\n * because it requires files to be copied to wasm heap each time\n * a file is read.\n *\n * HeapFS is an alternative, which resolves this problem by\n * allocating space for file directly inside wasm heap. This\n * allows us to mmap without doing any copy.\n *\n * For llama.cpp, this is great because we use MAP_SHARED\n *\n * Ref: https://github.com/ngxson/wllama/pull/39\n * Ref: https://github.com/emscripten-core/emscripten/blob/main/src/library_memfs.js\n *\n * Note 29/05/2024 @ngxson\n * Due to ftell() being limited to MAX_LONG, we cannot load files bigger than 2^31 bytes (or 2GB)\n * Ref: https://github.com/emscripten-core/emscripten/blob/main/system/lib/libc/musl/src/stdio/ftell.c\n */\n\nconst fsNameToFile = {}; // map Name => File\nconst fsIdToFile = {}; // map ID => File\nlet currFileId = 0;\n\n// Patch and redirect memfs calls to wllama\nconst patchHeapFS = () => {\n const m = Module;\n // save functions\n m.MEMFS.stream_ops._read = m.MEMFS.stream_ops.read;\n m.MEMFS.stream_ops._write = m.MEMFS.stream_ops.write;\n m.MEMFS.stream_ops._llseek = m.MEMFS.stream_ops.llseek;\n m.MEMFS.stream_ops._allocate = m.MEMFS.stream_ops.allocate;\n m.MEMFS.stream_ops._mmap = m.MEMFS.stream_ops.mmap;\n m.MEMFS.stream_ops._msync = m.MEMFS.stream_ops.msync;\n\n const patchStream = (stream) => {\n const name = stream.node.name;\n if (fsNameToFile[name]) {\n const f = fsNameToFile[name];\n stream.node.contents = getHeapU8().subarray(f.ptr, f.ptr + f.size);\n stream.node.usedBytes = f.size;\n }\n };\n\n // replace \"read\" functions\n m.MEMFS.stream_ops.read = function (\n stream,\n buffer,\n offset,\n length,\n position\n ) {\n patchStream(stream);\n return m.MEMFS.stream_ops._read(stream, buffer, offset, length, position);\n };\n m.MEMFS.ops_table.file.stream.read = m.MEMFS.stream_ops.read;\n\n // replace \"llseek\" functions\n m.MEMFS.stream_ops.llseek = function (stream, offset, whence) {\n patchStream(stream);\n return m.MEMFS.stream_ops._llseek(stream, offset, whence);\n };\n m.MEMFS.ops_table.file.stream.llseek = m.MEMFS.stream_ops.llseek;\n\n // replace \"mmap\" functions\n m.MEMFS.stream_ops.mmap = function (stream, length, position, prot, flags) {\n patchStream(stream);\n const name = stream.node.name;\n if (fsNameToFile[name]) {\n const f = fsNameToFile[name];\n return {\n ptr: f.ptr + position,\n allocated: false,\n };\n } else {\n return m.MEMFS.stream_ops._mmap(stream, length, position, prot, flags);\n }\n };\n m.MEMFS.ops_table.file.stream.mmap = m.MEMFS.stream_ops.mmap;\n\n // mount FS\n m.FS.mkdir('/models');\n m.FS.mount(m.MEMFS, { root: '.' }, '/models');\n};\n\n// Allocate a new file in wllama heapfs, returns file ID\nconst heapfsAlloc = (name, size) => {\n if (size < 1) {\n throw new Error('File size must be bigger than 0');\n }\n const m = Module;\n const ptr = m.mmapAlloc(size);\n const file = {\n ptr: ptr,\n size: size,\n id: currFileId++,\n };\n fsIdToFile[file.id] = file;\n fsNameToFile[name] = file;\n return file.id;\n};\n\n// Add new file to wllama heapfs, return number of written bytes\nconst heapfsWrite = (id, buffer, offset) => {\n const m = Module;\n if (fsIdToFile[id]) {\n const { ptr, size } = fsIdToFile[id];\n const afterWriteByte = offset + buffer.byteLength;\n if (afterWriteByte > size) {\n throw new Error(\n `File ID ${id} write out of bound, afterWriteByte = ${afterWriteByte} while size = ${size}`\n );\n }\n getHeapU8().set(buffer, ptr + offset);\n return buffer.byteLength;\n } else {\n throw new Error(`File ID ${id} not found in heapfs`);\n }\n};\n\n//////////////////////////////////////////////////////////////\n// MAIN CODE\n//////////////////////////////////////////////////////////////\n\nconst callWrapper = (name, ret, args, isAsync) => {\n const fn = Module.cwrap(\n name,\n ret,\n args,\n isAsync ? { async: true } : undefined\n );\n return async (action, req) => {\n // console.log(`Calling ${name} with action:`, action, 'and req:', req);\n let result;\n try {\n if (args.length === 2) {\n result = isAsync ? await fn(action, req) : fn(action, req);\n } else {\n result = fn();\n }\n } catch (ex) {\n console.error(ex);\n throw ex;\n }\n return result;\n };\n};\n\nonmessage = async (e) => {\n if (!e.data) return;\n const { verb, args, callbackId } = e.data;\n\n if (!callbackId) {\n msg({ verb: 'console.error', args: ['callbackId is required', e.data] });\n return;\n }\n\n if (verb === 'module.init') {\n const argMainScriptBlob = args[0];\n try {\n Module = getWModuleConfig(argMainScriptBlob);\n Module.preRun = () => {\n // ENV can be set here (for future use)\n };\n Module.onRuntimeInitialized = () => {\n // async call once module is ready\n // init FS\n patchHeapFS();\n // init cwrap\n const pointer = 'bigint';\n // TODO: note sure why emscripten cannot bind if there is only 1 argument\n wllamaMalloc = callWrapper('wllama_malloc', pointer, [\n 'number',\n pointer,\n ]);\n wllamaStart = callWrapper('wllama_start', 'string', [], true);\n wllamaAction = callWrapper(\n 'wllama_action',\n pointer,\n ['string', pointer],\n true\n );\n wllamaExit = callWrapper('wllama_exit', 'string', []);\n wllamaDebug = callWrapper('wllama_debug', 'string', []);\n msg({ callbackId, result: null });\n };\n wModuleInit();\n } catch (err) {\n msg({ callbackId, err });\n }\n return;\n }\n\n if (verb === 'fs.alloc') {\n const argFilename = args[0];\n const argSize = args[1];\n try {\n // create blank file\n const emptyBuffer = new ArrayBuffer(0);\n Module['FS_createDataFile'](\n '/models',\n argFilename,\n emptyBuffer,\n true,\n true,\n true\n );\n // alloc data on heap\n const fileId = heapfsAlloc(argFilename, argSize);\n msg({ callbackId, result: { fileId } });\n } catch (err) {\n msg({ callbackId, err });\n }\n return;\n }\n\n if (verb === 'fs.write') {\n const argFileId = args[0];\n const argBuffer = args[1];\n const argOffset = args[2];\n try {\n const writtenBytes = heapfsWrite(argFileId, argBuffer, argOffset);\n msg({ callbackId, result: { writtenBytes } });\n } catch (err) {\n msg({ callbackId, err });\n }\n return;\n }\n\n if (verb === 'wllama.start') {\n try {\n const result = await wllamaStart();\n msg({ callbackId, result });\n } catch (err) {\n msg({ callbackId, err });\n }\n return;\n }\n\n if (verb === 'wllama.action') {\n const argAction = args[0];\n const argEncodedMsg = args[1];\n try {\n const inputPtr = await wllamaMalloc(BigInt(argEncodedMsg.byteLength), 0);\n // copy data to wasm heap\n const inputBuffer = new Uint8Array(\n getHeapU8().buffer,\n Number(inputPtr),\n argEncodedMsg.byteLength\n );\n inputBuffer.set(argEncodedMsg, 0);\n const outputPtr = await wllamaAction(argAction, inputPtr);\n // length of output buffer is written at the first 4 bytes of input buffer\n const outputLen = new Uint32Array(\n getHeapU8().buffer,\n Number(inputPtr),\n 1\n )[0];\n // copy the output buffer to JS heap\n const outputBuffer = new Uint8Array(outputLen);\n const outputSrcView = new Uint8Array(\n getHeapU8().buffer,\n Number(outputPtr),\n outputLen\n );\n outputBuffer.set(outputSrcView, 0); // copy it\n msg({ callbackId, result: outputBuffer }, [outputBuffer.buffer]);\n } catch (err) {\n msg({ callbackId, err });\n }\n return;\n }\n\n if (verb === 'wllama.exit') {\n try {\n const result = await wllamaExit();\n msg({ callbackId, result });\n } catch (err) {\n msg({ callbackId, err });\n }\n return;\n }\n\n if (verb === 'wllama.debug') {\n try {\n const result = await wllamaDebug();\n msg({ callbackId, result });\n } catch (err) {\n msg({ callbackId, err });\n }\n return;\n }\n};\n";

Expand Down
Loading