Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ jobs:
- name: Test (Firefox)
run: npm run test:firefox

# TODO: add test for WebGPU on github hosted runner ; current missing ShaderF16 support

lint:
runs-on: ubuntu-latest
steps:
Expand Down
8 changes: 6 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)

add_compile_options(
-O3 -msimd128 -DNDEBUG
-flto=full -frtti -fwasm-exceptions
-flto=full -frtti
-fwasm-exceptions
-pthread
-sMEMORY64=1
)
add_link_options(
-sMEMORY64=1
-flto=full -fwasm-exceptions
-flto=full
-fwasm-exceptions
--no-entry
-sEXPORT_ALL=1
-sEXPORT_ES6=0
Expand All @@ -31,6 +33,8 @@ add_link_options(
-sPTHREAD_POOL_SIZE=Module[\"pthreadPoolSize\"]
-sUSE_PTHREADS=1
-pthread
-sJSPI
-sJSPI_EXPORTS=['wllama_start','wllama_action']
)

add_subdirectory(llama.cpp)
Expand Down
44 changes: 31 additions & 13 deletions cpp/wllama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ struct wllama_context
params.cache_type_k = kv_cache_type_from_str(req.cache_type_k.value);
if (req.cache_type_v.not_null())
params.cache_type_v = kv_cache_type_from_str(req.cache_type_v.value);
if (req.flash_attn.not_null())
params.flash_attn_type = req.flash_attn.value ? LLAMA_FLASH_ATTN_TYPE_AUTO : LLAMA_FLASH_ATTN_TYPE_DISABLED;
if (req.swa_full.not_null())
params.swa_full = req.swa_full.value;
if (req.n_ctx_checkpoints.not_null())
Expand All @@ -314,22 +316,29 @@ struct wllama_context
params.chat_template = req.chat_template.value;
if (req.jinja.not_null())
params.use_jinja = req.jinja.value;
if (req.reasoning.not_null()) {
if (req.reasoning.value) {
if (req.reasoning.not_null())
{
if (req.reasoning.value)
{
params.enable_reasoning = 1;
params.default_template_kwargs["enable_thinking"] = "true";
} else {
}
else
{
params.enable_reasoning = 0;
params.default_template_kwargs["enable_thinking"] = "false";
}
}
if (req.default_template_kwargs_keys.not_null() && req.default_template_kwargs_vals.not_null()) {
if (req.default_template_kwargs_keys.not_null() && req.default_template_kwargs_vals.not_null())
{
auto &keys = req.default_template_kwargs_keys.arr;
auto &vals = req.default_template_kwargs_vals.arr;
if (keys.size() != vals.size()) {
if (keys.size() != vals.size())
{
throw app_exception("default_template_kwargs_keys and default_template_kwargs_vals must have the same length");
}
for (size_t i = 0; i < keys.size(); i++) {
for (size_t i = 0; i < keys.size(); i++)
{
params.default_template_kwargs[keys[i]] = vals[i];
}
}
Expand Down Expand Up @@ -422,28 +431,37 @@ struct wllama_context
json body = json::parse(req_raw);

json prompt;
if (body.count("input") != 0) {
if (body.count("input") != 0)
{
prompt = body.at("input");
} else if (body.contains("content")) {
}
else if (body.contains("content"))
{
prompt = body.at("content");
} else {
}
else
{
throw app_exception("\"input\" or \"content\" must be provided");
}

int embd_normalize = 2;
if (body.count("embd_normalize") != 0) {
if (body.count("embd_normalize") != 0)
{
embd_normalize = body.at("embd_normalize");
}

auto tokenized_prompts = tokenize_input_prompts(vocab, nullptr, prompt, true, true);
for (const auto &tokens : tokenized_prompts) {
if (tokens.empty()) {
for (const auto &tokens : tokenized_prompts)
{
if (tokens.empty())
{
throw app_exception("Input content cannot be empty");
}
}

std::vector<server_task> tasks;
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
for (size_t i = 0; i < tokenized_prompts.size(); i++)
{
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
task.id = rd->get_new_id();
task.tokens = std::move(tokenized_prompts[i]);
Expand Down
25 changes: 25 additions & 0 deletions examples/multimodal/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
color: rgb(222, 222, 222);
font-family: 'Courier New', Courier, monospace;
padding: 1em;
padding-bottom: 4em;
}

#output_cmpl {
Expand Down Expand Up @@ -92,6 +93,13 @@ <h2>Multimodal (Vision) Completion</h2>

Output:<br />
<div id="output_cmpl"></div>
<div
id="output_timings"
style="margin-top: 0.5em; color: #aaa; font-size: 0.85em; display: none"
>
Prompt: <span id="timing_prompt">-</span> t/s &nbsp;|&nbsp; Generation:
<span id="timing_gen">-</span> t/s
</div>

<script type="module">
import { Wllama } from '../../esm/index.js';
Expand All @@ -112,6 +120,12 @@ <h2>Multimodal (Vision) Completion</h2>
async function main() {
setRunDisabled(true);

// Pre-load the example image
const response = await fetch('./bliss.png');
imageData = await response.arrayBuffer();
elemPreviewImage.src = './bliss.png';
elemPreviewImage.style.display = 'block';
Comment thread
ngxson marked this conversation as resolved.

elemBtnLoadRemote.onclick = async () => {
elemBtnLoadRemote.disabled = true;
elemBtnPickFiles.disabled = true;
Expand Down Expand Up @@ -149,6 +163,7 @@ <h2>Multimodal (Vision) Completion</h2>
if (!wllama) return;
setRunDisabled(true);
elemOutputCmpl.textContent = '';
elemOutputTimings.style.display = 'none';
try {
await runCompletion();
} catch (err) {
Expand Down Expand Up @@ -226,8 +241,15 @@ <h2>Multimodal (Vision) Completion</h2>
temperature: 0.2,
stream: true,
onData: (chunk) => {
console.log('Received chunk:', chunk);
const delta = chunk.choices[0]?.delta?.content;
if (delta) elemOutputCmpl.textContent += delta;
if (chunk.timings) {
const t = chunk.timings;
elemTimingPrompt.textContent = t.prompt_per_second.toFixed(1);
elemTimingGen.textContent = t.predicted_per_second.toFixed(1);
elemOutputTimings.style.display = 'block';
}
},
});
}
Expand Down Expand Up @@ -266,6 +288,9 @@ <h2>Multimodal (Vision) Completion</h2>
const elemPreviewImage = document.getElementById('preview_image');
const elemBtnRunCmpl = document.getElementById('btn_run_cmpl');
const elemOutputCmpl = document.getElementById('output_cmpl');
const elemOutputTimings = document.getElementById('output_timings');
const elemTimingPrompt = document.getElementById('timing_prompt');
const elemTimingGen = document.getElementById('timing_gen');

main();
</script>
Expand Down
2 changes: 1 addition & 1 deletion llama.cpp
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
"format": "prettier --write .",
"test": "vitest",
"test:firefox": "BROWSER=firefox vitest",
"test:safari": "BROWSER=safari vitest"
"test:safari": "BROWSER=safari vitest",
"test:wgpu": "WEBGPU=1 vitest"
},
"repository": {
"type": "git",
Expand Down
12 changes: 11 additions & 1 deletion scripts/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,17 @@ services:

mkdir -p build
cd build
emcmake cmake ..
mkdir -p emdawn

DAWN_TAG=v20260317.182325
EMDAWN_PKG="emdawnwebgpu_pkg-$${DAWN_TAG}.zip"
EMDAWNWEBGPU_DIR="/source/build/emdawn/emdawnwebgpu_pkg"
echo "Downloading $${EMDAWN_PKG}"
curl -L -o emdawn.zip \
"https://github.com/google/dawn/releases/download/$${DAWN_TAG}/$${EMDAWN_PKG}"
python3 -c "import zipfile; zf=zipfile.ZipFile('emdawn.zip','r'); zf.extractall('/source/build/emdawn'); zf.close()"

emcmake cmake .. -DGGML_WEBGPU=ON -DGGML_WEBGPU_JSPI=ON -DEMDAWNWEBGPU_DIR="$${EMDAWNWEBGPU_DIR}"
emmake make wllama -j

# go back to root
Expand Down
15 changes: 15 additions & 0 deletions src/types/oai-compat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,18 @@ export interface ChatCompletionChunkChoice {
logprobs: ChatCompletionChoiceLogprobs | null;
}

export interface ResultTimings {
cache_n: number;
prompt_n: number;
prompt_ms: number;
prompt_per_token_ms: number;
prompt_per_second: number;
predicted_n: number;
predicted_ms: number;
predicted_per_token_ms: number;
predicted_per_second: number;
}

/** Response when stream=true — one chunk per SSE event */
export interface ChatCompletionChunk {
id: string;
Expand All @@ -204,6 +216,7 @@ export interface ChatCompletionChunk {
model: string;
choices: ChatCompletionChunkChoice[];
usage?: ChatCompletionUsage | null;
timings?: ResultTimings;
}

// Raw (text) completion
Expand Down Expand Up @@ -249,6 +262,7 @@ export interface RawCompletionResponse {
choices: RawCompletionChoice[];
usage: ChatCompletionUsage;
system_fingerprint?: string;
timings?: ResultTimings;
}

/** One chunk when stream=true */
Expand All @@ -264,6 +278,7 @@ export interface RawCompletionChunk {
logprobs: null;
}>;
usage?: ChatCompletionUsage | null;
timings?: ResultTimings;
}

// Embeddings
Expand Down
3 changes: 2 additions & 1 deletion src/types/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ export interface LoadModelParams {
seed?: number;
n_ctx?: number;
n_batch?: number;
// by default, all layers are offloaded if WebGPU is available
n_gpu_layers?: number;
// by default, on multi-thread build, we take half number of available threads (hardwareConcurrency / 2)
n_threads?: number;
embeddings?: boolean;
Expand All @@ -26,7 +28,6 @@ export interface LoadModelParams {
yarn_beta_fast?: number;
yarn_beta_slow?: number;
yarn_orig_ctx?: number;
// TODO: add group attention
// optimizations
cache_type_k?: 'f32' | 'f16' | 'q8_0' | 'q5_1' | 'q5_0' | 'q4_1' | 'q4_0';
cache_type_v?: 'f32' | 'f16' | 'q8_0' | 'q5_1' | 'q5_0' | 'q4_1' | 'q4_0';
Expand Down
21 changes: 21 additions & 0 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,20 @@ const isSupportSIMD = async () =>
])
);

/**
* @returns true if browser support JSPI
*/
export const isSupportJSPI = () => {
return !!(WebAssembly as any).Suspending;
};

/**
* @returns true if brower support WebGPU and JSPI (required by emscripten build)
*/
export const isSupportWebGPU = () => {
return !!(navigator as any).gpu && isSupportJSPI();
};

/**
* Throws an error if the environment is not compatible
*/
Expand All @@ -277,6 +291,13 @@ export const isSafari = (): boolean => {
); // safari
};

/**
* Check if browser is Firefox
*/
export const isFirefox = (): boolean => {
return !!navigator.userAgent.match(/Firefox\/([0-9\.]+)(?:\s|$)/);
};

/**
* Regular expression to validate GGUF file paths/URLs
* Matches paths ending with .gguf and optional query parameters
Expand Down
2 changes: 1 addition & 1 deletion src/wasm/wllama.js

Large diffs are not rendered by default.

Binary file modified src/wasm/wllama.wasm
Binary file not shown.
24 changes: 21 additions & 3 deletions src/wllama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ import {
absoluteUrl,
cbToAsyncIter,
checkEnvironmentCompatible,
isFirefox,
isString,
isSupportJSPI,
isSupportMultiThread,
isSupportWebGPU,
MMPROJ_FILE_NAME,
prepareBlobs,
} from './utils';
Expand Down Expand Up @@ -170,6 +173,13 @@ export class Wllama {
parallelDownloads: wllamaConfig.parallelDownloads,
allowOffline: wllamaConfig.allowOffline,
});

// warn user to enable JSPI on firefox
if (isFirefox() && !isSupportJSPI()) {
this.logger().warn(
'WebGPU is disabled on Firefox due to missing JSPI support. Please enable "javascript.options.wasm_js_promise_integration" in "about:config" to allow WebGPU support.'
);
}
}

private logger() {
Expand Down Expand Up @@ -343,6 +353,14 @@ export class Wllama {
return this.chatTemplate ?? null;
}

/**
* Check if WebGPU is supported by the current environment.
* @returns true if WebGPU is supported
*/
isSupportWebGPU(): boolean {
return isSupportWebGPU();
}

/**
* Load model from a given URL (or a list of URLs, in case the model is splitted into smaller files)
* - If the model already been downloaded (via `downloadModel()`), then we will use the cached model
Expand Down Expand Up @@ -411,7 +429,7 @@ export class Wllama {
if (this.proxy) {
throw new WllamaError('Module is already initialized', 'load_error');
}
// detect if we can use multi-thread
// detect if we can use multi-thread and webgpu
const supportMultiThread = await isSupportMultiThread();
const hwConccurency = Math.floor((navigator.hardwareConcurrency || 1) / 2);
const nbThreads = params.n_threads ?? hwConccurency;
Expand Down Expand Up @@ -452,8 +470,8 @@ export class Wllama {
log_level: logLevel,
use_mmap: true,
use_mlock: true,
n_gpu_layers: 0, // not supported for now
n_ctx: params.n_ctx || 1024,
n_gpu_layers: params.n_gpu_layers ?? 99999,
n_ctx: params.n_ctx ?? 1024,
n_threads: this.useMultiThread ? nbThreads : 1,
n_ctx_auto: false, // not supported for now
mmproj_path: modelFiles.mmproj
Expand Down
Loading
Loading