diff --git a/Microsoft.Azure.Cosmos/src/CosmosClientOptions.cs b/Microsoft.Azure.Cosmos/src/CosmosClientOptions.cs
index 80a9478dab..4c69b04ef4 100644
--- a/Microsoft.Azure.Cosmos/src/CosmosClientOptions.cs
+++ b/Microsoft.Azure.Cosmos/src/CosmosClientOptions.cs
@@ -306,6 +306,24 @@ public int GatewayModeMaxConnectionLimit
///
public TimeSpan RequestTimeout { get; set; }
+ ///
+ /// Gets or sets the request timeout for inference service operations (e.g., semantic reranking).
+ /// The number specifies the time to wait for a response from the inference service before the request is cancelled.
+ /// This is a single-attempt timeout with no retries.
+ ///
+ /// Default value is 5 seconds.
+ ///
+ /// This timeout is specific to inference service operations and is separate from the standard .
+ /// If the request does not complete within the specified duration, a with status 408 (Request Timeout) is thrown.
+ /// No retries are attempted on timeout.
+ ///
+#if PREVIEW
+ public
+#else
+ internal
+#endif
+ TimeSpan InferenceRequestTimeout { get; set; } = InferenceService.DefaultInferenceRequestTimeout;
+
///
/// The SDK does a background refresh based on the time interval set to refresh the token credentials.
/// This avoids latency issues because the old token is used until the new token is retrieved.
diff --git a/Microsoft.Azure.Cosmos/src/Fluent/CosmosClientBuilder.cs b/Microsoft.Azure.Cosmos/src/Fluent/CosmosClientBuilder.cs
index e93229ae70..e33de132e0 100644
--- a/Microsoft.Azure.Cosmos/src/Fluent/CosmosClientBuilder.cs
+++ b/Microsoft.Azure.Cosmos/src/Fluent/CosmosClientBuilder.cs
@@ -387,6 +387,26 @@ public CosmosClientBuilder WithRequestTimeout(TimeSpan requestTimeout)
return this;
}
+ ///
+ /// Sets the request timeout for inference service operations (e.g., semantic reranking).
+ /// This is a single-attempt timeout with no retries; if the request does not complete
+ /// within the specified duration, a with status 408 (Request Timeout) is thrown.
+ ///
+ /// A time to use as timeout for inference operations.
+ /// Default value is 5 seconds.
+ /// The current .
+ ///
+#if PREVIEW
+ public
+#else
+ internal
+#endif
+ CosmosClientBuilder WithInferenceRequestTimeout(TimeSpan inferenceRequestTimeout)
+ {
+ this.clientOptions.InferenceRequestTimeout = inferenceRequestTimeout;
+ return this;
+ }
+
///
/// Sets the connection mode to Direct. This is used by the client when connecting to the Azure Cosmos DB service.
///
diff --git a/Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs b/Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs
index bf05c9698d..4436f536c5 100644
--- a/Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs
+++ b/Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs
@@ -15,6 +15,7 @@ namespace Microsoft.Azure.Cosmos
using System.Threading;
using System.Threading.Tasks;
using global::Azure.Core;
+ using Microsoft.Azure.Cosmos.Resource.CosmosExceptions;
using Microsoft.Azure.Documents;
using Microsoft.Azure.Documents.Collections;
@@ -32,9 +33,16 @@ internal class InferenceService : IDisposable
private const string InferenceTokenPrefix = "Bearer ";
private const int inferenceServiceDefaultMaxConnectionLimit = 50;
+ ///
+ /// Default per-request timeout for inference requests. Referenced by
+ /// .
+ ///
+ internal static readonly TimeSpan DefaultInferenceRequestTimeout = TimeSpan.FromSeconds(5);
+
private readonly int inferenceServiceMaxConnectionLimit;
private readonly string inferenceServiceBaseUrl;
private readonly Uri inferenceEndpoint;
+ private readonly TimeSpan inferenceRequestTimeout;
private HttpClient httpClient;
private AuthorizationTokenProvider cosmosAuthorization;
@@ -59,6 +67,9 @@ public InferenceService(CosmosClient client)
"AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_SERVICE_MAX_CONNECTION_LIMIT",
inferenceServiceDefaultMaxConnectionLimit) ?? inferenceServiceDefaultMaxConnectionLimit;
+ Debug.Assert(client.ClientOptions != null, "ClientOptions should not be null");
+ this.inferenceRequestTimeout = client.ClientOptions.InferenceRequestTimeout;
+
// Create and configure HttpClient for inference requests.
HttpMessageHandler httpMessageHandler = CosmosHttpClientCore.CreateHttpClientHandler(
gatewayModeMaxConnectionLimit: this.inferenceServiceMaxConnectionLimit,
@@ -95,6 +106,7 @@ public InferenceService(CosmosClient client)
///
internal InferenceService(HttpMessageHandler messageHandler, Uri inferenceEndpoint, AuthorizationTokenProvider cosmosAuthorization)
{
+ this.inferenceRequestTimeout = InferenceService.DefaultInferenceRequestTimeout;
this.httpClient = new HttpClient(messageHandler);
this.CreateClientHelper(this.httpClient);
this.inferenceEndpoint = inferenceEndpoint;
@@ -115,6 +127,8 @@ public async Task SemanticRerankAsync(
IDictionary options = null,
CancellationToken cancellationToken = default)
{
+ DateTime startDateTimeUtc = DateTime.UtcNow;
+
// Prepare HTTP request for semantic reranking.
HttpRequestMessage message = new HttpRequestMessage(HttpMethod.Post, this.inferenceEndpoint);
INameValueCollection additionalHeaders = new RequestNameValueCollection();
@@ -139,8 +153,29 @@ await this.cosmosAuthorization.AddAuthorizationHeaderAsync(
Encoding.UTF8,
RuntimeConstants.MediaTypes.Json);
- // Send the request and check for success.
- HttpResponseMessage responseMessage = await this.httpClient.SendAsync(message, cancellationToken);
+ // Enforce a single-attempt, no-retry timeout for the inference request.
+ // HttpClient.Timeout is intentionally left unchanged; this linked CTS is the authoritative
+ // per-request timeout for inference calls.
+ using CancellationTokenSource linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+ linkedCts.CancelAfter(this.inferenceRequestTimeout);
+
+ HttpResponseMessage responseMessage;
+ try
+ {
+ responseMessage = await this.httpClient.SendAsync(message, linkedCts.Token);
+ }
+ catch (OperationCanceledException operationCanceledException) when (!cancellationToken.IsCancellationRequested)
+ {
+ // Timeout triggered by the linked CTS (not the caller's cancellationToken).
+ string errorMessage = $"Inference Service Request Timeout. Start Time UTC:{startDateTimeUtc}; Total Duration:{(DateTime.UtcNow - startDateTimeUtc).TotalMilliseconds} Ms; Inference Request Timeout:{this.inferenceRequestTimeout.TotalMilliseconds} Ms; Activity id: {System.Diagnostics.Trace.CorrelationManager.ActivityId};";
+ throw CosmosExceptionFactory.CreateRequestTimeoutException(
+ message: errorMessage,
+ headers: new Headers()
+ {
+ ActivityId = System.Diagnostics.Trace.CorrelationManager.ActivityId.ToString()
+ },
+ innerException: operationCanceledException);
+ }
if (!responseMessage.IsSuccessStatusCode)
{
diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetPreviewSDKAPI.net6.json b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetPreviewSDKAPI.net6.json
index 9e75de23de..9a78dbe18a 100644
--- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetPreviewSDKAPI.net6.json
+++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetPreviewSDKAPI.net6.json
@@ -392,11 +392,30 @@
"Attributes": [],
"MethodInfo": "System.Nullable`1[System.Int32] ThroughputBucket;CanRead:True;CanWrite:True;System.Nullable`1[System.Int32] get_ThroughputBucket();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;Void set_ThroughputBucket(System.Nullable`1[System.Int32]);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
},
+ "System.TimeSpan get_InferenceRequestTimeout()[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": {
+ "Type": "Method",
+ "Attributes": [
+ "CompilerGeneratedAttribute"
+ ],
+ "MethodInfo": "System.TimeSpan get_InferenceRequestTimeout();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
+ },
+ "System.TimeSpan InferenceRequestTimeout": {
+ "Type": "Property",
+ "Attributes": [],
+ "MethodInfo": "System.TimeSpan InferenceRequestTimeout;CanRead:True;CanWrite:True;System.TimeSpan get_InferenceRequestTimeout();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;Void set_InferenceRequestTimeout(System.TimeSpan);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
+ },
"Void set_EnableRemoteRegionPreferredForSessionRetry(Boolean)": {
"Type": "Method",
"Attributes": [],
"MethodInfo": "Void set_EnableRemoteRegionPreferredForSessionRetry(Boolean);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
},
+ "Void set_InferenceRequestTimeout(System.TimeSpan)[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": {
+ "Type": "Method",
+ "Attributes": [
+ "CompilerGeneratedAttribute"
+ ],
+ "MethodInfo": "Void set_InferenceRequestTimeout(System.TimeSpan);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
+ },
"Void set_ReadConsistencyStrategy(System.Nullable`1[Microsoft.Azure.Cosmos.ReadConsistencyStrategy])[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": {
"Type": "Method",
"Attributes": [
@@ -1232,6 +1251,11 @@
"Attributes": [],
"MethodInfo": "Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithEnableRemoteRegionPreferredForSessionRetry(Boolean);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
},
+ "Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithInferenceRequestTimeout(System.TimeSpan)": {
+ "Type": "Method",
+ "Attributes": [],
+ "MethodInfo": "Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithInferenceRequestTimeout(System.TimeSpan);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
+ },
"Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithReadConsistencyStrategy(Microsoft.Azure.Cosmos.ReadConsistencyStrategy)": {
"Type": "Method",
"Attributes": [],
diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/InferenceServiceTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/InferenceServiceTests.cs
index 47ebf31e80..1949f65385 100644
--- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/InferenceServiceTests.cs
+++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/InferenceServiceTests.cs
@@ -94,6 +94,60 @@ public async Task SemanticRerankAsync_SuccessResponse_ReturnsResult()
Assert.AreEqual(0, result.RerankScores[0].Index);
}
+ [TestMethod]
+ public async Task SemanticRerankAsync_RequestExceedsInferenceTimeout_Throws408CosmosException()
+ {
+ // Handler delays for 10 seconds; the internal InferenceService ctor uses the
+ // DefaultInferenceRequestTimeout (5 seconds), so the linked CTS should cancel first.
+ DelayedMessageHandler delayedHandler = new DelayedMessageHandler(
+ delay: TimeSpan.FromSeconds(10),
+ statusCode: HttpStatusCode.OK,
+ responseContent: "{}");
+
+ Mock mockAuth = InferenceServiceTests.CreateMockAuthorizationTokenProvider();
+
+ using InferenceService service = new InferenceService(delayedHandler, TestEndpoint, mockAuth.Object);
+
+ CosmosException exception = await Assert.ThrowsExceptionAsync(
+ () => service.SemanticRerankAsync(
+ rerankContext: "test query",
+ documents: new List { "doc1", "doc2" }));
+
+ Assert.AreEqual(HttpStatusCode.RequestTimeout, exception.StatusCode);
+ Assert.IsTrue(
+ exception.Message.Contains("Inference Service Request Timeout"),
+ $"Expected timeout message. Actual: {exception.Message}");
+ }
+
+ [TestMethod]
+ public async Task SemanticRerankAsync_UserCancellation_PropagatesOperationCanceledException()
+ {
+ // Handler delays long enough that user cancellation should fire first.
+ DelayedMessageHandler delayedHandler = new DelayedMessageHandler(
+ delay: TimeSpan.FromSeconds(10),
+ statusCode: HttpStatusCode.OK,
+ responseContent: "{}");
+
+ Mock mockAuth = InferenceServiceTests.CreateMockAuthorizationTokenProvider();
+
+ using InferenceService service = new InferenceService(delayedHandler, TestEndpoint, mockAuth.Object);
+ using CancellationTokenSource cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(200));
+
+ try
+ {
+ await service.SemanticRerankAsync(
+ rerankContext: "test query",
+ documents: new List { "doc1", "doc2" },
+ cancellationToken: cts.Token);
+ Assert.Fail("Expected OperationCanceledException to propagate when the caller cancels.");
+ }
+ catch (OperationCanceledException)
+ {
+ // Expected: user cancellation should surface as OperationCanceledException (or its
+ // TaskCanceledException subclass), not be swallowed into a timeout CosmosException.
+ }
+ }
+
private static Mock CreateMockAuthorizationTokenProvider()
{
Mock mockAuth = new Mock();
@@ -132,5 +186,35 @@ protected override Task SendAsync(
return Task.FromResult(response);
}
}
+
+ ///
+ /// HttpMessageHandler that delays for a configurable duration before responding.
+ /// Used to exercise the per-request inference timeout.
+ ///
+ private class DelayedMessageHandler : HttpMessageHandler
+ {
+ private readonly TimeSpan delay;
+ private readonly HttpStatusCode statusCode;
+ private readonly string responseContent;
+
+ public DelayedMessageHandler(TimeSpan delay, HttpStatusCode statusCode, string responseContent)
+ {
+ this.delay = delay;
+ this.statusCode = statusCode;
+ this.responseContent = responseContent;
+ }
+
+ protected override async Task SendAsync(
+ HttpRequestMessage request,
+ CancellationToken cancellationToken)
+ {
+ await Task.Delay(this.delay, cancellationToken);
+
+ return new HttpResponseMessage(this.statusCode)
+ {
+ Content = new StringContent(this.responseContent, Encoding.UTF8, "application/json")
+ };
+ }
+ }
}
}