diff --git a/Microsoft.Azure.Cosmos/src/CosmosClientOptions.cs b/Microsoft.Azure.Cosmos/src/CosmosClientOptions.cs index 80a9478dab..4c69b04ef4 100644 --- a/Microsoft.Azure.Cosmos/src/CosmosClientOptions.cs +++ b/Microsoft.Azure.Cosmos/src/CosmosClientOptions.cs @@ -306,6 +306,24 @@ public int GatewayModeMaxConnectionLimit /// public TimeSpan RequestTimeout { get; set; } + /// + /// Gets or sets the request timeout for inference service operations (e.g., semantic reranking). + /// The number specifies the time to wait for a response from the inference service before the request is cancelled. + /// This is a single-attempt timeout with no retries. + /// + /// Default value is 5 seconds. + /// + /// This timeout is specific to inference service operations and is separate from the standard . + /// If the request does not complete within the specified duration, a with status 408 (Request Timeout) is thrown. + /// No retries are attempted on timeout. + /// +#if PREVIEW + public +#else + internal +#endif + TimeSpan InferenceRequestTimeout { get; set; } = InferenceService.DefaultInferenceRequestTimeout; + /// /// The SDK does a background refresh based on the time interval set to refresh the token credentials. /// This avoids latency issues because the old token is used until the new token is retrieved. diff --git a/Microsoft.Azure.Cosmos/src/Fluent/CosmosClientBuilder.cs b/Microsoft.Azure.Cosmos/src/Fluent/CosmosClientBuilder.cs index e93229ae70..e33de132e0 100644 --- a/Microsoft.Azure.Cosmos/src/Fluent/CosmosClientBuilder.cs +++ b/Microsoft.Azure.Cosmos/src/Fluent/CosmosClientBuilder.cs @@ -387,6 +387,26 @@ public CosmosClientBuilder WithRequestTimeout(TimeSpan requestTimeout) return this; } + /// + /// Sets the request timeout for inference service operations (e.g., semantic reranking). + /// This is a single-attempt timeout with no retries; if the request does not complete + /// within the specified duration, a with status 408 (Request Timeout) is thrown. + /// + /// A time to use as timeout for inference operations. + /// Default value is 5 seconds. + /// The current . + /// +#if PREVIEW + public +#else + internal +#endif + CosmosClientBuilder WithInferenceRequestTimeout(TimeSpan inferenceRequestTimeout) + { + this.clientOptions.InferenceRequestTimeout = inferenceRequestTimeout; + return this; + } + /// /// Sets the connection mode to Direct. This is used by the client when connecting to the Azure Cosmos DB service. /// diff --git a/Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs b/Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs index bf05c9698d..4436f536c5 100644 --- a/Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs +++ b/Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs @@ -15,6 +15,7 @@ namespace Microsoft.Azure.Cosmos using System.Threading; using System.Threading.Tasks; using global::Azure.Core; + using Microsoft.Azure.Cosmos.Resource.CosmosExceptions; using Microsoft.Azure.Documents; using Microsoft.Azure.Documents.Collections; @@ -32,9 +33,16 @@ internal class InferenceService : IDisposable private const string InferenceTokenPrefix = "Bearer "; private const int inferenceServiceDefaultMaxConnectionLimit = 50; + /// + /// Default per-request timeout for inference requests. Referenced by + /// . + /// + internal static readonly TimeSpan DefaultInferenceRequestTimeout = TimeSpan.FromSeconds(5); + private readonly int inferenceServiceMaxConnectionLimit; private readonly string inferenceServiceBaseUrl; private readonly Uri inferenceEndpoint; + private readonly TimeSpan inferenceRequestTimeout; private HttpClient httpClient; private AuthorizationTokenProvider cosmosAuthorization; @@ -59,6 +67,9 @@ public InferenceService(CosmosClient client) "AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_SERVICE_MAX_CONNECTION_LIMIT", inferenceServiceDefaultMaxConnectionLimit) ?? inferenceServiceDefaultMaxConnectionLimit; + Debug.Assert(client.ClientOptions != null, "ClientOptions should not be null"); + this.inferenceRequestTimeout = client.ClientOptions.InferenceRequestTimeout; + // Create and configure HttpClient for inference requests. HttpMessageHandler httpMessageHandler = CosmosHttpClientCore.CreateHttpClientHandler( gatewayModeMaxConnectionLimit: this.inferenceServiceMaxConnectionLimit, @@ -95,6 +106,7 @@ public InferenceService(CosmosClient client) /// internal InferenceService(HttpMessageHandler messageHandler, Uri inferenceEndpoint, AuthorizationTokenProvider cosmosAuthorization) { + this.inferenceRequestTimeout = InferenceService.DefaultInferenceRequestTimeout; this.httpClient = new HttpClient(messageHandler); this.CreateClientHelper(this.httpClient); this.inferenceEndpoint = inferenceEndpoint; @@ -115,6 +127,8 @@ public async Task SemanticRerankAsync( IDictionary options = null, CancellationToken cancellationToken = default) { + DateTime startDateTimeUtc = DateTime.UtcNow; + // Prepare HTTP request for semantic reranking. HttpRequestMessage message = new HttpRequestMessage(HttpMethod.Post, this.inferenceEndpoint); INameValueCollection additionalHeaders = new RequestNameValueCollection(); @@ -139,8 +153,29 @@ await this.cosmosAuthorization.AddAuthorizationHeaderAsync( Encoding.UTF8, RuntimeConstants.MediaTypes.Json); - // Send the request and check for success. - HttpResponseMessage responseMessage = await this.httpClient.SendAsync(message, cancellationToken); + // Enforce a single-attempt, no-retry timeout for the inference request. + // HttpClient.Timeout is intentionally left unchanged; this linked CTS is the authoritative + // per-request timeout for inference calls. + using CancellationTokenSource linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + linkedCts.CancelAfter(this.inferenceRequestTimeout); + + HttpResponseMessage responseMessage; + try + { + responseMessage = await this.httpClient.SendAsync(message, linkedCts.Token); + } + catch (OperationCanceledException operationCanceledException) when (!cancellationToken.IsCancellationRequested) + { + // Timeout triggered by the linked CTS (not the caller's cancellationToken). + string errorMessage = $"Inference Service Request Timeout. Start Time UTC:{startDateTimeUtc}; Total Duration:{(DateTime.UtcNow - startDateTimeUtc).TotalMilliseconds} Ms; Inference Request Timeout:{this.inferenceRequestTimeout.TotalMilliseconds} Ms; Activity id: {System.Diagnostics.Trace.CorrelationManager.ActivityId};"; + throw CosmosExceptionFactory.CreateRequestTimeoutException( + message: errorMessage, + headers: new Headers() + { + ActivityId = System.Diagnostics.Trace.CorrelationManager.ActivityId.ToString() + }, + innerException: operationCanceledException); + } if (!responseMessage.IsSuccessStatusCode) { diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetPreviewSDKAPI.net6.json b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetPreviewSDKAPI.net6.json index 9e75de23de..9a78dbe18a 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetPreviewSDKAPI.net6.json +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetPreviewSDKAPI.net6.json @@ -392,11 +392,30 @@ "Attributes": [], "MethodInfo": "System.Nullable`1[System.Int32] ThroughputBucket;CanRead:True;CanWrite:True;System.Nullable`1[System.Int32] get_ThroughputBucket();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;Void set_ThroughputBucket(System.Nullable`1[System.Int32]);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" }, + "System.TimeSpan get_InferenceRequestTimeout()[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": { + "Type": "Method", + "Attributes": [ + "CompilerGeneratedAttribute" + ], + "MethodInfo": "System.TimeSpan get_InferenceRequestTimeout();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, + "System.TimeSpan InferenceRequestTimeout": { + "Type": "Property", + "Attributes": [], + "MethodInfo": "System.TimeSpan InferenceRequestTimeout;CanRead:True;CanWrite:True;System.TimeSpan get_InferenceRequestTimeout();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;Void set_InferenceRequestTimeout(System.TimeSpan);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, "Void set_EnableRemoteRegionPreferredForSessionRetry(Boolean)": { "Type": "Method", "Attributes": [], "MethodInfo": "Void set_EnableRemoteRegionPreferredForSessionRetry(Boolean);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" }, + "Void set_InferenceRequestTimeout(System.TimeSpan)[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": { + "Type": "Method", + "Attributes": [ + "CompilerGeneratedAttribute" + ], + "MethodInfo": "Void set_InferenceRequestTimeout(System.TimeSpan);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, "Void set_ReadConsistencyStrategy(System.Nullable`1[Microsoft.Azure.Cosmos.ReadConsistencyStrategy])[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": { "Type": "Method", "Attributes": [ @@ -1232,6 +1251,11 @@ "Attributes": [], "MethodInfo": "Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithEnableRemoteRegionPreferredForSessionRetry(Boolean);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" }, + "Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithInferenceRequestTimeout(System.TimeSpan)": { + "Type": "Method", + "Attributes": [], + "MethodInfo": "Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithInferenceRequestTimeout(System.TimeSpan);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;" + }, "Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithReadConsistencyStrategy(Microsoft.Azure.Cosmos.ReadConsistencyStrategy)": { "Type": "Method", "Attributes": [], diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/InferenceServiceTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/InferenceServiceTests.cs index 47ebf31e80..1949f65385 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/InferenceServiceTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/InferenceServiceTests.cs @@ -94,6 +94,60 @@ public async Task SemanticRerankAsync_SuccessResponse_ReturnsResult() Assert.AreEqual(0, result.RerankScores[0].Index); } + [TestMethod] + public async Task SemanticRerankAsync_RequestExceedsInferenceTimeout_Throws408CosmosException() + { + // Handler delays for 10 seconds; the internal InferenceService ctor uses the + // DefaultInferenceRequestTimeout (5 seconds), so the linked CTS should cancel first. + DelayedMessageHandler delayedHandler = new DelayedMessageHandler( + delay: TimeSpan.FromSeconds(10), + statusCode: HttpStatusCode.OK, + responseContent: "{}"); + + Mock mockAuth = InferenceServiceTests.CreateMockAuthorizationTokenProvider(); + + using InferenceService service = new InferenceService(delayedHandler, TestEndpoint, mockAuth.Object); + + CosmosException exception = await Assert.ThrowsExceptionAsync( + () => service.SemanticRerankAsync( + rerankContext: "test query", + documents: new List { "doc1", "doc2" })); + + Assert.AreEqual(HttpStatusCode.RequestTimeout, exception.StatusCode); + Assert.IsTrue( + exception.Message.Contains("Inference Service Request Timeout"), + $"Expected timeout message. Actual: {exception.Message}"); + } + + [TestMethod] + public async Task SemanticRerankAsync_UserCancellation_PropagatesOperationCanceledException() + { + // Handler delays long enough that user cancellation should fire first. + DelayedMessageHandler delayedHandler = new DelayedMessageHandler( + delay: TimeSpan.FromSeconds(10), + statusCode: HttpStatusCode.OK, + responseContent: "{}"); + + Mock mockAuth = InferenceServiceTests.CreateMockAuthorizationTokenProvider(); + + using InferenceService service = new InferenceService(delayedHandler, TestEndpoint, mockAuth.Object); + using CancellationTokenSource cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(200)); + + try + { + await service.SemanticRerankAsync( + rerankContext: "test query", + documents: new List { "doc1", "doc2" }, + cancellationToken: cts.Token); + Assert.Fail("Expected OperationCanceledException to propagate when the caller cancels."); + } + catch (OperationCanceledException) + { + // Expected: user cancellation should surface as OperationCanceledException (or its + // TaskCanceledException subclass), not be swallowed into a timeout CosmosException. + } + } + private static Mock CreateMockAuthorizationTokenProvider() { Mock mockAuth = new Mock(); @@ -132,5 +186,35 @@ protected override Task SendAsync( return Task.FromResult(response); } } + + /// + /// HttpMessageHandler that delays for a configurable duration before responding. + /// Used to exercise the per-request inference timeout. + /// + private class DelayedMessageHandler : HttpMessageHandler + { + private readonly TimeSpan delay; + private readonly HttpStatusCode statusCode; + private readonly string responseContent; + + public DelayedMessageHandler(TimeSpan delay, HttpStatusCode statusCode, string responseContent) + { + this.delay = delay; + this.statusCode = statusCode; + this.responseContent = responseContent; + } + + protected override async Task SendAsync( + HttpRequestMessage request, + CancellationToken cancellationToken) + { + await Task.Delay(this.delay, cancellationToken); + + return new HttpResponseMessage(this.statusCode) + { + Content = new StringContent(this.responseContent, Encoding.UTF8, "application/json") + }; + } + } } }