Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions assistant/src/config/schemas/__tests__/memory-v2.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ describe("MemoryV2ConfigSchema", () => {
consolidation_interval_hours: 4,
max_page_chars: 5000,
consolidation_prompt_path: null,
rerank: {
enabled: false,
top_k: 50,
alpha: 0.3,
model: "Xenova/bge-reranker-base",
},
});
});

Expand Down
48 changes: 48 additions & 0 deletions assistant/src/config/schemas/memory-v2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ import { z } from "zod";
*/
const WEIGHT_SUM_TOLERANCE = 0.001;

/**
* Default cross-encoder model for memory v2 reranking. `BAAI/bge-reranker-v2-m3`
* is the long-term target but currently lacks a public ONNX export; the
* `Xenova/bge-reranker-base` (278M, MIT, ONNX-converted) is the working pick.
*/
const DEFAULT_RERANK_MODEL = "Xenova/bge-reranker-base";

/**
* Memory v2 (concept-page activation model) configuration.
*
Expand Down Expand Up @@ -192,6 +199,47 @@ export const MemoryV2ConfigSchema = z
.describe(
"Optional path to a file whose contents replace the bundled consolidation prompt. Absolute paths are used as-is, a leading `~/` is expanded to the home directory, otherwise the path is resolved under the workspace root. The loaded contents may include `{{CUTOFF}}`, which is substituted with the run's ISO-8601 cutoff timestamp. If the file is missing, unreadable, or empty, the bundled prompt is used and a warning is logged.",
),
rerank: z
.object({
enabled: z
.boolean()
.default(false)
.describe(
"Whether to apply cross-encoder reranking as an additive boost to the user + assistant similarity channels. Disabled by default — opt in once measured.",
),
top_k: z
.number()
.int()
.positive()
.max(200)
.default(50)
.describe(
"Number of top-fused candidates per `simBatch` call to send through the reranker. Tail candidates keep their pure fused score.",
),
alpha: z
.number()
.min(0)
.max(1)
.default(0.3)
.describe(
"Boost weight: `boosted = clamp01(fused + alpha · normalized_rerank)`. Top reranker hit can lift its fused score by up to `alpha`; bottom of top_k stays roughly unchanged.",
),
model: z
.string()
.default(DEFAULT_RERANK_MODEL)
.describe(
"HuggingFace model id for the cross-encoder. Must have an ONNX export reachable from huggingface.co/<model>/resolve/main/onnx/model.onnx.",
),
})
.default({
enabled: false,
top_k: 50,
alpha: 0.3,
model: DEFAULT_RERANK_MODEL,
})
.describe(
"Cross-encoder rerank configuration. When enabled, runs a local cross-encoder over the top-K fused candidates per `simBatch(useRerank: true)` call and adds an alpha-weighted normalized boost to their fused scores.",
),
})
.describe(
"Memory v2 — concept-page activation model with hourly LLM-driven consolidation",
Expand Down
105 changes: 102 additions & 3 deletions assistant/src/memory/embedding-runtime-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ const JINJA_VERSION = "0.5.5";
const RUNTIME_VERSION = `ort-${ONNXRUNTIME_NODE_VERSION}_hf-${TRANSFORMERS_VERSION}_jinja-${JINJA_VERSION}`;

const WORKER_FILENAME = "embed-worker.mjs";
const RERANK_WORKER_FILENAME = "rerank-worker.mjs";

/** Module-level guard so concurrent in-process calls share one download. */
const installGuard = new PromiseGuard<void>();
Expand Down Expand Up @@ -171,6 +172,91 @@ process.stdin.on('end', () => process.exit(0));
`;
}

function generateRerankWorkerScript(): string {
// Cross-encoder rerank worker. Loads a sequence-classification model and
// scores (query, passage) pairs in batched ONNX inference calls. Mirrors
// the embed worker's lifecycle (ready signal, JSON-lines IPC, sequential
// queue) so LocalRerankBackend can reuse the same supervisor pattern.
return `\
// rerank-worker.mjs — Auto-generated by EmbeddingRuntimeManager
// Runs in a separate bun process, communicates via JSON-lines over stdin/stdout.
process.title = 'rerank-worker';
import {
AutoModelForSequenceClassification,
AutoTokenizer,
env,
} from '@huggingface/transformers';

const model = process.argv[2];
const cacheDir = process.argv[3];
if (cacheDir && env) env.cacheDir = cacheDir;

let tokenizer;
let session;
try {
tokenizer = await AutoTokenizer.from_pretrained(model);
session = await AutoModelForSequenceClassification.from_pretrained(model, { dtype: 'fp32' });
process.stdout.write(JSON.stringify({ type: 'ready' }) + '\\n');
} catch (err) {
process.stdout.write(JSON.stringify({ type: 'error', error: err.message || String(err) }) + '\\n');
process.exit(1);
}

const sigmoid = (x) => 1 / (1 + Math.exp(-x));

const decoder = new TextDecoder();
let buffer = '';
let processing = false;
const queue = [];

process.stdin.on('data', (chunk) => {
buffer += typeof chunk === 'string' ? chunk : decoder.decode(chunk, { stream: true });
let idx;
while ((idx = buffer.indexOf('\\n')) !== -1) {
const line = buffer.slice(0, idx);
buffer = buffer.slice(idx + 1);
if (line.trim()) queue.push(line);
}
if (!processing) processQueue();
});

async function processQueue() {
processing = true;
while (queue.length > 0) {
const line = queue.shift();
let req;
try { req = JSON.parse(line); } catch { continue; }
try {
const { id, query, passages } = req;
if (!Array.isArray(passages) || passages.length === 0) {
process.stdout.write(JSON.stringify({ id, scores: [] }) + '\\n');
continue;
}
const queries = passages.map(() => query);
const inputs = await tokenizer(queries, {
text_pair: passages,
padding: true,
truncation: true,
return_tensors: 'pt',
});
Comment thread
siddseethepalli marked this conversation as resolved.
const out = await session(inputs);
const logits = out.logits.data;
const scores = new Array(passages.length);
for (let i = 0; i < passages.length; i++) {
scores[i] = sigmoid(Number(logits[i]));
}
process.stdout.write(JSON.stringify({ id, scores }) + '\\n');
} catch (err) {
process.stdout.write(JSON.stringify({ id: req?.id, error: err.message || String(err) }) + '\\n');
}
}
processing = false;
}

process.stdin.on('end', () => process.exit(0));
`;
}

// ── Main manager ────────────────────────────────────────────────────

export class EmbeddingRuntimeManager {
Expand All @@ -186,15 +272,24 @@ export class EmbeddingRuntimeManager {
if (!manifest) return false;
if (manifest.runtimeVersion !== RUNTIME_VERSION) return false;

// Verify the worker script exists and a bun binary is available
return existsSync(this.getWorkerPath()) && this.getBunPath() !== undefined;
// Verify both worker scripts exist and a bun binary is available
return (
existsSync(this.getWorkerPath()) &&
existsSync(this.getRerankWorkerPath()) &&
this.getBunPath() !== undefined
Comment thread
siddseethepalli marked this conversation as resolved.
);
Comment thread
siddseethepalli marked this conversation as resolved.
}

/** Path to the embed worker script. */
getWorkerPath(): string {
return join(this.baseDir, WORKER_FILENAME);
}

/** Path to the rerank worker script. */
getRerankWorkerPath(): string {
return join(this.baseDir, RERANK_WORKER_FILENAME);
}

/**
* Find a usable bun binary.
* Delegates to the shared bun-runtime helper, also checking
Expand Down Expand Up @@ -375,8 +470,12 @@ export class EmbeddingRuntimeManager {
].join("\n"),
);

// Step 4: Write embed worker script
// Step 4: Write embed + rerank worker scripts
writeFileSync(join(tmpDir, WORKER_FILENAME), generateWorkerScript());
writeFileSync(
join(tmpDir, RERANK_WORKER_FILENAME),
generateRerankWorkerScript(),
);

// Step 5: Write version manifest
const manifest: VersionManifest = {
Expand Down
Loading
Loading