diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptionContainer.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptionContainer.cs index b968ad5bb9..48615d818b 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptionContainer.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptionContainer.cs @@ -1008,6 +1008,21 @@ public override Task IsFeedRangePartOfAsync( } #endif +#if PREVIEW && SDKPROJECTREF + public override Task SemanticRerankAsync( + string rerankContext, + IEnumerable documents, + IDictionary options = null, + CancellationToken cancellationToken = default) + { + return this.container.SemanticRerankAsync( + rerankContext, + documents, + options, + cancellationToken); + } +#endif + private async Task ReadManyItemsHelperAsync( IReadOnlyList<(string id, PartitionKey partitionKey)> items, ReadManyRequestOptions readManyRequestOptions = null, diff --git a/Microsoft.Azure.Cosmos.Encryption/src/EncryptionContainer.cs b/Microsoft.Azure.Cosmos.Encryption/src/EncryptionContainer.cs index fec9fa0f3f..47f0c9ed8a 100644 --- a/Microsoft.Azure.Cosmos.Encryption/src/EncryptionContainer.cs +++ b/Microsoft.Azure.Cosmos.Encryption/src/EncryptionContainer.cs @@ -732,6 +732,21 @@ public override FeedIterator GetItemQueryIterator( } #if ENCRYPTIONPREVIEW +#if SDKPROJECTREF + public override Task SemanticRerankAsync( + string rerankContext, + IEnumerable documents, + IDictionary options = null, + CancellationToken cancellationToken = default) + { + return this.Container.SemanticRerankAsync( + rerankContext, + documents, + options, + cancellationToken); + } + +#endif public override async Task DeleteAllItemsByPartitionKeyStreamAsync( Cosmos.PartitionKey partitionKey, RequestOptions requestOptions = null, diff --git a/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProvider.cs b/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProvider.cs index b7bc7a4475..9839bf1039 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProvider.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProvider.cs @@ -52,6 +52,12 @@ public abstract ValueTask GetUserAuthorizationTokenAsync( AuthorizationTokenType tokenType, ITrace trace); + public abstract ValueTask AddInferenceAuthorizationHeaderAsync( + INameValueCollection headersCollection, + Uri requestAddress, + string verb, + AuthorizationTokenType tokenType); + public abstract void TraceUnauthorized( DocumentClientException dce, string authorizationToken, diff --git a/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderMasterKey.cs b/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderMasterKey.cs index 134640ba10..278be856eb 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderMasterKey.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderMasterKey.cs @@ -214,6 +214,11 @@ private void Dispose(bool disposing) this.authKeyHashFunction = null; } + public override ValueTask AddInferenceAuthorizationHeaderAsync(INameValueCollection headersCollection, Uri requestAddress, string verb, AuthorizationTokenType tokenType) + { + throw new NotImplementedException("AddInferenceAuthorizationHeaderAsync is only valid for AAD"); + } + // Use C# finalizer syntax for finalization code. // This finalizer will run only if the Dispose method does not get called. // It gives your base class the opportunity to finalize. diff --git a/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderResourceToken.cs b/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderResourceToken.cs index 182d4c5ba7..1697589e5d 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderResourceToken.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderResourceToken.cs @@ -92,6 +92,11 @@ private void Dispose(bool disposing) // Do nothing } + public override ValueTask AddInferenceAuthorizationHeaderAsync(INameValueCollection headersCollection, Uri requestAddress, string verb, AuthorizationTokenType tokenType) + { + throw new NotImplementedException("AddInferenceAuthorizationHeaderAsync is only valid for AAD"); + } + // Use C# finalizer syntax for finalization code. // This finalizer will run only if the Dispose method does not get called. // It gives your base class the opportunity to finalize. diff --git a/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderTokenCredential.cs b/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderTokenCredential.cs index dff20331f6..09e422bb8a 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderTokenCredential.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderTokenCredential.cs @@ -15,14 +15,18 @@ namespace Microsoft.Azure.Cosmos internal sealed class AuthorizationTokenProviderTokenCredential : AuthorizationTokenProvider { + private const string InferenceTokenPrefix = "Bearer "; internal readonly TokenCredentialCache tokenCredentialCache; private bool isDisposed = false; + internal readonly TokenCredential tokenCredential; + public AuthorizationTokenProviderTokenCredential( TokenCredential tokenCredential, Uri accountEndpoint, TimeSpan? backgroundTokenCredentialRefreshInterval) { + this.tokenCredential = tokenCredential ?? throw new ArgumentNullException(nameof(tokenCredential)); this.tokenCredentialCache = new TokenCredentialCache( tokenCredential: tokenCredential, accountEndpoint: accountEndpoint, @@ -71,6 +75,21 @@ public override async ValueTask AddAuthorizationHeaderAsync( } } + public override async ValueTask AddInferenceAuthorizationHeaderAsync( + INameValueCollection headersCollection, + Uri requestAddress, + string verb, + AuthorizationTokenType tokenType) + { + using (Trace trace = Trace.GetRootTrace(nameof(GetUserAuthorizationTokenAsync), TraceComponent.Authorization, TraceLevel.Info)) + { + string token = await this.tokenCredentialCache.GetTokenAsync(trace); + + string inferenceToken = $"{InferenceTokenPrefix}{token}"; + headersCollection.Add(HttpConstants.HttpHeaders.Authorization, inferenceToken); + } + } + public override void TraceUnauthorized( DocumentClientException dce, string authorizationToken, diff --git a/Microsoft.Azure.Cosmos/src/Authorization/AzureKeyCredentialAuthorizationTokenProvider.cs b/Microsoft.Azure.Cosmos/src/Authorization/AzureKeyCredentialAuthorizationTokenProvider.cs index 602c419a6c..03a0bff2a4 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/AzureKeyCredentialAuthorizationTokenProvider.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/AzureKeyCredentialAuthorizationTokenProvider.cs @@ -125,5 +125,10 @@ private void CheckAndRefreshTokenProvider() } } } + + public override ValueTask AddInferenceAuthorizationHeaderAsync(INameValueCollection headersCollection, Uri requestAddress, string verb, AuthorizationTokenType tokenType) + { + throw new NotImplementedException("AddInferenceAuthorizationHeaderAsync is only valid for AAD"); + } } } diff --git a/Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs b/Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs new file mode 100644 index 0000000000..6f808fa691 --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs @@ -0,0 +1,209 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos +{ + using System; + using System.Collections.Generic; + using System.Diagnostics; + using System.Linq; + using System.Net.Http; + using System.Net.Http.Headers; + using System.Text; + using System.Threading; + using System.Threading.Tasks; + using global::Azure.Core; + using Microsoft.Azure.Documents; + using Microsoft.Azure.Documents.Collections; + + /// + /// Provides functionality to interact with the Cosmos DB Inference Service for semantic reranking. + /// + internal class InferenceService : IDisposable + { + // Base path for the inference service endpoint. + private const string basePath = "/inference/semanticReranking"; + // User agent string for inference requests. + private const string inferenceUserAgent = "cosmos-inference-dotnet"; + // Default scope for AAD authentication. + private const string inferenceServiceDefaultScope = "https://dbinference.azure.com/.default"; + private const int inferenceServiceDefaultMaxConnectionLimit = 50; + + private readonly int inferenceServiceMaxConnectionLimit; + private readonly string inferenceServiceBaseUrl; + private readonly Uri inferenceEndpoint; + + private HttpClient httpClient; + private AuthorizationTokenProvider cosmosAuthorization; + + private bool disposedValue; + + /// + /// Initializes a new instance of the class. + /// + /// The CosmosClient instance. + /// Thrown if AAD authentication is not used. + public InferenceService(CosmosClient client) + { + this.inferenceServiceBaseUrl = ConfigurationManager.GetEnvironmentVariable("AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT", null); + + if (string.IsNullOrEmpty(this.inferenceServiceBaseUrl)) + { + throw new ArgumentNullException("Set environment variable AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT to use inference service"); + } + + this.inferenceServiceMaxConnectionLimit = ConfigurationManager.GetEnvironmentVariable( + "AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_SERVICE_MAX_CONNECTION_LIMIT", + inferenceServiceDefaultMaxConnectionLimit) ?? inferenceServiceDefaultMaxConnectionLimit; + + // Create and configure HttpClient for inference requests. + HttpMessageHandler httpMessageHandler = CosmosHttpClientCore.CreateHttpClientHandler( + gatewayModeMaxConnectionLimit: this.inferenceServiceMaxConnectionLimit, + webProxy: null, + serverCertificateCustomValidationCallback: client.DocumentClient.ConnectionPolicy.ServerCertificateCustomValidationCallback); + + this.httpClient = new HttpClient(httpMessageHandler); + + this.CreateClientHelper(this.httpClient); + + // Construct the inference service endpoint URI. + this.inferenceEndpoint = new Uri($"{this.inferenceServiceBaseUrl}/{basePath}"); + + // Ensure AAD authentication is used. + if (client.DocumentClient.cosmosAuthorization.GetType() != typeof(AuthorizationTokenProviderTokenCredential)) + { + throw new InvalidOperationException("InferenceService only supports AAD authentication."); + } + + // Set up token credential for authorization. + // This is done to ensure the correct scope, which is different than the scope of the client, is used for the inference service. + AuthorizationTokenProviderTokenCredential defaultOperationTokenProvider = client.DocumentClient.cosmosAuthorization as AuthorizationTokenProviderTokenCredential; + TokenCredential tokenCredential = defaultOperationTokenProvider.tokenCredential; + + this.cosmosAuthorization = new AuthorizationTokenProviderTokenCredential( + tokenCredential: tokenCredential, + accountEndpoint: new Uri(inferenceServiceDefaultScope), + backgroundTokenCredentialRefreshInterval: client.ClientOptions?.TokenCredentialBackgroundRefreshInterval); + } + + /// + /// Sends a semantic rerank request to the inference service. + /// + /// The context/query for reranking. + /// The documents to be reranked. + /// Optional additional options for the request. + /// Cancellation token. + /// A dictionary containing the reranked results. + public async Task SemanticRerankAsync( + string rerankContext, + IEnumerable documents, + IDictionary options = null, + CancellationToken cancellationToken = default) + { + // Prepare HTTP request for semantic reranking. + HttpRequestMessage message = new HttpRequestMessage(HttpMethod.Post, this.inferenceEndpoint); + INameValueCollection additionalHeaders = new RequestNameValueCollection(); + await this.cosmosAuthorization.AddInferenceAuthorizationHeaderAsync( + headersCollection: additionalHeaders, + this.inferenceEndpoint, + HttpConstants.HttpMethods.Post, + AuthorizationTokenType.AadToken); + additionalHeaders.Add(HttpConstants.HttpHeaders.UserAgent, inferenceUserAgent); + + // Add all headers to the HTTP request. + foreach (string key in additionalHeaders.AllKeys()) + { + message.Headers.Add(key, additionalHeaders[key]); + } + + // Build the request payload. + Dictionary body = this.AddSemanticRerankPayload(rerankContext, documents, options); + + message.Content = new StringContent( + Newtonsoft.Json.JsonConvert.SerializeObject(body), + Encoding.UTF8, + RuntimeConstants.MediaTypes.Json); + + // Send the request and ensure success. + HttpResponseMessage responseMessage = await this.httpClient.SendAsync(message, cancellationToken); + responseMessage.EnsureSuccessStatusCode(); + + // Deserialize and return the response content as a dictionary. + return await SemanticRerankResult.DeserializeSemanticRerankResultAsync(responseMessage); + } + + /// + /// Configures the provided HttpClient with default headers and settings for inference requests. + /// + /// The HttpClient to configure. + private void CreateClientHelper(HttpClient httpClient) + { + httpClient.Timeout = TimeSpan.FromSeconds(120); + httpClient.DefaultRequestHeaders.CacheControl = new CacheControlHeaderValue { NoCache = true }; + + // Set requested API version header for version enforcement. + httpClient.DefaultRequestHeaders.Add(HttpConstants.HttpHeaders.Version, + HttpConstants.Versions.CurrentVersion); + + httpClient.DefaultRequestHeaders.Add(HttpConstants.HttpHeaders.Accept, RuntimeConstants.MediaTypes.Json); + } + + /// + /// Constructs the payload for the semantic rerank request. + /// + /// The context/query for reranking. + /// The documents to be reranked. + /// Optional additional options. + /// A dictionary representing the request payload. + private Dictionary AddSemanticRerankPayload(string rerankContext, IEnumerable documents, IDictionary options) + { + Dictionary payload = new Dictionary + { + { "query", rerankContext }, + { "documents", documents.ToArray() } + }; + + if (options == null) + { + return payload; + } + + // Add any additional options to the payload. + foreach (string option in options.Keys) + { + payload.Add(option, options[option]); + } + + return payload; + } + + /// + /// Disposes managed resources used by the service. + /// + /// Indicates if called from Dispose. + protected void Dispose(bool disposing) + { + if (!this.disposedValue) + { + if (disposing) + { + this.httpClient.Dispose(); + this.cosmosAuthorization.Dispose(); + this.httpClient = null; + this.cosmosAuthorization = null; + } + + this.disposedValue = true; + } + } + + /// + /// Disposes the service and its resources. + /// + public void Dispose() + { + this.Dispose(true); + } + } +} diff --git a/Microsoft.Azure.Cosmos/src/Inference/RerankScore.cs b/Microsoft.Azure.Cosmos/src/Inference/RerankScore.cs new file mode 100644 index 0000000000..407658cbe7 --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/Inference/RerankScore.cs @@ -0,0 +1,46 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos +{ + /// + /// Represents the score assigned to a document after a reranking operation. + /// +#if PREVIEW + public +#else + internal +#endif + + class RerankScore + { + /// + /// Gets the document content or identifier that was reranked. + /// + public object Document { get; } + + /// + /// Gets the score assigned to the document after reranking. + /// + public double Score { get; } + + /// + /// Gets the original index or position of the document before reranking. + /// + public int Index { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The document content or identifier. + /// The reranked score for the document. + /// The original index of the document. + public RerankScore(object document, double score, int index) + { + this.Document = document; + this.Score = score; + this.Index = index; + } + } +} diff --git a/Microsoft.Azure.Cosmos/src/Inference/SemanticRerankResult.cs b/Microsoft.Azure.Cosmos/src/Inference/SemanticRerankResult.cs new file mode 100644 index 0000000000..e228f8d0fe --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/Inference/SemanticRerankResult.cs @@ -0,0 +1,131 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos +{ + using System.Collections.Generic; + using System.IO; + using System.Net.Http; + using System.Net.Http.Headers; + using System.Text.Json; + using System.Threading.Tasks; + + /// + /// Represents the result of a semantic reranking operation, including rerank scores, + /// latency, token usage, and HTTP response headers. + /// +#if PREVIEW + public +#else + internal +#endif + + class SemanticRerankResult + { + /// + /// Gets the HTTP response headers associated with the rerank operation. + /// + public HttpResponseHeaders Headers { get; } + + /// + /// Gets the list of rerank scores for the documents. + /// + public IReadOnlyList RerankScores { get; } + + /// + /// Gets the latency information for the rerank operation. + /// + public Dictionary Latency { get; } + + /// + /// Gets the token usage information for the rerank operation. + /// + public Dictionary TokenUseage { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The list of rerank scores. + /// The latency information. + /// The token usage information. + /// The HTTP response headers. + private SemanticRerankResult( + IReadOnlyList rerankScores, + Dictionary latency, + Dictionary tokenUseage, + HttpResponseHeaders headers) + { + this.RerankScores = rerankScores; + this.Latency = latency; + this.TokenUseage = tokenUseage; + this.Headers = headers; + } + + /// + /// Deserializes a from an HTTP response message asynchronously. + /// + /// The HTTP response message containing the rerank result. + /// A task that represents the asynchronous operation. The task result contains the deserialized . + internal static async Task DeserializeSemanticRerankResultAsync(HttpResponseMessage responseMessage) + { + Stream content = await responseMessage.Content.ReadAsStreamAsync(); + + using (content) + { + using (JsonDocument doc = await JsonDocument.ParseAsync(content)) + { + JsonElement root = doc.RootElement; + + // Parse Scores + List rerankScores = new List(); + if (root.TryGetProperty("Scores", out JsonElement scoresElement) && scoresElement.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement item in scoresElement.EnumerateArray()) + { + object document = null; + if (item.TryGetProperty("document", out JsonElement docElement)) + { + // Try to deserialize as an object + switch (docElement.ValueKind) + { + case JsonValueKind.Object: + document = JsonSerializer.Deserialize>(docElement.GetRawText()); + break; + case JsonValueKind.Null: + document = null; + break; + } + } + + double score = item.TryGetProperty("score", out JsonElement scoreElement) && scoreElement.TryGetDouble(out double s) ? s : 0.0; + int index = item.TryGetProperty("index", out JsonElement indexElement) && indexElement.TryGetInt32(out int i) ? i : -1; + + rerankScores.Add(new RerankScore(document, score, index)); + } + } + + // Parse latency + Dictionary latency = null; + if (root.TryGetProperty("latency", out JsonElement latencyElement) && latencyElement.ValueKind == JsonValueKind.Object) + { + latency = JsonSerializer.Deserialize>(latencyElement.GetRawText()); + } + + // Parse token_usage + Dictionary tokenUsage = null; + if (root.TryGetProperty("token_usage", out JsonElement tokenUsageElement) && tokenUsageElement.ValueKind == JsonValueKind.Object) + { + tokenUsage = JsonSerializer.Deserialize>(tokenUsageElement.GetRawText()); + } + + return new SemanticRerankResult( + rerankScores, + latency, + tokenUsage, + responseMessage.Headers); + } + } + } + } +} diff --git a/Microsoft.Azure.Cosmos/src/Resource/ClientContextCore.cs b/Microsoft.Azure.Cosmos/src/Resource/ClientContextCore.cs index faa9f50d2a..0014ed843a 100644 --- a/Microsoft.Azure.Cosmos/src/Resource/ClientContextCore.cs +++ b/Microsoft.Azure.Cosmos/src/Resource/ClientContextCore.cs @@ -5,6 +5,7 @@ namespace Microsoft.Azure.Cosmos { using System; + using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Net.Http; @@ -34,6 +35,7 @@ internal class ClientContextCore : CosmosClientContext private readonly string userAgent; private bool isDisposed = false; + private InferenceService inferenceService = null; private ClientContextCore( CosmosClient client, @@ -467,6 +469,32 @@ await this.DocumentClient.OpenConnectionsToAllReplicasAsync( cancellationToken); } + /// + internal override async Task SemanticRerankAsync( + string rerankContext, + IEnumerable documents, + IDictionary options = null, + CancellationToken cancellationToken = default) + { + InferenceService inferenceService = this.GetOrCreateInferenceService(); + return await inferenceService.SemanticRerankAsync(rerankContext, documents, options, cancellationToken); + } + + /// + internal override InferenceService GetOrCreateInferenceService() + { + if (this.inferenceService == null) + { + // Double check locking to avoid unnecessary locks + lock (this) + { + this.inferenceService ??= new InferenceService(this.client); + } + } + + return this.inferenceService; + } + public override void Dispose() { this.Dispose(true); @@ -484,6 +512,7 @@ protected virtual void Dispose(bool disposing) { this.batchExecutorCache.Dispose(); this.DocumentClient.Dispose(); + this.inferenceService?.Dispose(); } this.isDisposed = true; diff --git a/Microsoft.Azure.Cosmos/src/Resource/Container/Container.cs b/Microsoft.Azure.Cosmos/src/Resource/Container/Container.cs index 379ee407f3..0232eb6d78 100644 --- a/Microsoft.Azure.Cosmos/src/Resource/Container/Container.cs +++ b/Microsoft.Azure.Cosmos/src/Resource/Container/Container.cs @@ -1679,6 +1679,91 @@ public abstract ChangeFeedProcessorBuilder GetChangeFeedProcessorBuilderWithManu string processorName, ChangeFeedStreamHandlerWithManualCheckpoint onChangesDelegate); +#if PREVIEW + /// + /// Rerank a list of documents using semantic reranking. + /// This method uses a semantic reranker to score and reorder the provided documents + /// based on their relevance to the given reranking context. + /// + /// The sematic reranking requests will not use the regular request flow and have it's own client. This will not use the default SDK retry policies. + /// + /// To use this feature, you must set up a Semantic Reranker resource in Azure and provide the endpoint and key via the environment variable: "AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT" + /// By default the Semantic Reranking will have a default max connection limit of 50, to change this set the enviroment variable "AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_SERVICE_MAX_CONNECTION_LIMIT" to the desired value before creating the CosmosClient. + /// + /// The context (ex: query string) to use for reranking the documents. + /// A list of documents to be reranked + /// (Optional) The options for the semantic reranking request. + /// (Optional) representing request cancellation. + /// The reranking results, typically including the reranked documents and their scores. + /// /// + /// + /// documents = new List(); + /// FeedIterator resultSetIterator = container.GetItemQueryIterator( + /// new QueryDefinition(queryString), + /// requestOptions: new QueryRequestOptions() + /// { + /// MaxItemCount = 15, + /// }); + /// + /// while (resultSetIterator.HasMoreResults) + /// { + /// FeedResponse response = await resultSetIterator.ReadNextAsync(); + /// foreach (JsonElement item in response) + /// { + /// documents.Add(item.ToString()); + /// } + /// } + /// + /// Dictionary options = new Dictionary + /// { + /// { "return_documents", true }, + /// { "top_k", 10 }, + /// { "batch_size", 32 }, + /// { "sort", true } + /// }; + /// + /// SemanticRerankResult results = await container.SemanticRerankAsync( + /// reranking_context, + /// documents, + /// options); + /// + /// // get the best resulting document from the query + /// results.RerankScores.First().Document; + /// // or the index of the document in the original list + /// results.RerankScores.First().Index; + /// // or the reranking score + /// results.RerankScores.First().Score; + /// + /// // get the latency information from the reranking operation + /// Dictonary tokenUseageInfo = results.TokenUseage; + /// ]]> + /// + /// + public abstract Task SemanticRerankAsync( + string rerankContext, + IEnumerable documents, + IDictionary options = null, + CancellationToken cancellationToken = default); +#endif + /// /// Deletes all items in the Container with the specified value. /// Starts an asynchronous Cosmos DB background operation which deletes all items in the Container with the specified value. diff --git a/Microsoft.Azure.Cosmos/src/Resource/Container/ContainerInlineCore.cs b/Microsoft.Azure.Cosmos/src/Resource/Container/ContainerInlineCore.cs index 44a409eed1..cd38bada65 100644 --- a/Microsoft.Azure.Cosmos/src/Resource/Container/ContainerInlineCore.cs +++ b/Microsoft.Azure.Cosmos/src/Resource/Container/ContainerInlineCore.cs @@ -697,5 +697,16 @@ public override Task IsFeedRangePartOfAsync( y, cancellationToken: cancellationToken)); } + +#if PREVIEW + public override Task SemanticRerankAsync( + string rerankContext, + IEnumerable documents, + IDictionary options = null, + CancellationToken cancellationToken = default) + { + return this.ClientContext.SemanticRerankAsync(rerankContext, documents, options, cancellationToken); + } +#endif } } \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/src/Resource/CosmosClientContext.cs b/Microsoft.Azure.Cosmos/src/Resource/CosmosClientContext.cs index e4c2592cc4..da5a613dab 100644 --- a/Microsoft.Azure.Cosmos/src/Resource/CosmosClientContext.cs +++ b/Microsoft.Azure.Cosmos/src/Resource/CosmosClientContext.cs @@ -5,6 +5,7 @@ namespace Microsoft.Azure.Cosmos { using System; + using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; @@ -132,6 +133,33 @@ internal abstract Task InitializeContainerUsingRntbdAsync( string containerLinkUri, CancellationToken cancellationToken); + /// + /// Rerank a list of documents using semantic reranking. + /// This method uses a semantic reranker to score and reorder the provided documents + /// based on their relevance to the given reranking context. + /// + /// The sematic reranking requests will not use the regular request flow and not use the default SDK retry policies. + /// + /// The context (ex: query string) to use for reranking the documents. + /// A list of documents to be reranked + /// (Optional) The options for the semantic reranking request. + /// (Optional) representing request cancellation. + /// The reranking results, typically including the reranked documents and their scores. + internal abstract Task SemanticRerankAsync( + string rerankContext, + IEnumerable documents, + IDictionary options = null, + CancellationToken cancellationToken = default); + + /// + /// Creates, or gets if already created, the inference service for this client + /// This will have a seperate http client that is used to make calls to the inference end point + /// + /// This method exists in the client context so the infernece service can be easily disposed when the client is disposed + /// + /// the inferenceService + internal abstract InferenceService GetOrCreateInferenceService(); + public abstract void Dispose(); } } \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj index efc6c67076..fdde935e9f 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj @@ -51,6 +51,7 @@ + diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/SemanticRerankingIntegrationTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/SemanticRerankingIntegrationTests.cs new file mode 100644 index 0000000000..40f4a6ae39 --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/SemanticRerankingIntegrationTests.cs @@ -0,0 +1,118 @@ +namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests +{ + using System; + using System.Collections.Generic; + using System.Text.Json; + using System.Text.Json.Serialization; + using System.Threading; + using System.Threading.Tasks; + using global::Azure.Core; + using global::Azure.Identity; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using static Microsoft.Azure.Cosmos.SDK.EmulatorTests.MultiRegionSetupHelpers; + + [TestClass] + public class SemanticRerankingIntegrationTests + { + private string connectionString; + private CosmosClient client; + + private CosmosSystemTextJsonSerializer cosmosSystemTextJsonSerializer; + + [TestInitialize] + public void TestInitAsync() + { + this.connectionString = "https://inferencee2etest.documents.azure.com:443/"; + Environment.SetEnvironmentVariable("AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT", "https://inferencee2etest.dbinference.azure.com"); + DefaultAzureCredentialOptions options = new DefaultAzureCredentialOptions + { + TenantId = "72f988bf-86f1-41af-91ab-2d7cd011db47", + ExcludeVisualStudioCredential = true + }; + + //Create a cosmos client using AAD authentication + TokenCredential tokenCredential = new DefaultAzureCredential(options); + + JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions() + { + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }; + this.cosmosSystemTextJsonSerializer = new MultiRegionSetupHelpers.CosmosSystemTextJsonSerializer(jsonSerializerOptions); + + if (string.IsNullOrEmpty(this.connectionString)) + { + Assert.Fail("Set environment variable COSMOSDB_MULTI_REGION to run the tests"); + } + this.client = new CosmosClient( + this.connectionString, + tokenCredential, + new CosmosClientOptions() + { + Serializer = this.cosmosSystemTextJsonSerializer, + }); + } + + [TestCleanup] + public void TestCleanup() + { + this.client?.Dispose(); + } + +#if PREVIEW + [TestMethod] + [TestCategory("Ignore")] + [Timeout(70000)] + public async Task SemanticRerankTest() + { + Database db = this.client.GetDatabase("virtualstore"); + Container container = db.GetContainer("sportinggoods"); + + string search_text = "integrated pull-up bar"; + + string queryString = $@" + SELECT TOP 15 c.id, c.Name, c.Brand, c.Description + FROM c + WHERE FullTextContains(c.Description, ""{search_text}"") + ORDER BY RANK FullTextScore(c.Description, ""{search_text}"") + "; + + string reranking_context = "most economical with multiple pulley adjustmnets and ideal for home gyms"; + + List documents = new List(); + FeedIterator resultSetIterator = container.GetItemQueryIterator( + new QueryDefinition(queryString), + requestOptions: new QueryRequestOptions() + { + MaxItemCount = 15, + }); + + while (resultSetIterator.HasMoreResults) + { + FeedResponse response = await resultSetIterator.ReadNextAsync(); + foreach (JsonElement item in response) + { + documents.Add(item.ToString()); + } + } + + Dictionary options = new Dictionary + { + { "return_documents", true }, + { "top_k", 10 }, + { "batch_size", 32 }, + { "sort", true } + }; + + SemanticRerankResult results = await container.SemanticRerankAsync( + reranking_context, + documents, + options); + + Assert.IsTrue(results.RerankScores.Count > 0); + Assert.AreEqual(4, results.RerankScores[0].Index); + Assert.IsNotNull(results.Latency); + Assert.IsNotNull(results.TokenUseage); + } +#endif + } +} diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetPreviewSDKAPI.net6.json b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetPreviewSDKAPI.net6.json index 228784345a..67e92f2223 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetPreviewSDKAPI.net6.json +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetPreviewSDKAPI.net6.json @@ -293,6 +293,11 @@ "Attributes": [], "MethodInfo": "Microsoft.Azure.Cosmos.ChangeFeedProcessorBuilder GetChangeFeedProcessorBuilderWithAllVersionsAndDeletes[T](System.String, ChangeFeedHandler`1);IsAbstract:True;IsStatic:False;IsVirtual:True;IsGenericMethod:True;IsConstructor:False;IsFinal:False;" }, + "System.Threading.Tasks.Task`1[Microsoft.Azure.Cosmos.SemanticRerankResult] SemanticRerankAsync(System.String, System.Collections.Generic.IEnumerable`1[System.String], System.Collections.Generic.IDictionary`2[System.String,System.Object], System.Threading.CancellationToken)": { + "Type": "Method", + "Attributes": [], + "MethodInfo": "System.Threading.Tasks.Task`1[Microsoft.Azure.Cosmos.SemanticRerankResult] SemanticRerankAsync(System.String, System.Collections.Generic.IEnumerable`1[System.String], System.Collections.Generic.IDictionary`2[System.String,System.Object], System.Threading.CancellationToken);IsAbstract:True;IsStatic:False;IsVirtual:True;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, "System.Threading.Tasks.Task`1[System.Boolean] IsFeedRangePartOfAsync(Microsoft.Azure.Cosmos.FeedRange, Microsoft.Azure.Cosmos.FeedRange, System.Threading.CancellationToken)": { "Type": "Method", "Attributes": [], @@ -1390,6 +1395,107 @@ }, "NestedTypes": {} }, + "Microsoft.Azure.Cosmos.RerankScore;System.Object;IsAbstract:False;IsSealed:False;IsInterface:False;IsEnum:False;IsClass:True;IsValueType:False;IsNested:False;IsGenericType:False;IsSerializable:False": { + "Subclasses": {}, + "Members": { + "Double get_Score()[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": { + "Type": "Method", + "Attributes": [ + "CompilerGeneratedAttribute" + ], + "MethodInfo": "Double get_Score();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, + "Double Score": { + "Type": "Property", + "Attributes": [], + "MethodInfo": "Double Score;CanRead:True;CanWrite:False;Double get_Score();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, + "Int32 get_Index()[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": { + "Type": "Method", + "Attributes": [ + "CompilerGeneratedAttribute" + ], + "MethodInfo": "Int32 get_Index();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, + "Int32 Index": { + "Type": "Property", + "Attributes": [], + "MethodInfo": "Int32 Index;CanRead:True;CanWrite:False;Int32 get_Index();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, + "System.Object Document": { + "Type": "Property", + "Attributes": [], + "MethodInfo": "System.Object Document;CanRead:True;CanWrite:False;System.Object get_Document();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, + "System.Object get_Document()[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": { + "Type": "Method", + "Attributes": [ + "CompilerGeneratedAttribute" + ], + "MethodInfo": "System.Object get_Document();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, + "Void .ctor(System.Object, Double, Int32)": { + "Type": "Constructor", + "Attributes": [], + "MethodInfo": "Void .ctor(System.Object, Double, Int32)" + } + }, + "NestedTypes": {} + }, + "Microsoft.Azure.Cosmos.SemanticRerankResult;System.Object;IsAbstract:False;IsSealed:False;IsInterface:False;IsEnum:False;IsClass:True;IsValueType:False;IsNested:False;IsGenericType:False;IsSerializable:False": { + "Subclasses": {}, + "Members": { + "System.Collections.Generic.Dictionary`2[System.String,System.Object] get_Latency()[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": { + "Type": "Method", + "Attributes": [ + "CompilerGeneratedAttribute" + ], + "MethodInfo": "System.Collections.Generic.Dictionary`2[System.String,System.Object] get_Latency();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, + "System.Collections.Generic.Dictionary`2[System.String,System.Object] get_TokenUseage()[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": { + "Type": "Method", + "Attributes": [ + "CompilerGeneratedAttribute" + ], + "MethodInfo": "System.Collections.Generic.Dictionary`2[System.String,System.Object] get_TokenUseage();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, + "System.Collections.Generic.Dictionary`2[System.String,System.Object] Latency": { + "Type": "Property", + "Attributes": [], + "MethodInfo": "System.Collections.Generic.Dictionary`2[System.String,System.Object] Latency;CanRead:True;CanWrite:False;System.Collections.Generic.Dictionary`2[System.String,System.Object] get_Latency();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, + "System.Collections.Generic.Dictionary`2[System.String,System.Object] TokenUseage": { + "Type": "Property", + "Attributes": [], + "MethodInfo": "System.Collections.Generic.Dictionary`2[System.String,System.Object] TokenUseage;CanRead:True;CanWrite:False;System.Collections.Generic.Dictionary`2[System.String,System.Object] get_TokenUseage();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, + "System.Collections.Generic.IReadOnlyList`1[Microsoft.Azure.Cosmos.RerankScore] get_RerankScores()[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": { + "Type": "Method", + "Attributes": [ + "CompilerGeneratedAttribute" + ], + "MethodInfo": "System.Collections.Generic.IReadOnlyList`1[Microsoft.Azure.Cosmos.RerankScore] get_RerankScores();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, + "System.Collections.Generic.IReadOnlyList`1[Microsoft.Azure.Cosmos.RerankScore] RerankScores": { + "Type": "Property", + "Attributes": [], + "MethodInfo": "System.Collections.Generic.IReadOnlyList`1[Microsoft.Azure.Cosmos.RerankScore] RerankScores;CanRead:True;CanWrite:False;System.Collections.Generic.IReadOnlyList`1[Microsoft.Azure.Cosmos.RerankScore] get_RerankScores();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, + "System.Net.Http.Headers.HttpResponseHeaders get_Headers()[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": { + "Type": "Method", + "Attributes": [ + "CompilerGeneratedAttribute" + ], + "MethodInfo": "System.Net.Http.Headers.HttpResponseHeaders get_Headers();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, + "System.Net.Http.Headers.HttpResponseHeaders Headers": { + "Type": "Property", + "Attributes": [], + "MethodInfo": "System.Net.Http.Headers.HttpResponseHeaders Headers;CanRead:True;CanWrite:False;System.Net.Http.Headers.HttpResponseHeaders get_Headers();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + } + }, + "NestedTypes": {} + }, "Microsoft.Azure.Cosmos.VectorIndexPath;System.Object;IsAbstract:False;IsSealed:True;IsInterface:False;IsEnum:False;IsClass:True;IsValueType:False;IsNested:False;IsGenericType:False;IsSerializable:False": { "Subclasses": {}, "Members": {