Skip to content
Merged
1 change: 1 addition & 0 deletions src/common/logger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export const LogId = {
toolUpdateFailure: mongoLogId(1_005_001),
resourceUpdateFailure: mongoLogId(1_005_002),
updateToolMetadata: mongoLogId(1_005_003),
toolValidationError: mongoLogId(1_005_004),

streamableHttpTransportStarted: mongoLogId(1_006_001),
streamableHttpTransportSessionCloseFailure: mongoLogId(1_006_002),
Expand Down
22 changes: 22 additions & 0 deletions src/helpers/collectFieldsFromVectorSearchFilter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Based on -
// https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#mongodb-vector-search-pre-filter
const ALLOWED_LOGICAL_OPERATORS = ["$not", "$nor", "$and", "$or"];

export function collectFieldsFromVectorSearchFilter(filter: unknown): string[] {
if (!filter || typeof filter !== "object" || !Object.keys(filter).length) {
return [];
}

const collectedFields = Object.entries(filter).reduce<string[]>((collectedFields, [maybeField, fieldMQL]) => {
if (ALLOWED_LOGICAL_OPERATORS.includes(maybeField) && Array.isArray(fieldMQL)) {
return fieldMQL.flatMap((mql) => collectFieldsFromVectorSearchFilter(mql));
}

if (!ALLOWED_LOGICAL_OPERATORS.includes(maybeField)) {
collectedFields.push(maybeField);
}
return collectedFields;
}, []);

return Array.from(new Set(collectedFields));
}
70 changes: 70 additions & 0 deletions src/tools/mongodb/read/aggregate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { AGG_COUNT_MAX_TIME_MS_CAP, ONE_MB, CURSOR_LIMITS_TO_LLM_TEXT } from "..
import { zEJSON } from "../../args.js";
import { LogId } from "../../../common/logger.js";
import { zSupportedEmbeddingParameters } from "../../../common/search/embeddingsProvider.js";
import { collectFieldsFromVectorSearchFilter } from "../../../helpers/collectFieldsFromVectorSearchFilter.js";

const AnyStage = zEJSON();
const VectorSearchStage = z.object({
Expand Down Expand Up @@ -97,6 +98,7 @@ export class AggregateTool extends MongoDBToolBase {
try {
const provider = await this.ensureConnected();
await this.assertOnlyUsesPermittedStages(pipeline);
await this.assertVectorSearchFilterFieldsAreIndexed(database, collection, pipeline);

// Check if aggregate operation uses an index if enabled
if (this.config.indexCheck) {
Expand Down Expand Up @@ -202,6 +204,74 @@ export class AggregateTool extends MongoDBToolBase {
}
}

private async assertVectorSearchFilterFieldsAreIndexed(
database: string,
collection: string,
pipeline: Record<string, unknown>[]
): Promise<void> {
if (!(await this.session.isSearchSupported())) {
return;
}

const searchIndexesWithFilterFields = await this.searchIndexesWithFilterFields(database, collection);
for (const stage of pipeline) {
if ("$vectorSearch" in stage) {
const { $vectorSearch: vectorSearchStage } = stage as z.infer<typeof VectorSearchStage>;
const allowedFilterFields = searchIndexesWithFilterFields[vectorSearchStage.index];
if (!allowedFilterFields) {
this.session.logger.warning({
id: LogId.toolValidationError,
context: "aggregate tool",
message: `Could not assert if filter fields are indexed - No filter fields found for index ${vectorSearchStage.index}`,
});
return;
}

const filterFieldsInStage = collectFieldsFromVectorSearchFilter(vectorSearchStage.filter);
const filterFieldsNotIndexed = filterFieldsInStage.filter(
(field) => !allowedFilterFields.includes(field)
);
if (filterFieldsNotIndexed.length) {
throw new MongoDBError(
ErrorCodes.AtlasVectorSearchInvalidQuery,
`Vector search stage contains filter on fields are not indexed by index ${vectorSearchStage.index} - ${filterFieldsNotIndexed.join(", ")}`
);
}
}
}
}

private async searchIndexesWithFilterFields(
database: string,
collection: string
): Promise<Record<string, string[]>> {
const searchIndexes = (await this.session.serviceProvider.getSearchIndexes(database, collection)) as Array<{
name: string;
latestDefinition: {
fields: Array<
| {
type: "vector";
}
| {
type: "filter";
path: string;
}
>;
};
}>;

return searchIndexes.reduce<Record<string, string[]>>((indexFieldMap, searchIndex) => {
const filterFields = searchIndex.latestDefinition.fields
.map<string | undefined>((field) => {
return field.type === "filter" ? field.path : undefined;
})
.filter((filterField) => filterField !== undefined);

indexFieldMap[searchIndex.name] = filterFields;
return indexFieldMap;
}, {});
}

private async countAggregationResultDocuments({
provider,
database,
Expand Down
Loading
Loading