diff --git a/.vscode/cspell.json b/.vscode/cspell.json index db18c5f58dff..9192cb7b1615 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -320,6 +320,8 @@ "Parition", "colls", "pkranges", + "rerank", + "Rerank", "sproc", "sprocs", "udfs", diff --git a/sdk/cosmosdb/cosmos/review/cosmos-node.api.md b/sdk/cosmosdb/cosmos/review/cosmos-node.api.md index b84f0b83720b..092f7d564e78 100644 --- a/sdk/cosmosdb/cosmos/review/cosmos-node.api.md +++ b/sdk/cosmosdb/cosmos/review/cosmos-node.api.md @@ -317,6 +317,7 @@ export class ClientContext { diagnosticNode: DiagnosticNodeInternal; partitionKeyRangeId?: string; }): Promise>; + semanticRerank(context: string, documents: string[], options?: SemanticRerankOptions): Promise; // (undocumented) upsert(input: { body: T; @@ -700,6 +701,11 @@ export const Constants: { DefaultEncryptionCacheTimeToLiveInSeconds: number; EncryptionCacheRefreshIntervalInMs: number; RequestTimeoutForReadsInMs: number; + InferenceBasePath: string; + InferenceUserAgent: string; + InferenceDefaultScope: string; + InferenceDefaultTimeoutMs: number; + InferenceEndpointEnvVar: string; }; // @public @@ -731,6 +737,7 @@ export class Container { readPartitionKeyRanges(feedOptions?: FeedOptions): QueryIterator; replace(body: ContainerDefinition, options?: RequestOptions): Promise; get scripts(): Scripts; + semanticRerank(context: string, documents: string[], options?: SemanticRerankOptions): Promise; get url(): string; } @@ -824,6 +831,7 @@ export interface CosmosClientOptions { diagnosticLevel?: CosmosDbDiagnosticLevel; endpoint?: string; httpClient?: HttpClient; + inferenceEndpoint?: string; key?: string; permissionFeed?: PermissionDefinition[]; resourceTokens?: { @@ -2133,6 +2141,13 @@ export interface RequestOptions extends SharedOptions { urlConnection?: string; } +// @public +export interface RerankScore { + document: string | null; + index: number; + score: number; +} + // @public (undocumented) export interface Resource { _etag: string; @@ -2377,6 +2392,17 @@ export class Scripts { get userDefinedFunctions(): UserDefinedFunctions; } +// @public +export type SemanticRerankOptions = Record; + +// @public +export interface SemanticRerankResult { + headers: Record; + latency: Record | undefined; + rerankScores: RerankScore[]; + tokenUsage: Record | undefined; +} + // @public export function setAuthorizationTokenHeaderUsingMasterKey(verb: HTTPMethod, resourceId: string, resourceType: ResourceType, headers: CosmosHeaders, masterKey: string): Promise; @@ -2462,6 +2488,8 @@ export interface StatusCodesType { // (undocumented) MethodNotAllowed: 405; // (undocumented) + MultipleChoices: 300; + // (undocumented) MultiStatus: 207; // (undocumented) NoContent: 204; diff --git a/sdk/cosmosdb/cosmos/src/ClientContext.ts b/sdk/cosmosdb/cosmos/src/ClientContext.ts index dc875edb1ac5..782eb3fdd987 100644 --- a/sdk/cosmosdb/cosmos/src/ClientContext.ts +++ b/sdk/cosmosdb/cosmos/src/ClientContext.ts @@ -51,6 +51,9 @@ import { AAD_AUTH_PREFIX, AAD_RESOURCE_NOT_FOUND_ERROR, } from "./common/constants.js"; +import { InferenceService } from "./inference/InferenceService.js"; +import type { SemanticRerankOptions } from "./inference/SemanticRerankOptions.js"; +import type { SemanticRerankResult } from "./inference/SemanticRerankResult.js"; const logger: AzureLogger = createClientLogger("ClientContext"); @@ -70,6 +73,7 @@ export class ClientContext { public partitionKeyRangeCache: PartitionKeyRangeCache; /** boolean flag to support operations with client-side encryption */ public enableEncryption: boolean = false; + private inferenceService: InferenceService | null = null; public constructor( private cosmosClientOptions: CosmosClientOptions, @@ -1108,4 +1112,45 @@ export class ClientContext { this.globalEndpointManager.lastKnownPPCBEnabled ); } + + /** + * Rerank a list of documents using semantic reranking via the Cosmos DB Inference Service. + * This method uses a semantic reranker to score and reorder the provided documents + * based on their relevance to the given reranking context. + * + * The semantic reranking requests use a separate HTTP pipeline and do not use + * the default SDK retry policies. + * + * @param rerankContext - The context (e.g. query string) to use for reranking. + * @param documents - The documents to be reranked. + * @param options - Optional settings for the reranking request. + * @returns The reranking results including scores, latency, and token usage. + */ + public async semanticRerank( + context: string, + documents: string[], + options?: SemanticRerankOptions, + ): Promise { + const service = this.getOrCreateInferenceService(); + return service.semanticRerank(context, documents, options); + } + + /** + * Gets or lazily creates the InferenceService instance. + * @internal + */ + private getOrCreateInferenceService(): InferenceService { + if (!this.inferenceService) { + this.inferenceService = new InferenceService(this.cosmosClientOptions); + } + return this.inferenceService; + } + + /** + * Disposes the InferenceService if it was created. + * @internal + */ + public disposeInferenceService(): void { + this.inferenceService = null; + } } diff --git a/sdk/cosmosdb/cosmos/src/CosmosClient.ts b/sdk/cosmosdb/cosmos/src/CosmosClient.ts index 2b7c1be69c9c..87bbea0ae0d8 100644 --- a/sdk/cosmosdb/cosmos/src/CosmosClient.ts +++ b/sdk/cosmosdb/cosmos/src/CosmosClient.ts @@ -355,6 +355,7 @@ export class CosmosClient { if (this.globalPartitionEndpointManager) { this.globalPartitionEndpointManager.dispose(); } + this.clientContext.disposeInferenceService(); } private async backgroundRefreshEndpointList( diff --git a/sdk/cosmosdb/cosmos/src/CosmosClientOptions.ts b/sdk/cosmosdb/cosmos/src/CosmosClientOptions.ts index 54237731419e..e9823b79d307 100644 --- a/sdk/cosmosdb/cosmos/src/CosmosClientOptions.ts +++ b/sdk/cosmosdb/cosmos/src/CosmosClientOptions.ts @@ -81,4 +81,10 @@ export interface CosmosClientOptions { /** An optional parameter that represents the connection string. Your database connection string can be found in the Azure Portal. */ connectionString?: string; + + /** + * The endpoint URL for the Cosmos DB Inference Service, used for features such as semantic reranking. + * If not provided, the SDK falls back to the `AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT` environment variable. + */ + inferenceEndpoint?: string; } diff --git a/sdk/cosmosdb/cosmos/src/client/Container/Container.ts b/sdk/cosmosdb/cosmos/src/client/Container/Container.ts index 18b4cc08a6ec..e0c034504a2f 100644 --- a/sdk/cosmosdb/cosmos/src/client/Container/Container.ts +++ b/sdk/cosmosdb/cosmos/src/client/Container/Container.ts @@ -44,6 +44,8 @@ import { MetadataLookUpType } from "../../CosmosDiagnostics.js"; import type { EncryptionSettingForProperty } from "../../encryption/index.js"; import { EncryptionProcessor } from "../../encryption/index.js"; import type { EncryptionManager } from "../../encryption/EncryptionManager.js"; +import type { SemanticRerankOptions } from "../../inference/SemanticRerankOptions.js"; +import type { SemanticRerankResult } from "../../inference/SemanticRerankResult.js"; /** * Operations for reading, replacing, or deleting a specific, existing container by id. @@ -691,6 +693,74 @@ export class Container { } } + /** + * Rerank a list of documents using semantic reranking via the Cosmos DB Inference Service. + * This method uses a semantic reranker to score and reorder the provided documents + * based on their relevance to the given context. + * + * The semantic reranking requests use a separate HTTP pipeline from the main Cosmos DB client + * and do not use the default SDK retry policies. + * + * To use this feature, you must: + * 1. Configure AAD authentication via `aadCredentials` in `CosmosClientOptions` + * 2. Provide the inference endpoint via `inferenceEndpoint` in `CosmosClientOptions`, + * or set the `AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT` environment variable + * + * @param context - The context (e.g. query string) to use for reranking the documents. + * @param documents - A list of documents (as JSON strings) to be reranked. + * @param options - Optional dictionary of settings for the reranking request. + * Known service options: + * - `return_documents` (boolean) — include reranked documents in the response. + * - `top_k` (number) — max number of top-ranked documents to return. + * - `batch_size` (number) — batch size for processing documents. + * - `sort` (boolean) — sort results by relevance score in descending order. + * - `document_type` (`"string"` | `"json"`) — type of documents being reranked. + * - `target_paths` (string) — comma-separated JSON paths (when document_type is `"json"`). + * - `abortSignal` (AbortSignal) — signal to cancel the request. + * Any additional keys are forwarded as-is to the inference service. + * @returns The reranking results including scored documents, latency, and token usage. + * + * @example Semantic reranking of query results + * ```ts snippet:ContainerSemanticRerank + * import { DefaultAzureCredential } from "@azure/identity"; + * import { CosmosClient } from "@azure/cosmos"; + * + * const endpoint = "https://your-account.documents.azure.com"; + * const aadCredentials = new DefaultAzureCredential(); + * const client = new CosmosClient({ + * endpoint, + * aadCredentials, + * }); + * + * const { database } = await client.databases.createIfNotExists({ id: "Test Database" }); + * const { container } = await database.containers.createIfNotExists({ id: "Test Container" }); + * + * const queryResults = ["doc1 JSON", "doc2 JSON", "doc3 JSON"]; + * const result = await container.semanticRerank( + * "most economical with multiple adjustments", + * queryResults, + * { return_documents: true, top_k: 10, sort: true }, + * ); + * // Access the top-ranked document + * if (result.rerankScores.length > 0) { + * const topResult = result.rerankScores[0]; + * const topScore = topResult.score; + * const topDocument = topResult.document; + * if (topDocument !== null) { + * console.log("Top-ranked document:", topDocument); + * } + * console.log("Top score:", topScore); + * } + * ``` + */ + public async semanticRerank( + context: string, + documents: string[], + options?: SemanticRerankOptions, + ): Promise { + return this.clientContext.semanticRerank(context, documents, options); + } + /** * @internal */ diff --git a/sdk/cosmosdb/cosmos/src/common/constants.ts b/sdk/cosmosdb/cosmos/src/common/constants.ts index 9f0c787a9aac..f601d2a27e46 100644 --- a/sdk/cosmosdb/cosmos/src/common/constants.ts +++ b/sdk/cosmosdb/cosmos/src/common/constants.ts @@ -304,6 +304,13 @@ export const Constants = { EncryptionCacheRefreshIntervalInMs: 60000, // 1 minute RequestTimeoutForReadsInMs: 2000, // 2 seconds + + // Inference Service + InferenceBasePath: "/inference/semanticReranking", + InferenceUserAgent: "cosmos-inference-js", + InferenceDefaultScope: "https://dbinference.azure.com/.default", + InferenceDefaultTimeoutMs: 120_000, // 120 seconds + InferenceEndpointEnvVar: "AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT", }; export const AAD_DEFAULT_SCOPE = "https://cosmos.azure.com/.default"; diff --git a/sdk/cosmosdb/cosmos/src/common/statusCodes.ts b/sdk/cosmosdb/cosmos/src/common/statusCodes.ts index 96e11bea766a..f1758683b956 100644 --- a/sdk/cosmosdb/cosmos/src/common/statusCodes.ts +++ b/sdk/cosmosdb/cosmos/src/common/statusCodes.ts @@ -11,6 +11,9 @@ export interface StatusCodesType { Accepted: 202; NoContent: 204; MultiStatus: 207; + + // Redirection + MultipleChoices: 300; NotModified: 304; // Client error @@ -50,6 +53,9 @@ export const StatusCodes: StatusCodesType = { Accepted: 202, NoContent: 204, MultiStatus: 207, + + // Redirection + MultipleChoices: 300, NotModified: 304, // Client error diff --git a/sdk/cosmosdb/cosmos/src/index.ts b/sdk/cosmosdb/cosmos/src/index.ts index 0ea8ae04a4a9..f01d7cd91c32 100644 --- a/sdk/cosmosdb/cosmos/src/index.ts +++ b/sdk/cosmosdb/cosmos/src/index.ts @@ -164,3 +164,9 @@ export { type CosmosEncryptedNumber, CosmosEncryptedNumberType, } from "./encryption/index.js"; + +export type { + RerankScore, + SemanticRerankResult, + SemanticRerankOptions, +} from "./inference/index.js"; diff --git a/sdk/cosmosdb/cosmos/src/inference/InferenceService.ts b/sdk/cosmosdb/cosmos/src/inference/InferenceService.ts new file mode 100644 index 000000000000..7a4e966941a6 --- /dev/null +++ b/sdk/cosmosdb/cosmos/src/inference/InferenceService.ts @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import type { TokenCredential } from "@azure/core-auth"; +import type { + HttpClient, + Pipeline, + PipelineRequest, + PipelineResponse, +} from "@azure/core-rest-pipeline"; +import { + bearerTokenAuthenticationPolicy, + createEmptyPipeline, + createPipelineRequest, +} from "@azure/core-rest-pipeline"; +import type { AzureLogger } from "@azure/logger"; +import { createClientLogger } from "@azure/logger"; +import type { CosmosClientOptions } from "../CosmosClientOptions.js"; +import type { SemanticRerankOptions } from "./SemanticRerankOptions.js"; +import type { RerankScore, SemanticRerankResult } from "./SemanticRerankResult.js"; +import { Constants } from "../common/constants.js"; +import { StatusCodes } from "../common/statusCodes.js"; +import { getCachedDefaultHttpClient } from "../utils/cachedClient.js"; +import { ErrorResponse } from "../request/ErrorResponse.js"; + +const logger: AzureLogger = createClientLogger("InferenceService"); + +/** Keys that are not part of the inference service payload. */ +const NON_PAYLOAD_KEYS = new Set(["abortSignal"]); + +/** + * Provides functionality to interact with the Cosmos DB Inference Service for semantic reranking. + * @internal + */ +export class InferenceService { + private readonly pipeline: Pipeline; + private readonly httpClient: HttpClient; + private readonly inferenceEndpointUrl: string; + + constructor(cosmosClientOptions: CosmosClientOptions) { + if (!cosmosClientOptions.aadCredentials) { + throw new Error( + "Semantic rerank requires AAD authentication. Provide 'aadCredentials' in CosmosClientOptions.", + ); + } + + const endpoint = this.resolveInferenceEndpoint(cosmosClientOptions); + this.inferenceEndpointUrl = `${endpoint}${Constants.InferenceBasePath}`; + + this.pipeline = this.createInferencePipeline(cosmosClientOptions.aadCredentials); + this.httpClient = cosmosClientOptions.httpClient ?? getCachedDefaultHttpClient(); + + logger.info(`InferenceService initialized with endpoint: ${endpoint}`); + } + + /** + * Sends a semantic rerank request to the inference service. + * @param context - The context (e.g. query string) to use for reranking. + * @param documents - The documents to be reranked. + * @param options - Optional settings for the reranking request. + * @returns The reranking results including scores, latency, and token usage. + */ + async semanticRerank( + context: string, + documents: string[], + options?: SemanticRerankOptions, + ): Promise { + const payload = this.buildPayload(context, documents, options); + + const request = createPipelineRequest({ + url: this.inferenceEndpointUrl, + method: "POST", + body: JSON.stringify(payload), + abortSignal: options?.["abortSignal"] as AbortSignal | undefined, + timeout: Constants.InferenceDefaultTimeoutMs, + }); + + this.buildHeaders(request); + + const response = await this.pipeline.sendRequest(this.httpClient, request); + return this.parseResponse(response); + } + + /** + * Resolves the inference endpoint from client options or the environment variable. + * Client options take priority over the environment variable. + */ + private resolveInferenceEndpoint(cosmosClientOptions: CosmosClientOptions): string { + const endpoint = + cosmosClientOptions.inferenceEndpoint || + (typeof process !== "undefined" ? process.env[Constants.InferenceEndpointEnvVar] : undefined); + + if (!endpoint) { + throw new Error( + `Inference endpoint is required for semantic reranking. ` + + `Set 'inferenceEndpoint' in CosmosClientOptions or the '${Constants.InferenceEndpointEnvVar}' environment variable.`, + ); + } + + // Remove trailing slash if present + return endpoint.replace(/\/+$/, ""); + } + + /** + * Creates a pipeline configured for inference service authentication. + */ + private createInferencePipeline(credential: TokenCredential): Pipeline { + const pipeline = createEmptyPipeline(); + pipeline.addPolicy( + bearerTokenAuthenticationPolicy({ + credential, + scopes: Constants.InferenceDefaultScope, + }), + ); + return pipeline; + } + + /** + * Sets the required HTTP headers on an inference service request. + */ + private buildHeaders(request: PipelineRequest): void { + request.headers.set("Content-Type", "application/json"); + request.headers.set("Accept", "application/json"); + request.headers.set("Cache-Control", "no-cache"); + request.headers.set(Constants.HttpHeaders.Version, Constants.CurrentVersion); + request.headers.set(Constants.HttpHeaders.UserAgent, Constants.InferenceUserAgent); + request.headers.set(Constants.HttpHeaders.CustomUserAgent, Constants.InferenceUserAgent); + } + + /** + * Builds the JSON payload for the semantic rerank request. + */ + private buildPayload( + context: string, + documents: string[], + options?: SemanticRerankOptions, + ): Record { + const payload: Record = {}; + + if (options) { + // Forward all option keys except non-payload keys (e.g. abortSignal) + for (const [key, value] of Object.entries(options)) { + if (!NON_PAYLOAD_KEYS.has(key) && value !== undefined) { + payload[key] = value; + } + } + } + + // Required fields are set last to prevent options from overriding them + payload["query"] = context; + payload["documents"] = documents; + + return payload; + } + + /** + * Parses the HTTP response into a SemanticRerankResult. + * + * Note: The inference API response uses mixed casing conventions: + * - PascalCase: `Scores` (array of rerank results) + * - camelCase: `latency` (timing info), `document`, `score`, `index` + * - snake_case: `token_usage` (token consumption) + * This is the actual service response format, not a bug. + */ + private parseResponse(response: PipelineResponse): SemanticRerankResult { + if (response.status < StatusCodes.Ok || response.status >= StatusCodes.MultipleChoices) { + let serviceCode: string | number = response.status; + let serviceMessage = `Semantic rerank request failed with status ${response.status}`; + + // Parse the error payload to surface the service's code, message, and details + try { + const errorBody = JSON.parse(response.bodyAsText || "{}"); + if (errorBody.code) { + serviceCode = errorBody.code; + } + if (errorBody.message) { + serviceMessage = errorBody.message; + } + if (errorBody.details) { + serviceMessage += ` Details: ${JSON.stringify(errorBody.details)}`; + } + } catch { + // If parsing fails, fall back to raw body text + serviceMessage += `: ${response.bodyAsText}`; + } + + const errorResponse = new ErrorResponse(serviceMessage); + errorResponse.code = serviceCode; + errorResponse.headers = response.headers.toJSON() as Record; + throw errorResponse; + } + + const body = JSON.parse(response.bodyAsText || "{}"); + + const rerankScores: RerankScore[] = []; + if (Array.isArray(body.Scores)) { + for (const item of body.Scores) { + rerankScores.push({ + document: item.document ?? null, + score: typeof item.score === "number" ? item.score : 0, + index: typeof item.index === "number" ? item.index : -1, + }); + } + } + + return { + rerankScores, + latency: body.latency ?? undefined, + tokenUsage: body.token_usage ?? undefined, + headers: response.headers.toJSON() as Record, + }; + } +} diff --git a/sdk/cosmosdb/cosmos/src/inference/SemanticRerankOptions.ts b/sdk/cosmosdb/cosmos/src/inference/SemanticRerankOptions.ts new file mode 100644 index 000000000000..3d840ba044d9 --- /dev/null +++ b/sdk/cosmosdb/cosmos/src/inference/SemanticRerankOptions.ts @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +/** + * Options for a semantic reranking request, passed as a flat dictionary. + * + * Known service options (all optional): + * - `return_documents` (boolean) — if true, the reranked documents are included in the response. + * - `top_k` (number) — the maximum number of top-ranked documents to return. + * - `batch_size` (number) — the batch size for processing documents. + * - `sort` (boolean) — if true, results are sorted by relevance score in descending order. + * - `document_type` (`"string"` | `"json"`) — the type of documents being reranked. + * - `target_paths` (string) — comma-separated JSON paths to extract text from (when document_type is `"json"`). + * - `abortSignal` (AbortSignal) — signal to cancel the request (not sent as part of the payload). + * + * Any additional keys are forwarded as-is to the inference service payload. + */ +export type SemanticRerankOptions = Record; diff --git a/sdk/cosmosdb/cosmos/src/inference/SemanticRerankResult.ts b/sdk/cosmosdb/cosmos/src/inference/SemanticRerankResult.ts new file mode 100644 index 000000000000..e9e4f611d12f --- /dev/null +++ b/sdk/cosmosdb/cosmos/src/inference/SemanticRerankResult.ts @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +/** + * Represents the score assigned to a document after a semantic reranking operation. + */ +export interface RerankScore { + /** The document content that was reranked. May be null if `return_documents` was not set. */ + document: string | null; + /** The relevance score assigned to the document after reranking. */ + score: number; + /** The original index of the document in the input list before reranking. */ + index: number; +} + +/** + * Represents the result of a semantic reranking operation, including rerank scores, + * latency, token usage, and HTTP response headers. + */ +export interface SemanticRerankResult { + /** The list of rerank scores for the documents. */ + rerankScores: RerankScore[]; + /** Latency information for the rerank operation. */ + latency: Record | undefined; + /** Token usage information for the rerank operation. */ + tokenUsage: Record | undefined; + /** HTTP response headers from the inference service. */ + headers: Record; +} diff --git a/sdk/cosmosdb/cosmos/src/inference/index.ts b/sdk/cosmosdb/cosmos/src/inference/index.ts new file mode 100644 index 000000000000..a33841c4a446 --- /dev/null +++ b/sdk/cosmosdb/cosmos/src/inference/index.ts @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +export type { RerankScore, SemanticRerankResult } from "./SemanticRerankResult.js"; +export type { SemanticRerankOptions } from "./SemanticRerankOptions.js"; diff --git a/sdk/cosmosdb/cosmos/test/internal/unit/inference/inferenceService.spec.ts b/sdk/cosmosdb/cosmos/test/internal/unit/inference/inferenceService.spec.ts new file mode 100644 index 000000000000..a3528606786b --- /dev/null +++ b/sdk/cosmosdb/cosmos/test/internal/unit/inference/inferenceService.spec.ts @@ -0,0 +1,289 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { describe, it, assert, beforeEach, afterEach } from "vitest"; +import type { TokenCredential, GetTokenOptions, AccessToken } from "@azure/core-auth"; +import type { HttpClient, PipelineResponse } from "@azure/core-rest-pipeline"; +import { InferenceService } from "../../../../src/inference/InferenceService.js"; +import type { CosmosClientOptions } from "../../../../src/CosmosClientOptions.js"; + +class MockTokenCredential implements TokenCredential { + async getToken(scopes: string | string[], _options?: GetTokenOptions): Promise { + return { + token: "mock-token", + expiresOnTimestamp: Date.now() + 3600000, + }; + } +} + +function createMockOptions(overrides?: Partial): CosmosClientOptions { + return { + endpoint: "https://test-account.documents.azure.com:443/", + aadCredentials: new MockTokenCredential(), + ...overrides, + }; +} + +describe("InferenceService", { timeout: 10000 }, () => { + let originalEnv: string | undefined; + + beforeEach(() => { + originalEnv = process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT; + process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT = + "https://test-inference.dbinference.azure.com"; + }); + + afterEach(() => { + if (originalEnv !== undefined) { + process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT = originalEnv; + } else { + delete process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT; + } + }); + + describe("constructor", () => { + it("should throw when aadCredentials is not provided", () => { + assert.throws( + () => new InferenceService({ endpoint: "https://test.documents.azure.com" }), + /AAD authentication/, + ); + }); + + it("should throw when no inference endpoint is configured", () => { + delete process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT; + assert.throws( + () => + new InferenceService({ + endpoint: "https://test.documents.azure.com", + aadCredentials: new MockTokenCredential(), + }), + /Inference endpoint is required/, + ); + }); + + it("should use inferenceEndpoint from client options over environment variable", () => { + process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT = + "https://env-inference.dbinference.azure.com"; + const service = new InferenceService( + createMockOptions({ + inferenceEndpoint: "https://options-inference.dbinference.azure.com", + }), + ); + const resolvedUrl = (service as any).inferenceEndpointUrl as string; + assert.include(resolvedUrl, "options-inference"); + assert.notInclude(resolvedUrl, "env-inference"); + }); + + it("should fall back to environment variable when inferenceEndpoint is not in client options", () => { + process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT = + "https://env-inference.dbinference.azure.com"; + const service = new InferenceService(createMockOptions()); + const resolvedUrl = (service as any).inferenceEndpointUrl as string; + assert.include(resolvedUrl, "env-inference"); + }); + + it("should succeed with valid AAD credentials and inference endpoint", () => { + const service = new InferenceService(createMockOptions()); + assert.isDefined(service); + }); + + it("should read inference endpoint from environment variable", () => { + process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT = + "https://env-inference.dbinference.azure.com"; + const service = new InferenceService({ + endpoint: "https://test.documents.azure.com", + aadCredentials: new MockTokenCredential(), + }); + const resolvedUrl = (service as any).inferenceEndpointUrl as string; + assert.include(resolvedUrl, "env-inference"); + }); + }); + + describe("semanticRerank", () => { + it("should send correct payload with basic parameters", async () => { + let capturedBody: string | undefined; + + const service = new InferenceService(createMockOptions()); + + // Replace the pipeline's sendRequest to capture the request + const mockResponse: PipelineResponse = { + headers: { + toJSON: () => ({ "x-ms-request-id": "test-id" }), + } as any, + request: {} as any, + status: 200, + bodyAsText: JSON.stringify({ + Scores: [ + { document: "Doc 1 content", score: 0.95, index: 0 }, + { document: "Doc 2 content", score: 0.8, index: 1 }, + ], + latency: { total_ms: 100 }, + token_usage: { prompt_tokens: 50, total_tokens: 100 }, + }), + }; + + // Access private pipeline to mock sendRequest + const pipeline = (service as any).pipeline; + pipeline.sendRequest = async (client: HttpClient, request: any) => { + capturedBody = request.body; + return mockResponse; + }; + + const result = await service.semanticRerank("test query", ["doc1", "doc2"]); + + assert.isDefined(capturedBody); + const parsedBody = JSON.parse(capturedBody!); + assert.equal(parsedBody.query, "test query"); + assert.deepEqual(parsedBody.documents, ["doc1", "doc2"]); + + // Verify response parsing + assert.equal(result.rerankScores.length, 2); + assert.equal(result.rerankScores[0].score, 0.95); + assert.equal(result.rerankScores[0].index, 0); + assert.equal(result.rerankScores[0].document, "Doc 1 content"); + assert.equal(result.rerankScores[1].score, 0.8); + assert.isDefined(result.latency); + assert.isDefined(result.tokenUsage); + assert.isDefined(result.headers); + }); + + it("should include optional parameters in payload", async () => { + let capturedBody: string | undefined; + + const service = new InferenceService(createMockOptions()); + + const pipeline = (service as any).pipeline; + pipeline.sendRequest = async (_client: HttpClient, request: any) => { + capturedBody = request.body; + return { + headers: { toJSON: () => ({}) } as any, + request: {} as any, + status: 200, + bodyAsText: JSON.stringify({ Scores: [] }), + }; + }; + + await service.semanticRerank("test query", ["doc1"], { + return_documents: true, + top_k: 10, + batch_size: 32, + sort: true, + custom_param: "value", + }); + + const parsedBody = JSON.parse(capturedBody!); + assert.equal(parsedBody.return_documents, true); + assert.equal(parsedBody.top_k, 10); + assert.equal(parsedBody.batch_size, 32); + assert.equal(parsedBody.sort, true); + assert.equal(parsedBody.custom_param, "value"); + }); + + it("should throw on non-success HTTP status with plain text body", async () => { + const service = new InferenceService(createMockOptions()); + + const pipeline = (service as any).pipeline; + pipeline.sendRequest = async () => ({ + headers: { toJSON: () => ({}) } as any, + request: {} as any, + status: 500, + bodyAsText: "Internal Server Error", + }); + + try { + await service.semanticRerank("query", ["doc"]); + assert.fail("Should have thrown"); + } catch (e: any) { + assert.include(e.message, "status 500"); + assert.include(e.message, "Internal Server Error"); + } + }); + + it("should surface structured error payload from service", async () => { + const service = new InferenceService(createMockOptions()); + + const pipeline = (service as any).pipeline; + pipeline.sendRequest = async () => ({ + headers: { toJSON: () => ({ "x-ms-request-id": "err-id" }) } as any, + request: {} as any, + status: 400, + bodyAsText: JSON.stringify({ + code: "InvalidRequest", + message: "Error while formatting json document for the target paths Tas.", + details: null, + }), + }); + + try { + await service.semanticRerank("query", ["doc"]); + assert.fail("Should have thrown"); + } catch (e: any) { + assert.equal(e.code, "InvalidRequest"); + assert.include(e.message, "Error while formatting json document"); + assert.isDefined(e.headers); + } + }); + + it("should include document_type and target_paths in payload", async () => { + let capturedBody: string | undefined; + + const service = new InferenceService(createMockOptions()); + + const pipeline = (service as any).pipeline; + pipeline.sendRequest = async (_client: HttpClient, request: any) => { + capturedBody = request.body; + return { + headers: { toJSON: () => ({}) } as any, + request: {} as any, + status: 200, + bodyAsText: JSON.stringify({ Scores: [] }), + }; + }; + + await service.semanticRerank("test query", ["doc1"], { + document_type: "json", + target_paths: "/name,/description", + }); + + const parsedBody = JSON.parse(capturedBody!); + assert.equal(parsedBody.document_type, "json"); + assert.equal(parsedBody.target_paths, "/name,/description"); + }); + + it("should handle empty scores in response", async () => { + const service = new InferenceService(createMockOptions()); + + const pipeline = (service as any).pipeline; + pipeline.sendRequest = async () => ({ + headers: { toJSON: () => ({}) } as any, + request: {} as any, + status: 200, + bodyAsText: JSON.stringify({}), + }); + + const result = await service.semanticRerank("query", ["doc"]); + assert.deepEqual(result.rerankScores, []); + assert.isUndefined(result.latency); + assert.isUndefined(result.tokenUsage); + }); + + it("should handle null document in score", async () => { + const service = new InferenceService(createMockOptions()); + + const pipeline = (service as any).pipeline; + pipeline.sendRequest = async () => ({ + headers: { toJSON: () => ({}) } as any, + request: {} as any, + status: 200, + bodyAsText: JSON.stringify({ + Scores: [{ document: null, score: 0.9, index: 0 }], + }), + }); + + const result = await service.semanticRerank("query", ["doc"]); + assert.equal(result.rerankScores.length, 1); + assert.isNull(result.rerankScores[0].document); + assert.equal(result.rerankScores[0].score, 0.9); + }); + }); +}); diff --git a/sdk/cosmosdb/cosmos/test/internal/unit/inference/semanticRerank.spec.ts b/sdk/cosmosdb/cosmos/test/internal/unit/inference/semanticRerank.spec.ts new file mode 100644 index 000000000000..70b2db560fba --- /dev/null +++ b/sdk/cosmosdb/cosmos/test/internal/unit/inference/semanticRerank.spec.ts @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { describe, it, assert, beforeEach, afterEach } from "vitest"; +import type { TokenCredential, GetTokenOptions, AccessToken } from "@azure/core-auth"; +import { CosmosClient } from "../../../../src/CosmosClient.js"; + +class MockTokenCredential implements TokenCredential { + async getToken(scopes: string | string[], _options?: GetTokenOptions): Promise { + return { + token: "mock-token", + expiresOnTimestamp: Date.now() + 3600000, + }; + } +} + +describe("Container.semanticRerank", { timeout: 10000 }, () => { + let originalEnv: string | undefined; + + beforeEach(() => { + originalEnv = process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT; + process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT = + "https://test-inference.dbinference.azure.com"; + }); + + afterEach(() => { + if (originalEnv !== undefined) { + process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT = originalEnv; + } else { + delete process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT; + } + }); + + it("should throw when client is not using AAD authentication", async () => { + const client = new CosmosClient({ + endpoint: "https://test-account.documents.azure.com:443/", + key: "dGVzdC1rZXk=", // base64 "test-key" + }); + + const container = client.database("testdb").container("testcol"); + + try { + await container.semanticRerank("query", ["doc1"]); + assert.fail("Should have thrown"); + } catch (e: any) { + assert.include(e.message, "AAD authentication"); + } finally { + client.dispose(); + } + }); + + it("should throw when inference endpoint is not configured", async () => { + const savedEnv = process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT; + delete process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT; + + try { + const client = new CosmosClient({ + endpoint: "https://test-account.documents.azure.com:443/", + aadCredentials: new MockTokenCredential(), + }); + + const container = client.database("testdb").container("testcol"); + + try { + await container.semanticRerank("query", ["doc1"]); + assert.fail("Should have thrown"); + } catch (e: any) { + assert.include(e.message, "Inference endpoint is required"); + } finally { + client.dispose(); + } + } finally { + if (savedEnv !== undefined) { + process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT = savedEnv; + } + } + }); + + it("should delegate to ClientContext.semanticRerank", async () => { + const client = new CosmosClient({ + endpoint: "https://test-account.documents.azure.com:443/", + aadCredentials: new MockTokenCredential(), + }); + + const container = client.database("testdb").container("testcol"); + + // Verify the method exists and is callable + assert.isFunction(container.semanticRerank); + + client.dispose(); + }); + + it("should clean up inference service on client dispose", () => { + const client = new CosmosClient({ + endpoint: "https://test-account.documents.azure.com:443/", + aadCredentials: new MockTokenCredential(), + }); + + // Dispose should not throw + assert.doesNotThrow(() => client.dispose()); + }); +}); diff --git a/sdk/cosmosdb/cosmos/test/public/integration/semanticRerank.spec.ts b/sdk/cosmosdb/cosmos/test/public/integration/semanticRerank.spec.ts new file mode 100644 index 000000000000..bdad7a870bd5 --- /dev/null +++ b/sdk/cosmosdb/cosmos/test/public/integration/semanticRerank.spec.ts @@ -0,0 +1,199 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { DefaultAzureCredential } from "@azure/identity"; +import { CosmosClient } from "../../../src/index.js"; +import type { SemanticRerankResult } from "../../../src/index.js"; +import { describe, it, assert, beforeAll, afterAll } from "vitest"; + +/** + * Integration tests for the Semantic Rerank feature. + * + * These tests require: + * 1. AAD credentials with access to the Cosmos DB inference service + * 2. An inference endpoint registered for the Cosmos DB account + * + * Environment variables: + * - SEMANTIC_RERANK_ACCOUNT_ENDPOINT: Cosmos DB account endpoint + * - AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT: Inference service endpoint + * (e.g. "https://\{account\}.\{region\}.dbinference.azure.com") + * - AZURE_TENANT_ID: Azure AD tenant ID (optional, for DefaultAzureCredential) + * + * For the full-text-search + rerank test, additionally: + * - A database "virtualstore" with container "sportinggoods" and sample documents + */ +const accountEndpoint = process.env.SEMANTIC_RERANK_ACCOUNT_ENDPOINT; +const inferenceEndpoint = process.env.AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT; +const hasRequiredEnv = Boolean(accountEndpoint && inferenceEndpoint); + +describe.skipIf(!hasRequiredEnv)("SemanticRerankIntegration", { timeout: 120000 }, () => { + let client: CosmosClient; + + beforeAll(() => { + const aadCredentials = new DefaultAzureCredential(); + client = new CosmosClient({ + endpoint: accountEndpoint!, + aadCredentials, + }); + }); + + afterAll(() => { + client?.dispose(); + }); + + it("should rerank documents with scores, latency, and token usage", async () => { + // Use a placeholder container — the inference service is container-agnostic, + // it only needs the inference endpoint and AAD credentials. + const container = client.database("testdb").container("testcol"); + + const documents = [ + "Berlin is the capital of Germany.", + "Paris is the capital of France.", + "Madrid is the capital of Spain.", + ]; + + const context = "What is the capital of France?"; + + const result: SemanticRerankResult = await container.semanticRerank(context, documents, { + return_documents: true, + top_k: 10, + batch_size: 32, + }); + + // Verify scores are returned and correctly ordered + assert.isAbove(result.rerankScores.length, 0, "Should have rerank scores"); + assert.isAtMost(result.rerankScores.length, 3, "Should have at most 3 scores"); + + // The document about Paris/France should have the highest score + const topScore = result.rerankScores[0]; + assert.equal(topScore.index, 1, "Paris document (index 1) should rank highest"); + assert.isAbove(topScore.score, 0.5, "Top score should be well above 0.5"); + assert.equal(topScore.document, "Paris is the capital of France."); + + // Verify all scores have required fields + for (const score of result.rerankScores) { + assert.isNumber(score.score, "Score should be a number"); + assert.isNumber(score.index, "Index should be a number"); + assert.isString(score.document, "Document should be a string when return_documents is true"); + } + + // Verify metadata + assert.isDefined(result.latency, "Latency should be present in the result"); + assert.isDefined(result.tokenUsage, "Token usage should be present in the result"); + assert.isDefined(result.headers, "Headers should be present in the result"); + }); + + it("should rerank without returning documents when returnDocuments is not set", async () => { + const container = client.database("testdb").container("testcol"); + + const documents = ["Berlin is the capital of Germany.", "Paris is the capital of France."]; + + const result: SemanticRerankResult = await container.semanticRerank( + "What is the capital of France?", + documents, + ); + + assert.isAbove(result.rerankScores.length, 0, "Should have rerank scores"); + for (const score of result.rerankScores) { + assert.isNumber(score.score, "Score should be a number"); + assert.isNumber(score.index, "Index should be a number"); + } + }); + + /** + * End-to-end test: queries documents from a pre-existing Cosmos DB container, + * then reranks the results using the inference service. + * + * Prerequisite: database "rerank-test" with container "products" (partitioned by /category) + * must exist on the Cosmos DB account with sample documents already inserted. + */ + it.skip("should query Cosmos DB documents and rerank them", async () => { + const container = client.database("rerank-test").container("products"); + + // Step 1: Insert sample sporting goods documents + const sampleItems = [ + { + id: "sr-1", + category: "fitness", + name: "ProFit Power Tower", + description: + "Professional power tower with integrated pull-up bar, dip station, and vertical knee raise. Heavy-duty steel frame supports up to 300 lbs. Multiple grip positions for varied workouts. Ideal for home gyms with limited space.", + }, + { + id: "sr-2", + category: "fitness", + name: "FlexForce Cable Machine", + description: + "Compact cable crossover machine with multiple pulley adjustments. Features 200 lb weight stack and smooth motion guide rods. Perfect for strength training exercises including chest flys, lat pulldowns, and cable rows.", + }, + { + id: "sr-3", + category: "fitness", + name: "IronGrip Adjustable Dumbbells", + description: + "Quick-change adjustable dumbbell set ranging from 5 to 52.5 lbs per hand. Replaces 15 sets of weights. Space-saving design with durable steel construction and comfortable grip.", + }, + { + id: "sr-4", + category: "fitness", + name: "EnduraRun Treadmill", + description: + "Folding treadmill with cushioned running deck and 12 incline levels. Built-in heart rate monitor and Bluetooth speaker. Supports speeds up to 12 mph. Compact folding design for apartment living.", + }, + { + id: "sr-5", + category: "fitness", + name: "BudgetFlex Home Gym", + description: + "Most economical home gym system with integrated pull-up bar and multiple pulley adjustments. Affordable yet sturdy construction ideal for home gyms. Includes leg press attachment and preacher curl pad.", + }, + ]; + + try { + for (const item of sampleItems) { + await container.items.upsert(item); + } + + // Step 2: Query documents using a standard Cosmos DB query + const { resources: queryResults } = await container.items + .query("SELECT c.id, c.name, c.description FROM c WHERE c.category = 'fitness'") + .fetchAll(); + + const documents: string[] = (queryResults ?? []).map((item) => JSON.stringify(item)); + assert.isAbove(documents.length, 0, "Should have documents from query"); + + // Step 3: Rerank the query results using semantic reranker + const context = "most economical with multiple pulley adjustments and ideal for home gyms"; + + const result: SemanticRerankResult = await container.semanticRerank(context, documents, { + return_documents: true, + top_k: 10, + batch_size: 32, + }); + + // Step 4: Verify the rerank result + assert.isAbove(result.rerankScores.length, 0, "Should have rerank scores"); + assert.isDefined(result.latency, "Latency should be present"); + assert.isDefined(result.tokenUsage, "Token usage should be present"); + + // The BudgetFlex Home Gym (id: "sr-5") should rank highest since its description + // directly matches the rerank context about "most economical" and "pulley adjustments" + const topDoc = result.rerankScores[0]; + assert.isNotNull(topDoc.document, "Top document should be returned"); + assert.include( + topDoc.document!, + "economical", + "Top result should be the most relevant to the rerank context", + ); + } finally { + // Clean up: delete inserted items + for (const item of sampleItems) { + try { + await container.item(item.id, item.category).delete(); + } catch { + // Ignore cleanup errors + } + } + } + }); +}); diff --git a/sdk/cosmosdb/cosmos/test/snippets.spec.ts b/sdk/cosmosdb/cosmos/test/snippets.spec.ts index a0bf3bb11b73..f63c50f9562d 100644 --- a/sdk/cosmosdb/cosmos/test/snippets.spec.ts +++ b/sdk/cosmosdb/cosmos/test/snippets.spec.ts @@ -1841,4 +1841,32 @@ describe("snippets", () => { } } }); + it("ContainerSemanticRerank", async () => { + const endpoint = "https://your-account.documents.azure.com"; + const aadCredentials = new DefaultAzureCredential(); + const client = new CosmosClient({ + endpoint, + aadCredentials, + }); + // @ts-preserve-whitespace + const { database } = await client.databases.createIfNotExists({ id: "Test Database" }); + const { container } = await database.containers.createIfNotExists({ id: "Test Container" }); + // @ts-preserve-whitespace + const queryResults = ["doc1 JSON", "doc2 JSON", "doc3 JSON"]; + const result = await container.semanticRerank( + "most economical with multiple adjustments", + queryResults, + { return_documents: true, top_k: 10, sort: true }, + ); + // Access the top-ranked document + if (result.rerankScores.length > 0) { + const topResult = result.rerankScores[0]; + const topScore = topResult.score; + const topDocument = topResult.document; + if (topDocument !== null) { + console.log("Top-ranked document:", topDocument); + } + console.log("Top score:", topScore); + } + }); });