Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions cpp/wllama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,13 @@ struct wllama_context
}
}

std::pair<std::string, bool> get_next_result()
std::pair<server_task_result_ptr, bool> get_next_result()
{
server_task_result_ptr result = rd->next(should_stop);
if (result)
return {result->to_json().dump(), result->is_error()};
return {std::move(result), result->is_error()};
else
return {"", false};
return {nullptr, false};
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

kv_dump dump_metadata()
Expand Down Expand Up @@ -492,11 +492,31 @@ struct wllama_context
glue_msg_get_result_res res;

bool has_more = run_loop();
auto [data_json, is_error] = get_next_result();
auto [result, is_error] = get_next_result();

json data_json;
if (result)
{
auto *res = dynamic_cast<server_task_result_embd *>(result.get());
if (res)
{
// special handling for embeddings OAI-compat
json body = {{"model", meta->model_name}};
json responses = json::array();
responses.push_back(result->to_json());
// TODO: support base64 output
data_json = format_embeddings_response_oaicompat(body, meta->model_name, responses, false);
}
else
{
// otherwise, it should be a completion result, nothing special to do
data_json = result->to_json();
}
}

res.success.value = true;
res.has_more.value = has_more;
res.data_json.value = data_json;
res.data_json.value = result ? data_json.dump() : "";
res.is_error.value = is_error;
return res;
}
Expand Down
4 changes: 2 additions & 2 deletions examples/basic/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ <h2>Embeddings</h2>
setEmbdDisable(true);
let embdA = await wllama1.createEmbedding({ input: elemInputA.value });
console.log({embdA});
embdA = embdA.embedding; // OAI-compat response
embdA = embdA.data[0].embedding; // OAI-compat response
let embdB = await wllama1.createEmbedding({ input: elemInputB.value });
console.log({embdB});
embdB = embdB.embedding; // OAI-compat response
embdB = embdB.data[0].embedding; // OAI-compat response
// since embeddings are normalized, we don't need to calculate norm
const dotProd = embdA.reduce((acc, _, i) => acc + embdA[i]*embdB[i], 0);
elemOutputEmbd.textContent = dotProd;
Expand Down
2 changes: 1 addition & 1 deletion examples/embeddings/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
print(`Calculating embedding for sentence #${i}: "${sentence}"`);
timeStart();
const result = await wllama.createEmbedding({ input: sentence });
const vector = result.embedding;
const vector = result.data[0].embedding;
print(`OK, take ${timeEnd()} ms`);
embeddings.push(vector);
}
Expand Down
Binary file modified src/wasm/wllama.wasm
Binary file not shown.
4 changes: 2 additions & 2 deletions src/wllama.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ test.sequential('generates embeddings', async () => {
const res = await wllama.createEmbedding({ input: text });

expect(res).toBeDefined();
const embedding = (res as any).embedding as number[];
const embedding = res.data[0].embedding as number[];
expect(Array.isArray(embedding)).toBe(true);
expect(embedding.length).toBeGreaterThan(0);
for (const e of embedding) {
Expand All @@ -181,7 +181,7 @@ test.sequential('generates embeddings', async () => {

// slightly different text should have high cosine similarity
const res2 = await wllama.createEmbedding({ input: text + ' ' });
const embedding2 = (res2 as any).embedding as number[];
const embedding2 = res.data[0].embedding as number[];
const dot = embedding.reduce((acc, v, i) => acc + v * embedding2[i], 0);
Comment thread
coderabbitai[bot] marked this conversation as resolved.
const norm1 = Math.sqrt(embedding.reduce((acc, v) => acc + v * v, 0));
const norm2 = Math.sqrt(embedding2.reduce((acc, v) => acc + v * v, 0));
Expand Down
3 changes: 2 additions & 1 deletion src/wllama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import CacheManager, { type DownloadOptions } from './cache-manager';
import { ModelManager, Model, type ModelSource } from './model-manager';
import type {
GlueMsgCompletionRes,
GlueMsgEmbeddingRes,
GlueMsgGetResultRes,
GlueMsgLoadRes,
} from './glue/messages';
Expand Down Expand Up @@ -555,7 +556,7 @@ export class Wllama {
): Promise<CreateEmbeddingResponse> {
this.checkModelLoaded();

const result = await this.proxy.wllamaAction<GlueMsgCompletionRes>(
const result = await this.proxy.wllamaAction<GlueMsgEmbeddingRes>(
'embedding',
{
_name: 'embd_req',
Expand Down
Loading