From a619c20b5b70145d9e52ddc6d4b9a51667f11602 Mon Sep 17 00:00:00 2001 From: Vellum Assistant Date: Tue, 5 May 2026 01:06:54 +0000 Subject: [PATCH] feat(memory-v2): cross-encoder rerank as additive boost MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an opt-in (`memory.v2.rerank.enabled: false` by default) cross-encoder rerank step that runs locally via the existing embedding-runtime worker infrastructure. When enabled, simBatch wraps the dense+sparse fused score with `boosted = clamp01(fused + alpha · normalized_rerank)` for the top-K candidates of the user and assistant similarity channels — NOW keeps pure fused since structured context is outside the cross-encoder's training distribution. Default model `Xenova/bge-reranker-base` (278M, MIT, ONNX); long-term target is `BAAI/bge-reranker-v2-m3` once a public ONNX export ships. --- .../schemas/__tests__/memory-v2.test.ts | 6 + assistant/src/config/schemas/memory-v2.ts | 48 +++ .../src/memory/embedding-runtime-manager.ts | 105 +++++- assistant/src/memory/rerank-local.ts | 351 ++++++++++++++++++ .../memory/v2/__tests__/activation.test.ts | 83 +++++ .../src/memory/v2/__tests__/reranker.test.ts | 218 +++++++++++ assistant/src/memory/v2/__tests__/sim.test.ts | 176 +++++++++ assistant/src/memory/v2/activation.ts | 6 +- assistant/src/memory/v2/reranker.ts | 126 +++++++ assistant/src/memory/v2/sim.ts | 39 ++ 10 files changed, 1153 insertions(+), 5 deletions(-) create mode 100644 assistant/src/memory/rerank-local.ts create mode 100644 assistant/src/memory/v2/__tests__/reranker.test.ts create mode 100644 assistant/src/memory/v2/reranker.ts diff --git a/assistant/src/config/schemas/__tests__/memory-v2.test.ts b/assistant/src/config/schemas/__tests__/memory-v2.test.ts index cb3e5dc2d80..bdb351c0f7d 100644 --- a/assistant/src/config/schemas/__tests__/memory-v2.test.ts +++ b/assistant/src/config/schemas/__tests__/memory-v2.test.ts @@ -26,6 +26,12 @@ describe("MemoryV2ConfigSchema", () => { consolidation_interval_hours: 4, max_page_chars: 5000, consolidation_prompt_path: null, + rerank: { + enabled: false, + top_k: 50, + alpha: 0.3, + model: "Xenova/bge-reranker-base", + }, }); }); diff --git a/assistant/src/config/schemas/memory-v2.ts b/assistant/src/config/schemas/memory-v2.ts index 394b2bcaed5..6fbb6c73c16 100644 --- a/assistant/src/config/schemas/memory-v2.ts +++ b/assistant/src/config/schemas/memory-v2.ts @@ -7,6 +7,13 @@ import { z } from "zod"; */ const WEIGHT_SUM_TOLERANCE = 0.001; +/** + * Default cross-encoder model for memory v2 reranking. `BAAI/bge-reranker-v2-m3` + * is the long-term target but currently lacks a public ONNX export; the + * `Xenova/bge-reranker-base` (278M, MIT, ONNX-converted) is the working pick. + */ +const DEFAULT_RERANK_MODEL = "Xenova/bge-reranker-base"; + /** * Memory v2 (concept-page activation model) configuration. * @@ -192,6 +199,47 @@ export const MemoryV2ConfigSchema = z .describe( "Optional path to a file whose contents replace the bundled consolidation prompt. Absolute paths are used as-is, a leading `~/` is expanded to the home directory, otherwise the path is resolved under the workspace root. The loaded contents may include `{{CUTOFF}}`, which is substituted with the run's ISO-8601 cutoff timestamp. If the file is missing, unreadable, or empty, the bundled prompt is used and a warning is logged.", ), + rerank: z + .object({ + enabled: z + .boolean() + .default(false) + .describe( + "Whether to apply cross-encoder reranking as an additive boost to the user + assistant similarity channels. Disabled by default — opt in once measured.", + ), + top_k: z + .number() + .int() + .positive() + .max(200) + .default(50) + .describe( + "Number of top-fused candidates per `simBatch` call to send through the reranker. Tail candidates keep their pure fused score.", + ), + alpha: z + .number() + .min(0) + .max(1) + .default(0.3) + .describe( + "Boost weight: `boosted = clamp01(fused + alpha · normalized_rerank)`. Top reranker hit can lift its fused score by up to `alpha`; bottom of top_k stays roughly unchanged.", + ), + model: z + .string() + .default(DEFAULT_RERANK_MODEL) + .describe( + "HuggingFace model id for the cross-encoder. Must have an ONNX export reachable from huggingface.co//resolve/main/onnx/model.onnx.", + ), + }) + .default({ + enabled: false, + top_k: 50, + alpha: 0.3, + model: DEFAULT_RERANK_MODEL, + }) + .describe( + "Cross-encoder rerank configuration. When enabled, runs a local cross-encoder over the top-K fused candidates per `simBatch(useRerank: true)` call and adds an alpha-weighted normalized boost to their fused scores.", + ), }) .describe( "Memory v2 — concept-page activation model with hourly LLM-driven consolidation", diff --git a/assistant/src/memory/embedding-runtime-manager.ts b/assistant/src/memory/embedding-runtime-manager.ts index 163977c3a51..a836585a73b 100644 --- a/assistant/src/memory/embedding-runtime-manager.ts +++ b/assistant/src/memory/embedding-runtime-manager.ts @@ -42,6 +42,7 @@ const JINJA_VERSION = "0.5.5"; const RUNTIME_VERSION = `ort-${ONNXRUNTIME_NODE_VERSION}_hf-${TRANSFORMERS_VERSION}_jinja-${JINJA_VERSION}`; const WORKER_FILENAME = "embed-worker.mjs"; +const RERANK_WORKER_FILENAME = "rerank-worker.mjs"; /** Module-level guard so concurrent in-process calls share one download. */ const installGuard = new PromiseGuard(); @@ -171,6 +172,91 @@ process.stdin.on('end', () => process.exit(0)); `; } +function generateRerankWorkerScript(): string { + // Cross-encoder rerank worker. Loads a sequence-classification model and + // scores (query, passage) pairs in batched ONNX inference calls. Mirrors + // the embed worker's lifecycle (ready signal, JSON-lines IPC, sequential + // queue) so LocalRerankBackend can reuse the same supervisor pattern. + return `\ +// rerank-worker.mjs — Auto-generated by EmbeddingRuntimeManager +// Runs in a separate bun process, communicates via JSON-lines over stdin/stdout. +process.title = 'rerank-worker'; +import { + AutoModelForSequenceClassification, + AutoTokenizer, + env, +} from '@huggingface/transformers'; + +const model = process.argv[2]; +const cacheDir = process.argv[3]; +if (cacheDir && env) env.cacheDir = cacheDir; + +let tokenizer; +let session; +try { + tokenizer = await AutoTokenizer.from_pretrained(model); + session = await AutoModelForSequenceClassification.from_pretrained(model, { dtype: 'fp32' }); + process.stdout.write(JSON.stringify({ type: 'ready' }) + '\\n'); +} catch (err) { + process.stdout.write(JSON.stringify({ type: 'error', error: err.message || String(err) }) + '\\n'); + process.exit(1); +} + +const sigmoid = (x) => 1 / (1 + Math.exp(-x)); + +const decoder = new TextDecoder(); +let buffer = ''; +let processing = false; +const queue = []; + +process.stdin.on('data', (chunk) => { + buffer += typeof chunk === 'string' ? chunk : decoder.decode(chunk, { stream: true }); + let idx; + while ((idx = buffer.indexOf('\\n')) !== -1) { + const line = buffer.slice(0, idx); + buffer = buffer.slice(idx + 1); + if (line.trim()) queue.push(line); + } + if (!processing) processQueue(); +}); + +async function processQueue() { + processing = true; + while (queue.length > 0) { + const line = queue.shift(); + let req; + try { req = JSON.parse(line); } catch { continue; } + try { + const { id, query, passages } = req; + if (!Array.isArray(passages) || passages.length === 0) { + process.stdout.write(JSON.stringify({ id, scores: [] }) + '\\n'); + continue; + } + const queries = passages.map(() => query); + const inputs = await tokenizer(queries, { + text_pair: passages, + padding: true, + truncation: true, + return_tensors: 'pt', + }); + const out = await session(inputs); + const logits = out.logits.data; + const scores = new Array(passages.length); + for (let i = 0; i < passages.length; i++) { + scores[i] = sigmoid(Number(logits[i])); + } + process.stdout.write(JSON.stringify({ id, scores }) + '\\n'); + } catch (err) { + process.stdout.write(JSON.stringify({ id: req?.id, error: err.message || String(err) }) + '\\n'); + } + } + processing = false; +} + +process.stdin.on('end', () => process.exit(0)); +`; +} + // ── Main manager ──────────────────────────────────────────────────── export class EmbeddingRuntimeManager { @@ -186,8 +272,12 @@ export class EmbeddingRuntimeManager { if (!manifest) return false; if (manifest.runtimeVersion !== RUNTIME_VERSION) return false; - // Verify the worker script exists and a bun binary is available - return existsSync(this.getWorkerPath()) && this.getBunPath() !== undefined; + // Verify both worker scripts exist and a bun binary is available + return ( + existsSync(this.getWorkerPath()) && + existsSync(this.getRerankWorkerPath()) && + this.getBunPath() !== undefined + ); } /** Path to the embed worker script. */ @@ -195,6 +285,11 @@ export class EmbeddingRuntimeManager { return join(this.baseDir, WORKER_FILENAME); } + /** Path to the rerank worker script. */ + getRerankWorkerPath(): string { + return join(this.baseDir, RERANK_WORKER_FILENAME); + } + /** * Find a usable bun binary. * Delegates to the shared bun-runtime helper, also checking @@ -375,8 +470,12 @@ export class EmbeddingRuntimeManager { ].join("\n"), ); - // Step 4: Write embed worker script + // Step 4: Write embed + rerank worker scripts writeFileSync(join(tmpDir, WORKER_FILENAME), generateWorkerScript()); + writeFileSync( + join(tmpDir, RERANK_WORKER_FILENAME), + generateRerankWorkerScript(), + ); // Step 5: Write version manifest const manifest: VersionManifest = { diff --git a/assistant/src/memory/rerank-local.ts b/assistant/src/memory/rerank-local.ts new file mode 100644 index 00000000000..71a9b95093c --- /dev/null +++ b/assistant/src/memory/rerank-local.ts @@ -0,0 +1,351 @@ +/** Local cross-encoder rerank backend — drives the rerank-worker subprocess. */ +import { existsSync } from "node:fs"; + +import { getLogger } from "../util/logger.js"; +import { getEmbeddingModelsDir } from "../util/platform.js"; +import { PromiseGuard } from "../util/promise-guard.js"; +import { EmbeddingRuntimeManager } from "./embedding-runtime-manager.js"; + +const log = getLogger("memory-rerank-local"); + +interface WorkerResponse { + id?: number; + type?: string; + scores?: number[]; + error?: string; +} + +export class LocalRerankBackend { + readonly model: string; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + private workerProc: any = null; + private stdoutBuffer = ""; + private requestCounter = 0; + private pendingRequests = new Map< + number, + { resolve: (response: WorkerResponse) => void } + >(); + private stdoutReaderActive = false; + private activeRequests = 0; + private disposeRequested = false; + + private readyResolve: (() => void) | null = null; + private readyReject: ((err: Error) => void) | null = null; + + private readonly initGuard = new PromiseGuard(); + + constructor(model: string) { + this.model = model; + } + + /** Score `(query, passages[i])` pairs in one batched ONNX inference call. */ + async score(query: string, passages: string[]): Promise { + if (this.disposeRequested) { + throw new Error("Local rerank backend is shutting down"); + } + if (passages.length === 0) return []; + + this.activeRequests++; + try { + await this.ensureInitialized(); + const response = await this.sendRequest({ query, passages }); + if (response.error) { + throw new Error(`Rerank worker error: ${response.error}`); + } + if (!response.scores) { + throw new Error("Rerank worker returned no scores"); + } + if (response.scores.length !== passages.length) { + throw new Error( + `Rerank worker returned ${response.scores.length} scores for ${passages.length} passages`, + ); + } + return response.scores; + } finally { + this.activeRequests--; + this.disposeIfIdle(); + } + } + + dispose(): void { + this.disposeRequested = true; + this.disposeIfIdle(); + } + + private sendRequest(payload: { + query: string; + passages: string[]; + }): Promise { + const id = ++this.requestCounter; + return new Promise((resolve) => { + if (!this.workerProc) { + resolve({ error: "Worker not initialized" }); + return; + } + this.pendingRequests.set(id, { resolve }); + this.workerProc.stdin.write(JSON.stringify({ id, ...payload }) + "\n"); + try { + this.workerProc.stdin.flush(); + } catch { + // Worker may have exited — stdout reader cleanup resolves pending requests. + } + }); + } + + private async ensureInitialized(): Promise { + if (this.workerProc) return; + await this.initGuard.run(() => this.initialize()); + } + + private async initialize(): Promise { + log.info({ model: this.model }, "Initializing local rerank backend"); + + const runtimeManager = new EmbeddingRuntimeManager(); + if (!runtimeManager.isReady()) { + log.info("Embedding runtime not yet available, waiting for download..."); + await runtimeManager.ensureInstalled(); + } + + const bunPath = runtimeManager.getBunPath(); + const workerPath = runtimeManager.getRerankWorkerPath(); + + if (!bunPath) { + throw new Error("Local rerank backend unavailable: no bun binary found"); + } + if (!existsSync(workerPath)) { + throw new Error( + `Local rerank backend unavailable: worker script not found at ${workerPath}`, + ); + } + + await this.startWorker(bunPath, workerPath); + } + + private async startWorker( + bunPath: string, + workerPath: string, + ): Promise { + const embeddingModelsDir = getEmbeddingModelsDir(); + const modelCacheDir = `${embeddingModelsDir}/model-cache`; + + log.info( + { bunPath, workerPath, model: this.model }, + "Spawning rerank worker process", + ); + + const proc = Bun.spawn({ + cmd: [bunPath, "--smol", workerPath, this.model, modelCacheDir], + stdin: "pipe", + stdout: "pipe", + stderr: "pipe", + cwd: embeddingModelsDir, + }); + + this.workerProc = proc; + this.startStdoutReader(); + + try { + await this.waitForReady(); + } catch (err) { + this.workerProc = null; + this.stdoutReaderActive = false; + try { + proc.kill(); + } catch { + /* may already be dead */ + } + const exitCode = await proc.exited.catch(() => undefined); + const stderr = await new Response(proc.stderr).text().catch(() => ""); + if (stderr.trim()) { + log.warn({ stderr: stderr.trim(), exitCode }, "Rerank worker stderr"); + } + throw new Error( + `Rerank worker exited (code ${exitCode ?? "unknown"}): ${ + stderr.trim() || (err instanceof Error ? err.message : String(err)) + }`, + ); + } + + this.drainStderr(proc.stderr); + log.info( + { pid: proc.pid, model: this.model }, + "Rerank worker process started", + ); + this.disposeIfIdle(); + } + + private drainStderr(stderr: ReadableStream): void { + const reader = stderr.getReader(); + const decoder = new TextDecoder(); + (async () => { + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + const text = decoder.decode(value, { stream: true }).trim(); + if (text) log.debug({ workerStderr: text }, "Rerank worker stderr"); + } + } catch { + /* expected on shutdown */ + } + })(); + } + + private startStdoutReader(): void { + if (this.stdoutReaderActive || !this.workerProc) return; + this.stdoutReaderActive = true; + + const proc = this.workerProc; + const reader = proc.stdout.getReader(); + const decoder = new TextDecoder(); + + (async () => { + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + this.stdoutBuffer += decoder.decode(value, { stream: true }); + this.processStdoutBuffer(); + } + } catch { + /* reader cancelled or stream errored */ + } + + if (this.workerProc === proc) { + for (const pending of this.pendingRequests.values()) { + pending.resolve({ + error: "Rerank worker process exited unexpectedly", + }); + } + this.pendingRequests.clear(); + this.workerProc = null; + this.stdoutReaderActive = false; + this.stdoutBuffer = ""; + this.initGuard.reset(); + } + })(); + } + + private processStdoutBuffer(): void { + let idx: number; + while ((idx = this.stdoutBuffer.indexOf("\n")) !== -1) { + const line = this.stdoutBuffer.slice(0, idx); + this.stdoutBuffer = this.stdoutBuffer.slice(idx + 1); + if (!line.trim()) continue; + + let msg: WorkerResponse; + try { + msg = JSON.parse(line); + } catch { + continue; + } + + if (msg.type === "ready") { + this.readyResolve?.(); + this.readyResolve = null; + this.readyReject = null; + continue; + } + if (msg.type === "error" && this.readyReject) { + this.readyReject( + new Error(msg.error ?? "Worker initialization failed"), + ); + this.readyResolve = null; + this.readyReject = null; + continue; + } + + if (msg.id !== undefined) { + const pending = this.pendingRequests.get(msg.id); + if (pending) { + this.pendingRequests.delete(msg.id); + pending.resolve(msg); + this.disposeIfIdle(); + } + } + } + } + + private waitForReady(): Promise { + return new Promise((resolve, reject) => { + // First-call timeout. Generous because the first run downloads the + // ONNX weights (~280 MB to ~1 GB depending on model) before loading. + const timeout = setTimeout(() => { + this.readyResolve = null; + this.readyReject = null; + reject(new Error("Rerank worker timed out waiting for model to load")); + }, 120_000); + + this.readyResolve = () => { + clearTimeout(timeout); + resolve(); + }; + this.readyReject = (err: Error) => { + clearTimeout(timeout); + reject(err); + }; + + this.workerProc?.exited.then(() => { + if (this.readyResolve) { + clearTimeout(timeout); + this.readyResolve = null; + this.readyReject = null; + reject( + new Error("Rerank worker process exited before becoming ready"), + ); + } + }); + }); + } + + private disposeIfIdle(): void { + if (!this.disposeRequested) return; + if (this.activeRequests > 0) return; + if (this.pendingRequests.size > 0) return; + if (this.readyResolve || this.readyReject) return; + + const proc = this.workerProc; + this.workerProc = null; + this.stdoutReaderActive = false; + this.stdoutBuffer = ""; + this.initGuard.reset(); + + if (!proc) return; + + try { + proc.kill(); + } catch { + /* may already be exiting */ + } + } +} + +// ── Module-level singleton management ───────────────────────────────── + +let _backend: LocalRerankBackend | null = null; + +export function getOrCreateRerankBackend(model: string): LocalRerankBackend { + if (_backend?.model === model) return _backend; + if (_backend) { + try { + _backend.dispose(); + } catch { + /* best effort */ + } + } + _backend = new LocalRerankBackend(model); + return _backend; +} + +/** @internal Test-only: reset the cached backend. */ +export function _resetRerankBackendForTests(): void { + if (_backend) { + try { + _backend.dispose(); + } catch { + /* best effort */ + } + } + _backend = null; +} diff --git a/assistant/src/memory/v2/__tests__/activation.test.ts b/assistant/src/memory/v2/__tests__/activation.test.ts index c504ea23b81..206790ef510 100644 --- a/assistant/src/memory/v2/__tests__/activation.test.ts +++ b/assistant/src/memory/v2/__tests__/activation.test.ts @@ -138,6 +138,30 @@ mock.module("@qdrant/js-client-rest", () => ({ QdrantClient: MockQdrantClient, })); +// Reranker mock — keeps the activation tests hermetic when rerank.enabled is +// flipped on by an integration case. Tests stage `rerankState.scores` to +// program the boost outcome. +const rerankState = { + scores: null as Map | null, + calls: [] as Array<{ query: string; candidates: string[] }>, +}; +mock.module("../reranker.js", () => ({ + rerankCandidates: async ( + query: string, + candidates: readonly string[], + ): Promise> => { + rerankState.calls.push({ query, candidates: [...candidates] }); + if (rerankState.scores === null) return new Map(); + const out = new Map(); + for (const slug of candidates) { + const v = rerankState.scores.get(slug); + if (v !== undefined) out.set(slug, v); + } + return out; + }, + _resetRerankCacheForTests: () => {}, +})); + // Static `import type` is fine — types erase, so they don't run module-init // code that would race the mocks above. import type { EdgeIndex } from "../edge-index.js"; @@ -169,6 +193,8 @@ function resetState(): void { state.skillQueryResponses.dense.length = 0; state.skillQueryResponses.sparse.length = 0; state.queryCalls.length = 0; + rerankState.scores = null; + rerankState.calls.length = 0; // Bun's `mock.module` persists across files in the same process, so the // qdrant modules' `_client` singletons may already hold a MockQdrantClient // instance from a sibling test file (e.g. sim.test.ts). Resetting both the @@ -554,6 +580,63 @@ describe("computeOwnActivation", () => { // No prior state → prev=0 → priorContribution=0 regardless of `d`. expect(out.breakdown.get("fresh")?.priorContribution).toBe(0); }); + + test("rerank boost on user/assistant flips top-1 when fused had it second", async () => { + // Three Qdrant queries fire in parallel inside computeOwnActivation: + // user, assistant, now. Stage identical hits for each so the only signal + // separating slugs is the rerank boost on the user + assistant channels. + const stagedHits = [ + { slug: "lexical", denseScore: 0.6, sparseScore: 0 }, + { slug: "semantic", denseScore: 0.5, sparseScore: 0 }, + ]; + stageHybridResponse(stagedHits); // user channel + stageHybridResponse(stagedHits); // assistant channel + stageHybridResponse(stagedHits); // now channel + rerankState.scores = new Map([ + ["lexical", 0.05], + ["semantic", 0.95], + ]); + + const config = { + memory: { + v2: { + d: 0.0, + c_user: 0.5, + c_assistant: 0.5, + c_now: 0.0, + dense_weight: 1.0, + sparse_weight: 0.0, + rerank: { + enabled: true, + top_k: 50, + alpha: 0.5, + model: "test-model", + }, + }, + }, + } as unknown as AssistantConfig; + + const out = await computeOwnActivation({ + candidates: new Set(["lexical", "semantic"]), + priorState: null, + userText: "u", + assistantText: "a", + nowText: "n", + config, + }); + + // Without rerank: lexical (0.6) would beat semantic (0.5) on both + // user and assistant channels. + // With rerank (alpha=0.5): + // lexical: 0.6 + 0.5 · (0.05/0.95) ≈ 0.626 + // semantic: 0.5 + 0.5 · 1.0 = 1.0 + // The semantic candidate now wins on both rerank-boosted channels. + expect(out.activation.get("semantic")!).toBeGreaterThan( + out.activation.get("lexical")!, + ); + // Rerank should have been called once per rerank-enabled channel. + expect(rerankState.calls).toHaveLength(2); + }); }); // --------------------------------------------------------------------------- diff --git a/assistant/src/memory/v2/__tests__/reranker.test.ts b/assistant/src/memory/v2/__tests__/reranker.test.ts new file mode 100644 index 00000000000..c2902692360 --- /dev/null +++ b/assistant/src/memory/v2/__tests__/reranker.test.ts @@ -0,0 +1,218 @@ +/** + * Tests for `memory/v2/reranker.ts` — public `rerankCandidates` function. + * + * Mocks the underlying `LocalRerankBackend` and the `readPage` page reader so + * the test is hermetic (no subprocess, no filesystem). Verifies the public + * contract: scores keyed by slug, fail-open on backend failure, page-read + * failures drop slugs silently, LRU cache hits skip the backend. + */ +import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; + +import { makeMockLogger } from "../../../__tests__/helpers/mock-logger.js"; +import type { AssistantConfig } from "../../../config/types.js"; + +mock.module("../../../util/logger.js", () => ({ + getLogger: () => makeMockLogger(), +})); + +mock.module("../../../util/platform.js", () => ({ + getWorkspaceDir: () => "/tmp/test-workspace", +})); + +const backendState = { + scores: [] as number[], + shouldThrow: false, + calls: [] as Array<{ query: string; passages: string[] }>, +}; +mock.module("../../rerank-local.js", () => ({ + getOrCreateRerankBackend: (_model: string) => ({ + score: async (query: string, passages: string[]): Promise => { + backendState.calls.push({ query, passages: [...passages] }); + if (backendState.shouldThrow) throw new Error("backend down"); + return backendState.scores.slice(0, passages.length); + }, + }), +})); + +const pageState = { + pages: new Map(), + failingSlugs: new Set(), +}; +// Partial mock — Bun's `mock.module` is process-wide, so we re-export every +// real symbol and override only `readPage`. Without this, sibling test files +// that import `listPages` etc. would crash with "Export not found". +const realPageStore = await import("../page-store.js"); +mock.module("../page-store.js", () => ({ + ...realPageStore, + readPage: async (_dir: string, slug: string) => { + if (pageState.failingSlugs.has(slug)) { + throw new Error("read failure"); + } + return pageState.pages.get(slug) ?? null; + }, +})); + +const { rerankCandidates, _resetRerankCacheForTests } = + await import("../reranker.js"); + +function configWithModel(model = "test-model"): AssistantConfig { + return { + memory: { + v2: { + rerank: { model, enabled: true, top_k: 50, alpha: 0.3 }, + }, + }, + } as unknown as AssistantConfig; +} + +function resetState() { + backendState.scores = []; + backendState.shouldThrow = false; + backendState.calls.length = 0; + pageState.pages.clear(); + pageState.failingSlugs.clear(); + _resetRerankCacheForTests(); +} + +beforeEach(resetState); +afterEach(resetState); + +describe("rerankCandidates", () => { + test("returns empty map for empty candidates", async () => { + const out = await rerankCandidates("query", [], configWithModel()); + expect(out.size).toBe(0); + expect(backendState.calls).toHaveLength(0); + }); + + test("returns empty map for whitespace-only query", async () => { + pageState.pages.set("a", { body: "content" }); + const out = await rerankCandidates(" ", ["a"], configWithModel()); + expect(out.size).toBe(0); + expect(backendState.calls).toHaveLength(0); + }); + + test("scores returned keyed by slug, in [0, 1]", async () => { + pageState.pages.set("a", { body: "first paragraph of a" }); + pageState.pages.set("b", { body: "first paragraph of b" }); + backendState.scores = [0.9, 0.1]; + + const out = await rerankCandidates("query", ["a", "b"], configWithModel()); + + expect(out.get("a")).toBe(0.9); + expect(out.get("b")).toBe(0.1); + }); + + test("clamps scores to [0, 1]", async () => { + pageState.pages.set("a", { body: "x" }); + pageState.pages.set("b", { body: "x" }); + backendState.scores = [1.5, -0.2]; + + const out = await rerankCandidates("query", ["a", "b"], configWithModel()); + + expect(out.get("a")).toBe(1); + expect(out.get("b")).toBe(0); + }); + + test("drops slugs whose page failed to read; others present", async () => { + pageState.pages.set("a", { body: "x" }); + pageState.failingSlugs.add("b"); + pageState.pages.set("c", { body: "y" }); + backendState.scores = [0.5, 0.7]; + + const out = await rerankCandidates( + "query", + ["a", "b", "c"], + configWithModel(), + ); + + expect(out.has("b")).toBe(false); + expect(out.get("a")).toBe(0.5); + expect(out.get("c")).toBe(0.7); + }); + + test("drops slugs whose page is null (missing on disk)", async () => { + pageState.pages.set("a", { body: "x" }); + pageState.pages.set("missing", null); + backendState.scores = [0.5]; + + const out = await rerankCandidates( + "query", + ["a", "missing"], + configWithModel(), + ); + + expect(out.size).toBe(1); + expect(out.get("a")).toBe(0.5); + expect(out.has("missing")).toBe(false); + }); + + test("returns empty map when backend throws (fail-open)", async () => { + pageState.pages.set("a", { body: "x" }); + backendState.shouldThrow = true; + + const out = await rerankCandidates("query", ["a"], configWithModel()); + + expect(out.size).toBe(0); + }); + + test("returns empty map when no pages load (no backend call)", async () => { + pageState.failingSlugs.add("a"); + + const out = await rerankCandidates("query", ["a"], configWithModel()); + + expect(out.size).toBe(0); + expect(backendState.calls).toHaveLength(0); + }); + + test("LRU cache hit skips the backend on identical inputs", async () => { + pageState.pages.set("a", { body: "x" }); + backendState.scores = [0.7]; + + const first = await rerankCandidates("query", ["a"], configWithModel()); + const second = await rerankCandidates("query", ["a"], configWithModel()); + + expect(first.get("a")).toBe(0.7); + expect(second.get("a")).toBe(0.7); + // Backend called only once — second call hit the cache. + expect(backendState.calls).toHaveLength(1); + }); + + test("cache key insensitive to candidate order", async () => { + pageState.pages.set("a", { body: "x" }); + pageState.pages.set("b", { body: "y" }); + backendState.scores = [0.5, 0.6]; + + await rerankCandidates("query", ["a", "b"], configWithModel()); + await rerankCandidates("query", ["b", "a"], configWithModel()); + + // Same query, same set of candidates — second call hits cache. + expect(backendState.calls).toHaveLength(1); + }); + + test("passage construction caps at 240 chars after slug newline", async () => { + const longBody = "a".repeat(500); + pageState.pages.set("slug", { body: longBody }); + backendState.scores = [0.5]; + + await rerankCandidates("q", ["slug"], configWithModel()); + + expect(backendState.calls).toHaveLength(1); + const passage = backendState.calls[0].passages[0]; + // "slug\n" prefix + 240 chars of body + expect(passage.startsWith("slug\n")).toBe(true); + expect(passage.length).toBeLessThanOrEqual(5 + 240); + }); + + test("first paragraph is taken (body truncated at blank line)", async () => { + pageState.pages.set("slug", { + body: "first para line\n\nsecond para should not appear", + }); + backendState.scores = [0.5]; + + await rerankCandidates("q", ["slug"], configWithModel()); + + const passage = backendState.calls[0].passages[0]; + expect(passage).toContain("first para line"); + expect(passage).not.toContain("second para"); + }); +}); diff --git a/assistant/src/memory/v2/__tests__/sim.test.ts b/assistant/src/memory/v2/__tests__/sim.test.ts index 7250f7d2cd1..cbb5d303c0a 100644 --- a/assistant/src/memory/v2/__tests__/sim.test.ts +++ b/assistant/src/memory/v2/__tests__/sim.test.ts @@ -159,6 +159,31 @@ mock.module("@qdrant/js-client-rest", () => ({ QdrantClient: MockQdrantClient, })); +// Reranker mock — allows boost-mode tests to programmatically supply scores +// without spinning up the cross-encoder subprocess. +const rerankState = { + scores: null as Map | null, + shouldThrow: false, + calls: [] as Array<{ query: string; candidates: string[] }>, +}; +mock.module("../reranker.js", () => ({ + rerankCandidates: async ( + query: string, + candidates: readonly string[], + ): Promise> => { + rerankState.calls.push({ query, candidates: [...candidates] }); + if (rerankState.shouldThrow) throw new Error("rerank disabled in test"); + if (rerankState.scores === null) return new Map(); + const out = new Map(); + for (const slug of candidates) { + const v = rerankState.scores.get(slug); + if (v !== undefined) out.set(slug, v); + } + return out; + }, + _resetRerankCacheForTests: () => {}, +})); + const { simBatch, simSkillBatch, clamp01, effectiveWeights } = await import("../sim.js"); const { _resetMemoryV2SkillQdrantForTests } = @@ -178,6 +203,9 @@ function resetState(): void { state.skillQueryResponses.dense.length = 0; state.skillQueryResponses.sparse.length = 0; state.queryCalls.length = 0; + rerankState.scores = null; + rerankState.shouldThrow = false; + rerankState.calls.length = 0; // Bun's `mock.module` persists across files in the same process, so the // qdrant modules' singletons may already hold a MockQdrantClient instance // from a sibling test file. Reset both readiness caches so each test in @@ -707,3 +735,151 @@ describe("simSkillBatch", () => { expect(state.sparseCalls).toEqual(["hello skill"]); }); }); + +// --------------------------------------------------------------------------- +// simBatch — cross-encoder rerank boost +// --------------------------------------------------------------------------- + +describe("simBatch with rerank boost", () => { + // dense_weight=1.0 / sparse_weight=0 so the fused score equals the dense + // input directly — keeps the boost-math arithmetic readable in assertions. + // The validator that requires the weights to sum to 1.0 only runs when the + // schema is parsed; tests cast partial objects so it never fires. + function configWithRerank(overrides: { + enabled: boolean; + top_k?: number; + alpha?: number; + }): AssistantConfig { + return { + memory: { + v2: { + dense_weight: 1.0, + sparse_weight: 0.0, + rerank: { + enabled: overrides.enabled, + top_k: overrides.top_k ?? 50, + alpha: overrides.alpha ?? 0.3, + model: "test-model", + }, + }, + }, + } as unknown as AssistantConfig; + } + + test("boosts top-K fused scores by alpha · normalized rerank", async () => { + const config = configWithRerank({ enabled: true, top_k: 50, alpha: 0.4 }); + stageHybridResponse([ + { slug: "a", denseScore: 0.5 }, + { slug: "b", denseScore: 0.4 }, + { slug: "c", denseScore: 0.3 }, + ]); + rerankState.scores = new Map([ + ["a", 0.2], // normalised → 0.2 / 0.8 = 0.25 + ["b", 0.8], // normalised → 1.0 (max) + ["c", 0.4], // normalised → 0.5 + ]); + + const out = await simBatch("query", ["a", "b", "c"], config, { + useRerank: true, + }); + + // a: clamp01(0.5 + 0.4·0.25) = 0.6 + // b: clamp01(0.4 + 0.4·1.0) = 0.8 + // c: clamp01(0.3 + 0.4·0.5) = 0.5 + expect(out.get("a")).toBeCloseTo(0.6); + expect(out.get("b")).toBeCloseTo(0.8); + expect(out.get("c")).toBeCloseTo(0.5); + }); + + test("rerank flips ranking when its top hit was dense's #2", async () => { + const config = configWithRerank({ enabled: true, alpha: 0.5 }); + stageHybridResponse([ + { slug: "lexical-match", denseScore: 0.55 }, + { slug: "semantic-match", denseScore: 0.45 }, + ]); + rerankState.scores = new Map([ + ["lexical-match", 0.05], + ["semantic-match", 0.9], + ]); + + const out = await simBatch( + "query", + ["lexical-match", "semantic-match"], + config, + { useRerank: true }, + ); + + // lexical-match: 0.55 + 0.5 · (0.05/0.9) ≈ 0.578 + // semantic-match: 0.45 + 0.5 · 1.0 = 0.95 + expect(out.get("semantic-match")!).toBeGreaterThan( + out.get("lexical-match")!, + ); + }); + + test("only top-K candidates get reranked; tail keeps pure fused", async () => { + const config = configWithRerank({ enabled: true, top_k: 2, alpha: 0.5 }); + stageHybridResponse([ + { slug: "a", denseScore: 0.9 }, + { slug: "b", denseScore: 0.7 }, + { slug: "c", denseScore: 0.3 }, // tail — outside top_k=2 + ]); + rerankState.scores = new Map([ + ["a", 0.5], + ["b", 1.0], + ["c", 1.0], // would lift but reranker is never called for it + ]); + + const out = await simBatch("query", ["a", "b", "c"], config, { + useRerank: true, + }); + + expect(rerankState.calls).toHaveLength(1); + expect(rerankState.calls[0].candidates).toEqual(["a", "b"]); + expect(out.get("c")).toBeCloseTo(0.3); // unchanged + }); + + test("returns pure fused when useRerank: true but rerank.enabled: false", async () => { + const config = configWithRerank({ enabled: false }); + stageHybridResponse([{ slug: "a", denseScore: 0.5 }]); + rerankState.scores = new Map([["a", 1.0]]); + + const out = await simBatch("query", ["a"], config, { useRerank: true }); + + expect(rerankState.calls).toHaveLength(0); + expect(out.get("a")).toBeCloseTo(0.5); // no boost applied + }); + + test("returns pure fused when reranker returns empty (fail-open)", async () => { + const config = configWithRerank({ enabled: true }); + stageHybridResponse([{ slug: "a", denseScore: 0.5 }]); + // The real `rerankCandidates` swallows worker errors and returns an + // empty Map — `applyRerankBoost` short-circuits on empty. + rerankState.scores = new Map(); + + const out = await simBatch("query", ["a"], config, { useRerank: true }); + + expect(out.get("a")).toBeCloseTo(0.5); // no boost + }); + + test("useRerank not passed — boost path doesn't run even when enabled", async () => { + const config = configWithRerank({ enabled: true }); + stageHybridResponse([{ slug: "a", denseScore: 0.5 }]); + rerankState.scores = new Map([["a", 1.0]]); + + const out = await simBatch("query", ["a"], config); + + expect(rerankState.calls).toHaveLength(0); + expect(out.get("a")).toBeCloseTo(0.5); + }); + + test("clamps boosted score to <= 1", async () => { + const config = configWithRerank({ enabled: true, alpha: 1.0 }); + stageHybridResponse([{ slug: "a", denseScore: 0.95 }]); + rerankState.scores = new Map([["a", 0.8]]); + + const out = await simBatch("query", ["a"], config, { useRerank: true }); + + // 0.95 + 1.0 · 1.0 = 1.95 → clamped to 1.0 + expect(out.get("a")).toBe(1); + }); +}); diff --git a/assistant/src/memory/v2/activation.ts b/assistant/src/memory/v2/activation.ts index 4cd50d79388..e4cc112d629 100644 --- a/assistant/src/memory/v2/activation.ts +++ b/assistant/src/memory/v2/activation.ts @@ -204,9 +204,11 @@ export async function computeOwnActivation( const { d, c_user, c_assistant, c_now } = config.memory.v2; const slugList = [...candidates]; + // NOW context is structured (timestamps, current focus) — outside the + // cross-encoder's training distribution, so it stays on pure fused fusion. const [simUser, simAssistant, simNow] = await Promise.all([ - simBatch(userText, slugList, config), - simBatch(assistantText, slugList, config), + simBatch(userText, slugList, config, { useRerank: true }), + simBatch(assistantText, slugList, config, { useRerank: true }), simBatch(nowText, slugList, config), ]); diff --git a/assistant/src/memory/v2/reranker.ts b/assistant/src/memory/v2/reranker.ts new file mode 100644 index 00000000000..b7638fa09aa --- /dev/null +++ b/assistant/src/memory/v2/reranker.ts @@ -0,0 +1,126 @@ +/** Memory v2 cross-encoder rerank — `(query, page-preview)` pairs scored by a local model. */ + +import { createHash } from "node:crypto"; + +import type { AssistantConfig } from "../../config/types.js"; +import { getLogger } from "../../util/logger.js"; +import { getWorkspaceDir } from "../../util/platform.js"; +import { getOrCreateRerankBackend } from "../rerank-local.js"; +import { readPage } from "./page-store.js"; + +const log = getLogger("memory-v2-reranker"); + +// ~512-token model context for bge-reranker-base; cap input to bound payload. +const PASSAGE_CHAR_CAP = 240; + +interface CacheEntry { + scores: Map; + ts: number; +} + +const CACHE_TTL_MS = 2 * 60 * 1000; +const CACHE_MAX_ENTRIES = 64; +const cache = new Map(); + +function cacheKey(query: string, slugs: readonly string[]): string { + const sorted = [...slugs].sort().join("\0"); + return createHash("sha256").update(`${query}\0${sorted}`).digest("hex"); +} + +function evictExpired(now: number): void { + for (const [k, v] of cache) { + if (now - v.ts > CACHE_TTL_MS) cache.delete(k); + } + if (cache.size > CACHE_MAX_ENTRIES) { + const toDrop = cache.size - CACHE_MAX_ENTRIES; + let i = 0; + for (const k of cache.keys()) { + if (i++ >= toDrop) break; + cache.delete(k); + } + } +} + +function buildPassage(slug: string, body: string): string { + const trimmed = body.replace(/^\s+/, ""); + const blank = trimmed.search(/\n\s*\n/); + const para = blank === -1 ? trimmed : trimmed.slice(0, blank); + const stripped = para.replace(/^#+\s.*\n/, "").trim(); + const compact = stripped.replace(/\s+/g, " ").slice(0, PASSAGE_CHAR_CAP); + return `${slug}\n${compact}`; +} + +/** + * Run the cross-encoder over each candidate's first-paragraph preview. + * Returns raw sigmoid scores; failures (worker down, page read error) yield + * an empty Map so callers can fall back to pure fused scores. Per-batch + * normalisation and boost math live in `simBatch.applyRerankBoost`. + */ +export async function rerankCandidates( + query: string, + candidates: readonly string[], + config: AssistantConfig, +): Promise> { + if (candidates.length === 0 || query.trim().length === 0) { + return new Map(); + } + + const now = Date.now(); + evictExpired(now); + const key = cacheKey(query, candidates); + const cached = cache.get(key); + if (cached) { + // Refresh insertion order so frequently-hit entries survive eviction. + cache.delete(key); + cache.set(key, { ...cached, ts: now }); + return new Map(cached.scores); + } + + const workspaceDir = getWorkspaceDir(); + const pages = await Promise.all( + candidates.map((slug) => + readPage(workspaceDir, slug).catch((err) => { + log.debug({ err, slug }, "Reranker skipping page that failed to load"); + return null; + }), + ), + ); + const passages: string[] = []; + const slugsForPassages: string[] = []; + for (let i = 0; i < candidates.length; i++) { + const page = pages[i]; + if (!page) continue; + passages.push(buildPassage(candidates[i], page.body)); + slugsForPassages.push(candidates[i]); + } + + if (passages.length === 0) return new Map(); + + let scores: number[]; + try { + const backend = getOrCreateRerankBackend(config.memory.v2.rerank.model); + scores = await backend.score(query, passages); + } catch (err) { + log.warn( + { err, model: config.memory.v2.rerank.model, n: passages.length }, + "Rerank backend failed; falling back to pure fused scores", + ); + return new Map(); + } + + const result = new Map(); + for (let i = 0; i < slugsForPassages.length; i++) { + const s = scores[i]; + if (typeof s !== "number" || Number.isNaN(s)) continue; + // sigmoid output should already be in [0, 1]; clamp defensively. + result.set(slugsForPassages[i], Math.max(0, Math.min(1, s))); + } + + cache.set(key, { scores: new Map(result), ts: now }); + return result; +} + +/** @internal Test-only: clear the LRU cache. */ +export function _resetRerankCacheForTests(): void { + cache.clear(); +} diff --git a/assistant/src/memory/v2/sim.ts b/assistant/src/memory/v2/sim.ts index 335615fe3e6..3853b9994f4 100644 --- a/assistant/src/memory/v2/sim.ts +++ b/assistant/src/memory/v2/sim.ts @@ -30,6 +30,7 @@ import { applyCorrectionIfCalibrated } from "../anisotropy.js"; import { embedWithBackend } from "../embedding-backend.js"; import { clampUnitInterval } from "../validation.js"; import { hybridQueryConceptPages } from "./qdrant.js"; +import { rerankCandidates } from "./reranker.js"; import { hybridQuerySkills } from "./skill-qdrant.js"; import { generateBm25QueryEmbedding } from "./sparse-bm25.js"; @@ -147,6 +148,7 @@ export async function simBatch( text: string, candidateSlugs: readonly string[], config: AssistantConfig, + options?: { useRerank?: boolean }, ): Promise> { if (candidateSlugs.length === 0) { return new Map(); @@ -192,9 +194,46 @@ export async function simBatch( for (const hit of hits) { scores.set(hit.slug, fuseHit(hit, maxSparse, denseWeight, sparseWeight)); } + + // Cross-encoder boost on top of the fused score for the top-K candidates. + // Optional-chain on `rerank` so test configs that omit it still type-check. + if (options?.useRerank === true && config.memory.v2.rerank?.enabled) { + return applyRerankBoost(text, scores, config); + } + return scores; } +async function applyRerankBoost( + query: string, + fused: Map, + config: AssistantConfig, +): Promise> { + const rerankCfg = config.memory.v2.rerank; + const sortedSlugs = [...fused.entries()] + .sort((a, b) => b[1] - a[1]) + .map(([slug]) => slug); + const topSlugs = sortedSlugs.slice(0, rerankCfg.top_k); + if (topSlugs.length === 0) return fused; + + const rerank = await rerankCandidates(query, topSlugs, config); + if (rerank.size === 0) return fused; + + let maxRerank = 0; + for (const v of rerank.values()) { + if (v > maxRerank) maxRerank = v; + } + if (maxRerank === 0) return fused; + + const out = new Map(fused); + for (const [slug, raw] of rerank) { + const r_norm = raw / maxRerank; + const base = fused.get(slug) ?? 0; + out.set(slug, clampUnitInterval(base + rerankCfg.alpha * r_norm)); + } + return out; +} + /** * Compute hybrid (dense + sparse) similarity scores between a query text and * a fixed set of candidate skill ids. Mirrors `simBatch` but targets the