Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions assistant/src/config/schemas/__tests__/memory-v2.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ describe("MemoryV2ConfigSchema", () => {
top_k: 50,
alpha: 0.3,
model: "Alibaba-NLP/gte-reranker-modernbert-base",
dtype: "q8",
},
});
});
Expand Down
22 changes: 22 additions & 0 deletions assistant/src/config/schemas/memory-v2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@ const WEIGHT_SUM_TOLERANCE = 0.001;
*/
const DEFAULT_RERANK_MODEL = "Alibaba-NLP/gte-reranker-modernbert-base";

/**
* ONNX weight precision passed to `@huggingface/transformers`. Sourced from
* transformers.js's supported `dtype` values; `q8` (int8) is ~3× faster than
* `fp32` on CPU with negligible reranker accuracy loss. Single source of
* truth for both the schema enum and the `LocalRerankBackend` type.
*/
export const RerankDtypeEnum = z.enum([
"fp32",
"fp16",
"q8",
"int8",
"uint8",
"q4",
"bnb4",
"q4f16",
]);
export type RerankDtype = z.infer<typeof RerankDtypeEnum>;

/**
* Memory v2 (concept-page activation model) configuration.
*
Expand Down Expand Up @@ -224,12 +242,16 @@ export const MemoryV2ConfigSchema = z
.describe(
"HuggingFace model id for the cross-encoder. Must have an ONNX export reachable from huggingface.co/<model>/resolve/main/onnx/model.onnx.",
),
dtype: RerankDtypeEnum.default("q8").describe(
"ONNX weight precision passed to `@huggingface/transformers`. `q8` (int8) is ~3× faster than `fp32` on CPU with negligible reranker accuracy loss. The worker fails to spawn if the configured model has no matching quantized export — `reranker.ts` then falls back to pure fused scores for the turn.",
),
})
.default({
enabled: false,
top_k: 50,
alpha: 0.3,
model: DEFAULT_RERANK_MODEL,
dtype: "q8",
})
.describe(
"Cross-encoder rerank configuration. When enabled, picks the top-K candidates by pre-rerank A_o, runs the cross-encoder once per channel (user, assistant) on that unified set, and adds an alpha-weighted normalized boost to A_o for each scored slug.",
Expand Down
33 changes: 24 additions & 9 deletions assistant/src/memory/embedding-runtime-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@ const ONNXRUNTIME_COMMON_VERSION = "1.21.0";
const TRANSFORMERS_VERSION = "3.8.1";
const JINJA_VERSION = "0.5.5";

/** Composite version string for cache invalidation. */
const RUNTIME_VERSION = `ort-${ONNXRUNTIME_NODE_VERSION}_hf-${TRANSFORMERS_VERSION}_jinja-${JINJA_VERSION}`;
/**
* Composite version string for cache invalidation. Bumping the trailing
* `_workers-vN` suffix forces existing installs to regenerate the worker
* scripts when the worker IPC contract or spawn-args list changes (without
* requiring an `@huggingface/transformers` version bump).
*/
const RUNTIME_VERSION = `ort-${ONNXRUNTIME_NODE_VERSION}_hf-${TRANSFORMERS_VERSION}_jinja-${JINJA_VERSION}_workers-v2`;

const WORKER_FILENAME = "embed-worker.mjs";
const RERANK_WORKER_FILENAME = "rerank-worker.mjs";
Expand Down Expand Up @@ -174,9 +179,16 @@ process.stdin.on('end', () => process.exit(0));

function generateRerankWorkerScript(): string {
// Cross-encoder rerank worker. Loads a sequence-classification model and
// scores (query, passage) pairs in batched ONNX inference calls. Mirrors
// the embed worker's lifecycle (ready signal, JSON-lines IPC, sequential
// queue) so LocalRerankBackend can reuse the same supervisor pattern.
// scores paired (queries[i], passages[i]) tuples in one batched ONNX
// inference call. Mirrors the embed worker's lifecycle (ready signal,
// JSON-lines IPC, sequential queue) so LocalRerankBackend can reuse the
// same supervisor pattern.
//
// Request shape: { id, queries: string[], passages: string[] } with
// queries.length === passages.length. Each pair is one (query, passage)
// tuple; multiple distinct queries can ride in a single batch so the
// activation pipeline can score the user-channel and assistant-channel
// queries against a shared candidate set in one tokenizer + ONNX call.
return `\
// rerank-worker.mjs — Auto-generated by EmbeddingRuntimeManager
// Runs in a separate bun process, communicates via JSON-lines over stdin/stdout.
Expand All @@ -189,13 +201,14 @@ import {

const model = process.argv[2];
const cacheDir = process.argv[3];
const dtype = process.argv[4] || 'q8';
if (cacheDir && env) env.cacheDir = cacheDir;

let tokenizer;
let session;
try {
tokenizer = await AutoTokenizer.from_pretrained(model);
session = await AutoModelForSequenceClassification.from_pretrained(model, { dtype: 'fp32' });
session = await AutoModelForSequenceClassification.from_pretrained(model, { dtype });
process.stdout.write(JSON.stringify({ type: 'ready' }) + '\\n');
} catch (err) {
process.stdout.write(JSON.stringify({ type: 'error', error: err.message || String(err) }) + '\\n');
Expand Down Expand Up @@ -227,12 +240,14 @@ async function processQueue() {
let req;
try { req = JSON.parse(line); } catch { continue; }
try {
const { id, query, passages } = req;
if (!Array.isArray(passages) || passages.length === 0) {
const { id, queries, passages } = req;
if (
!Array.isArray(queries) || !Array.isArray(passages) ||
queries.length !== passages.length || passages.length === 0
) {
process.stdout.write(JSON.stringify({ id, scores: [] }) + '\\n');
continue;
}
const queries = passages.map(() => query);
const inputs = await tokenizer(queries, {
text_pair: passages,
padding: true,
Expand Down
43 changes: 33 additions & 10 deletions assistant/src/memory/rerank-local.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/** Local cross-encoder rerank backend — drives the rerank-worker subprocess. */
import { existsSync } from "node:fs";

import type { RerankDtype } from "../config/schemas/memory-v2.js";
import { getLogger } from "../util/logger.js";
import { getEmbeddingModelsDir } from "../util/platform.js";
import { PromiseGuard } from "../util/promise-guard.js";
Expand All @@ -17,6 +18,7 @@ interface WorkerResponse {

export class LocalRerankBackend {
readonly model: string;
readonly dtype: RerankDtype;

// eslint-disable-next-line @typescript-eslint/no-explicit-any
private workerProc: any = null;
Expand All @@ -35,21 +37,32 @@ export class LocalRerankBackend {

private readonly initGuard = new PromiseGuard<void>();

constructor(model: string) {
constructor(model: string, dtype: RerankDtype) {
this.model = model;
this.dtype = dtype;
}

/** Score `(query, passages[i])` pairs in one batched ONNX inference call. */
async score(query: string, passages: string[]): Promise<number[]> {
/**
* Score paired `(queries[i], passages[i])` tuples in one batched ONNX
* inference call. Multiple distinct queries can ride in a single batch
* so callers can score the user-channel and assistant-channel queries
* against a shared candidate set in one tokenizer + forward pass.
*/
async score(queries: string[], passages: string[]): Promise<number[]> {
if (this.disposeRequested) {
throw new Error("Local rerank backend is shutting down");
}
if (passages.length === 0) return [];
if (queries.length !== passages.length) {
throw new Error(
`Rerank backend got ${queries.length} queries for ${passages.length} passages`,
);
}

this.activeRequests++;
try {
await this.ensureInitialized();
const response = await this.sendRequest({ query, passages });
const response = await this.sendRequest({ queries, passages });
if (response.error) {
throw new Error(`Rerank worker error: ${response.error}`);
}
Expand All @@ -74,7 +87,7 @@ export class LocalRerankBackend {
}

private sendRequest(payload: {
query: string;
queries: string[];
passages: string[];
}): Promise<WorkerResponse> {
const id = ++this.requestCounter;
Expand Down Expand Up @@ -130,12 +143,19 @@ export class LocalRerankBackend {
const modelCacheDir = `${embeddingModelsDir}/model-cache`;

log.info(
{ bunPath, workerPath, model: this.model },
{ bunPath, workerPath, model: this.model, dtype: this.dtype },
"Spawning rerank worker process",
);

const proc = Bun.spawn({
cmd: [bunPath, "--smol", workerPath, this.model, modelCacheDir],
cmd: [
bunPath,
"--smol",
workerPath,
this.model,
modelCacheDir,
this.dtype,
],
stdin: "pipe",
stdout: "pipe",
stderr: "pipe",
Expand Down Expand Up @@ -325,16 +345,19 @@ export class LocalRerankBackend {

let _backend: LocalRerankBackend | null = null;

export function getOrCreateRerankBackend(model: string): LocalRerankBackend {
if (_backend?.model === model) return _backend;
export function getOrCreateRerankBackend(
model: string,
dtype: RerankDtype,
): LocalRerankBackend {
if (_backend?.model === model && _backend.dtype === dtype) return _backend;
if (_backend) {
try {
_backend.dispose();
} catch {
/* best effort */
}
}
_backend = new LocalRerankBackend(model);
_backend = new LocalRerankBackend(model, dtype);
return _backend;
}

Expand Down
44 changes: 26 additions & 18 deletions assistant/src/memory/v2/__tests__/activation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,24 +125,31 @@ mock.module("@qdrant/js-client-rest", () => ({

// Reranker mock — keeps the activation tests hermetic when rerank.enabled is
// flipped on by an integration case. Tests stage `rerankState.scores` to
// program the boost outcome.
// program the boost outcome. The activation pipeline now passes both the
// user-channel and assistant-channel queries into a single rerank call, so
// `rerankState.calls` records the full `queries` array per invocation.
const rerankState = {
scores: null as Map<string, number> | null,
calls: [] as Array<{ query: string; candidates: string[] }>,
calls: [] as Array<{ queries: string[]; candidates: string[] }>,
};
mock.module("../reranker.js", () => ({
rerankCandidates: async (
query: string,
queries: readonly string[],
candidates: readonly string[],
): Promise<Map<string, number>> => {
rerankState.calls.push({ query, candidates: [...candidates] });
if (rerankState.scores === null) return new Map();
const out = new Map<string, number>();
for (const slug of candidates) {
const v = rerankState.scores.get(slug);
if (v !== undefined) out.set(slug, v);
}
return out;
): Promise<Array<Map<string, number>>> => {
rerankState.calls.push({
queries: [...queries],
candidates: [...candidates],
});
return queries.map(() => {
if (rerankState.scores === null) return new Map();
const out = new Map<string, number>();
for (const slug of candidates) {
const v = rerankState.scores.get(slug);
if (v !== undefined) out.set(slug, v);
}
return out;
});
},
_resetRerankCacheForTests: () => {},
}));
Expand Down Expand Up @@ -612,8 +619,9 @@ describe("computeOwnActivation", () => {
expect(out.activation.get("semantic")!).toBeGreaterThan(
out.activation.get("lexical")!,
);
// Rerank should have been called once per rerank-enabled channel.
expect(rerankState.calls).toHaveLength(2);
// Both rerank-enabled channels ride in a single batched rerank call.
expect(rerankState.calls).toHaveLength(1);
expect(rerankState.calls[0].queries).toEqual(["u", "a"]);
});

test("rerank pool is the unified top-K by pre-rerank A_o, not per-channel fused", async () => {
Expand Down Expand Up @@ -672,11 +680,11 @@ describe("computeOwnActivation", () => {
config,
});

expect(rerankState.calls).toHaveLength(2);
// Both channels rerank against the same unified slug set, sorted by
// pre-rerank A_o descending.
// Single batched rerank call carrying both channel queries against the
// unified slug set, sorted by pre-rerank A_o descending.
expect(rerankState.calls).toHaveLength(1);
expect(rerankState.calls[0].queries).toEqual(["u", "a"]);
expect(rerankState.calls[0].candidates).toEqual(["a", "c"]);
expect(rerankState.calls[1].candidates).toEqual(["a", "c"]);
});

test("rerank-disabled candidates outside the unified pool get zero boost", async () => {
Expand Down
Loading
Loading