Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8ac71ba
chore: Add new session-level service for getting embeddings of a spec…
kmruiz Oct 8, 2025
cb52116
chore: add unit tests to embedding validation
kmruiz Oct 8, 2025
082fce9
chore: add the ability to disable embedding validation
kmruiz Oct 9, 2025
ed7a16e
chore: Make sure that cache works
kmruiz Oct 9, 2025
d68deee
chore: Do not query for the embedding information if the validation i…
kmruiz Oct 9, 2025
32fe96d
chore: it can't be undefined anymore, so this check is useless
kmruiz Oct 9, 2025
2e013f8
chore: Embedding validation on insert and minor refactor of formatUnt…
kmruiz Oct 9, 2025
998cf1b
Merge remote-tracking branch 'origin/main' into chore/mcp-246
kmruiz Oct 9, 2025
81f9ddd
Update src/tools/mongodb/create/insertMany.ts
kmruiz Oct 9, 2025
0a1c789
chore: Add integration test for insert many
kmruiz Oct 13, 2025
c68e4ad
chore: Make eslint happy
kmruiz Oct 13, 2025
539c4a5
chore: test slightly older image of atlas-local in case it's broken i…
kmruiz Oct 13, 2025
44a3ce8
chore: increase timeout time for CI
kmruiz Oct 13, 2025
a5842ef
chore: minor fixes from the PR comments
kmruiz Oct 15, 2025
13c1c35
Merge remote-tracking branch 'origin/main' into chore/mcp-246
kmruiz Oct 15, 2025
a04c2f3
chore: Merge reliably search permission detection
kmruiz Oct 15, 2025
94fdcda
Merge branch 'main' into chore/mcp-246
kmruiz Oct 15, 2025
3264796
chore: cleanup embeddings cache when the connection is closed
kmruiz Oct 15, 2025
3b104b5
chore: clean up embeddings cache after creating an index
kmruiz Oct 15, 2025
19a333c
chore: simplify, assume search indexes are available just by listing …
kmruiz Oct 16, 2025
7eed735
chore: add the Manager suffix
kmruiz Oct 16, 2025
519a0c4
Merge branch 'main' into chore/mcp-246
kmruiz Oct 16, 2025
3d69362
Update src/common/search/vectorSearchEmbeddingsManager.ts
kmruiz Oct 16, 2025
debc6f9
chore: Remove unused error code and messages
kmruiz Oct 16, 2025
c0d9dee
chore: use ts private fields for now
kmruiz Oct 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/common/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ const OPTIONS = {
boolean: [
"apiDeprecationErrors",
"apiStrict",
"disableEmbeddingsValidation",
"help",
"indexCheck",
"ipv6",
Expand Down Expand Up @@ -183,6 +184,7 @@ export interface UserConfig extends CliOptions {
maxBytesPerQuery: number;
atlasTemporaryDatabaseUserLifetimeMs: number;
voyageApiKey: string;
disableEmbeddingsValidation: boolean;
}

export const defaultUserConfig: UserConfig = {
Expand Down Expand Up @@ -213,6 +215,7 @@ export const defaultUserConfig: UserConfig = {
maxBytesPerQuery: 16 * 1024 * 1024, // By default, we only return ~16 mb of data per query / aggregation
atlasTemporaryDatabaseUserLifetimeMs: 4 * 60 * 60 * 1000, // 4 hours
voyageApiKey: "",
disableEmbeddingsValidation: false,
};

export const config = setupUserConfig({
Expand Down
1 change: 1 addition & 0 deletions src/common/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export enum ErrorCodes {
MisconfiguredConnectionString = 1_000_001,
ForbiddenCollscan = 1_000_002,
ForbiddenWriteOperation = 1_000_003,
AtlasSearchNotAvailable = 1_000_004,
}

export class MongoDBError<ErrorCode extends ErrorCodes = ErrorCodes> extends Error {
Expand Down
186 changes: 186 additions & 0 deletions src/common/search/vectorSearchEmbeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
import { BSON, type Document } from "bson";
import type { UserConfig } from "../config.js";

export type VectorFieldIndexDefinition = {
type: "vector";
path: string;
numDimensions: number;
quantization: "none" | "scalar" | "binary";
similarity: "euclidean" | "cosine" | "dotProduct";
};

export type EmbeddingNamespace = `${string}.${string}`;
export class VectorSearchEmbeddings {
constructor(
private readonly config: UserConfig,
private readonly embeddings: Map<EmbeddingNamespace, VectorFieldIndexDefinition[]> = new Map(),
private readonly atlasSearchStatus: Map<string, boolean> = new Map()
) {}

cleanupEmbeddingsForNamespace({ database, collection }: { database: string; collection: string }): void {
const embeddingDefKey: EmbeddingNamespace = `${database}.${collection}`;
this.embeddings.delete(embeddingDefKey);
}

async embeddingsForNamespace({
database,
collection,
provider,
}: {
database: string;
collection: string;
provider: NodeDriverServiceProvider;
}): Promise<VectorFieldIndexDefinition[]> {
if (!(await this.isAtlasSearchAvailable(provider))) {
return [];
}

// We only need the embeddings for validation now, so don't query them if
// validation is disabled.
if (this.config.disableEmbeddingsValidation) {
return [];
}

const embeddingDefKey: EmbeddingNamespace = `${database}.${collection}`;
const definition = this.embeddings.get(embeddingDefKey);

if (!definition) {
const allSearchIndexes = await provider.getSearchIndexes(database, collection);
const vectorSearchIndexes = allSearchIndexes.filter((index) => index.type === "vectorSearch");
const vectorFields = vectorSearchIndexes
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
.flatMap<Document>((index) => (index.latestDefinition?.fields as Document) ?? [])
.filter((field) => this.isVectorFieldIndexDefinition(field));

this.embeddings.set(embeddingDefKey, vectorFields);
return vectorFields;
} else {
return definition;
}
}

async findFieldsWithWrongEmbeddings(
{
database,
collection,
provider,
}: {
database: string;
collection: string;
provider: NodeDriverServiceProvider;
},
document: Document
): Promise<VectorFieldIndexDefinition[]> {
if (!(await this.isAtlasSearchAvailable(provider))) {
return [];
}

// While we can do our best effort to ensure that the embedding validation is correct
// based on https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-quantization/
// it's a complex process so we will also give the user the ability to disable this validation
if (this.config.disableEmbeddingsValidation) {
return [];
}

const embeddings = await this.embeddingsForNamespace({ database, collection, provider });
return embeddings.filter((emb) => !this.documentPassesEmbeddingValidation(emb, document));
}

async isAtlasSearchAvailable(provider: NodeDriverServiceProvider): Promise<boolean> {
const providerUri = provider.getURI();
if (!providerUri) {
// no URI? can't be cached
return await this.canListAtlasSearchIndexes(provider);
}

if (this.atlasSearchStatus.has(providerUri)) {
// has should ensure that get is always defined
return this.atlasSearchStatus.get(providerUri) ?? false;
}

const availability = await this.canListAtlasSearchIndexes(provider);
this.atlasSearchStatus.set(providerUri, availability);
return availability;
}

private isVectorFieldIndexDefinition(doc: Document): doc is VectorFieldIndexDefinition {
return doc["type"] === "vector";
}

private documentPassesEmbeddingValidation(definition: VectorFieldIndexDefinition, document: Document): boolean {
const fieldPath = definition.path.split(".");
let fieldRef: unknown = document;

for (const field of fieldPath) {
if (fieldRef && typeof fieldRef === "object" && field in fieldRef) {
fieldRef = (fieldRef as Record<string, unknown>)[field];
} else {
return true;
}
}

switch (definition.quantization) {
case "none":
return true;
case "scalar":
case "binary":
if (fieldRef instanceof BSON.Binary) {
try {
const elements = fieldRef.toFloat32Array();
return elements.length === definition.numDimensions;
} catch {
// bits are also supported
try {
const bits = fieldRef.toBits();
return bits.length === definition.numDimensions;
} catch {
return false;
}
}
} else {
if (!Array.isArray(fieldRef)) {
return false;
}

if (fieldRef.length !== definition.numDimensions) {
return false;
}

if (!fieldRef.every((e) => this.isANumber(e))) {
return false;
}
}

break;
}

return true;
}

private async canListAtlasSearchIndexes(provider: NodeDriverServiceProvider): Promise<boolean> {
try {
await provider.getSearchIndexes("test", "test");
return true;
} catch {
return false;
}
}

private isANumber(value: unknown): boolean {
if (typeof value === "number") {
return true;
}

if (
value instanceof BSON.Int32 ||
value instanceof BSON.Decimal128 ||
value instanceof BSON.Double ||
value instanceof BSON.Long
) {
return true;
}

return false;
}
}
5 changes: 5 additions & 0 deletions src/common/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-d
import { ErrorCodes, MongoDBError } from "./errors.js";
import type { ExportsManager } from "./exportsManager.js";
import type { Keychain } from "./keychain.js";
import type { VectorSearchEmbeddings } from "./search/vectorSearchEmbeddings.js";

export interface SessionOptions {
apiBaseUrl: string;
Expand All @@ -25,6 +26,7 @@ export interface SessionOptions {
exportsManager: ExportsManager;
connectionManager: ConnectionManager;
keychain: Keychain;
vectorSearchEmbeddings: VectorSearchEmbeddings;
}

export type SessionEvents = {
Expand All @@ -40,6 +42,7 @@ export class Session extends EventEmitter<SessionEvents> {
readonly connectionManager: ConnectionManager;
readonly apiClient: ApiClient;
readonly keychain: Keychain;
readonly vectorSearchEmbeddings: VectorSearchEmbeddings;

mcpClient?: {
name?: string;
Expand All @@ -57,6 +60,7 @@ export class Session extends EventEmitter<SessionEvents> {
connectionManager,
exportsManager,
keychain,
vectorSearchEmbeddings,
}: SessionOptions) {
super();

Expand All @@ -73,6 +77,7 @@ export class Session extends EventEmitter<SessionEvents> {
this.apiClient = new ApiClient({ baseUrl: apiBaseUrl, credentials }, logger);
this.exportsManager = exportsManager;
this.connectionManager = connectionManager;
this.vectorSearchEmbeddings = vectorSearchEmbeddings;
this.connectionManager.events.on("connection-success", () => this.emit("connect"));
this.connectionManager.events.on("connection-time-out", (error) => this.emit("connection-error", error));
this.connectionManager.events.on("connection-close", () => this.emit("disconnect"));
Expand Down
47 changes: 35 additions & 12 deletions src/tools/mongodb/create/insertMany.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { z } from "zod";
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
import type { ToolArgs, OperationType } from "../../tool.js";
import { type ToolArgs, type OperationType, formatUntrustedData } from "../../tool.js";
import { zEJSON } from "../../args.js";

export class InsertManyTool extends MongoDBToolBase {
Expand All @@ -23,19 +23,42 @@ export class InsertManyTool extends MongoDBToolBase {
documents,
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
const provider = await this.ensureConnected();
const result = await provider.insertMany(database, collection, documents);

const embeddingValidations = new Set(
...(await Promise.all(
documents.flatMap((document) =>
this.session.vectorSearchEmbeddings.findFieldsWithWrongEmbeddings(
{ database, collection, provider },
document
)
)
))
);

if (embeddingValidations.size > 0) {
// tell the LLM what happened
const embeddingValidationMessages = [...embeddingValidations].map(
(validation) =>
`- Field ${validation.path} is an embedding with ${validation.numDimensions} dimensions and ${validation.quantization} quantization, and the provided value is not compatible.`
);

return {
content: formatUntrustedData(
"There were errors when inserting documents. No document was inserted.",
...embeddingValidationMessages
),
isError: true,
};
}

const result = await provider.insertMany(database, collection, documents);
const content = formatUntrustedData(
"Documents were inserted successfully.",
`Inserted \`${result.insertedCount}\` document(s) into ${database}.${collection}.`,
`Inserted IDs: ${Object.values(result.insertedIds).join(", ")}`
);
return {
content: [
{
text: `Inserted \`${result.insertedCount}\` document(s) into collection "${collection}"`,
type: "text",
},
{
text: `Inserted IDs: ${Object.values(result.insertedIds).join(", ")}`,
type: "text",
},
],
content,
};
}
}
6 changes: 1 addition & 5 deletions src/tools/mongodb/metadata/collectionIndexes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@ export class CollectionIndexesTool extends MongoDBToolBase {
return {
content: formatUntrustedData(
`Found ${indexes.length} indexes in the collection "${collection}":`,
indexes.length > 0
? indexes
.map((index) => `Name: "${index.name}", definition: ${JSON.stringify(index.key)}`)
.join("\n")
: undefined
...indexes.map((index) => `Name: "${index.name}", definition: ${JSON.stringify(index.key)}`)
),
};
}
Expand Down
4 changes: 1 addition & 3 deletions src/tools/mongodb/metadata/listDatabases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ export class ListDatabasesTool extends MongoDBToolBase {
return {
content: formatUntrustedData(
`Found ${dbs.length} databases`,
dbs.length > 0
? dbs.map((db) => `Name: ${db.name}, Size: ${db.sizeOnDisk.toString()} bytes`).join("\n")
: undefined
...dbs.map((db) => `Name: ${db.name}, Size: ${db.sizeOnDisk.toString()} bytes`)
),
};
}
Expand Down
12 changes: 12 additions & 0 deletions src/tools/mongodb/mongodbTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ export abstract class MongoDBToolBase extends ToolBase {
return this.session.serviceProvider;
}

protected async ensureSearchAvailable(): Promise<NodeDriverServiceProvider> {
const provider = await this.ensureConnected();
if (!(await this.session.vectorSearchEmbeddings.isAtlasSearchAvailable(provider))) {
throw new MongoDBError(
ErrorCodes.AtlasSearchNotAvailable,
"This MongoDB cluster does not support Search Indexes. Make sure you are using an Atlas Cluster, either remotely in Atlas or using the Atlas Local image, or your cluster supports MongoDB Search."
);
}

return provider;
}

public register(server: Server): boolean {
this.server = server;
return super.register(server);
Expand Down
2 changes: 1 addition & 1 deletion src/tools/mongodb/read/aggregate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ export class AggregateTool extends MongoDBToolBase {
cursorResults.cappedBy,
].filter((limit): limit is keyof typeof CURSOR_LIMITS_TO_LLM_TEXT => !!limit),
}),
cursorResults.documents.length > 0 ? EJSON.stringify(cursorResults.documents) : undefined
...(cursorResults.documents.length > 0 ? [EJSON.stringify(cursorResults.documents)] : [])
),
};
} finally {
Expand Down
2 changes: 1 addition & 1 deletion src/tools/mongodb/read/find.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ export class FindTool extends MongoDBToolBase {
documents: cursorResults.documents,
appliedLimits: [limitOnFindCursor.cappedBy, cursorResults.cappedBy].filter((limit) => !!limit),
}),
cursorResults.documents.length > 0 ? EJSON.stringify(cursorResults.documents) : undefined
...(cursorResults.documents.length > 0 ? [EJSON.stringify(cursorResults.documents)] : [])
),
};
} finally {
Expand Down
Loading
Loading