Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions Microsoft.Azure.Cosmos/src/CosmosClientOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,24 @@ public int GatewayModeMaxConnectionLimit
/// <seealso cref="CosmosClientBuilder.WithRequestTimeout(TimeSpan)"/>
public TimeSpan RequestTimeout { get; set; }

/// <summary>
/// Gets or sets the request timeout for inference service operations (e.g., semantic reranking).
/// The number specifies the time to wait for a response from the inference service before the request is cancelled.
/// This is a single-attempt timeout with no retries.
/// </summary>
/// <value>Default value is 5 seconds.</value>
/// <remarks>
/// This timeout is specific to inference service operations and is separate from the standard <see cref="RequestTimeout"/>.
/// If the request does not complete within the specified duration, a <see cref="CosmosException"/> with status 408 (Request Timeout) is thrown.
/// No retries are attempted on timeout.
/// </remarks>
#if PREVIEW
public
#else
internal
#endif
TimeSpan InferenceRequestTimeout { get; set; } = InferenceService.DefaultInferenceRequestTimeout;

/// <summary>
/// The SDK does a background refresh based on the time interval set to refresh the token credentials.
/// This avoids latency issues because the old token is used until the new token is retrieved.
Expand Down
20 changes: 20 additions & 0 deletions Microsoft.Azure.Cosmos/src/Fluent/CosmosClientBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,26 @@ public CosmosClientBuilder WithRequestTimeout(TimeSpan requestTimeout)
return this;
}

/// <summary>
/// Sets the request timeout for inference service operations (e.g., semantic reranking).
/// This is a single-attempt timeout with no retries; if the request does not complete
/// within the specified duration, a <see cref="CosmosException"/> with status 408 (Request Timeout) is thrown.
/// </summary>
/// <param name="inferenceRequestTimeout">A time to use as timeout for inference operations.</param>
/// <value>Default value is 5 seconds.</value>
/// <returns>The current <see cref="CosmosClientBuilder"/>.</returns>
/// <seealso cref="CosmosClientOptions.InferenceRequestTimeout"/>
#if PREVIEW
public
#else
internal
#endif
CosmosClientBuilder WithInferenceRequestTimeout(TimeSpan inferenceRequestTimeout)
{
this.clientOptions.InferenceRequestTimeout = inferenceRequestTimeout;
return this;
}

/// <summary>
/// Sets the connection mode to Direct. This is used by the client when connecting to the Azure Cosmos DB service.
/// </summary>
Expand Down
39 changes: 37 additions & 2 deletions Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace Microsoft.Azure.Cosmos
using System.Threading;
using System.Threading.Tasks;
using global::Azure.Core;
using Microsoft.Azure.Cosmos.Resource.CosmosExceptions;
using Microsoft.Azure.Documents;
using Microsoft.Azure.Documents.Collections;

Expand All @@ -32,9 +33,16 @@ internal class InferenceService : IDisposable
private const string InferenceTokenPrefix = "Bearer ";
private const int inferenceServiceDefaultMaxConnectionLimit = 50;

/// <summary>
/// Default per-request timeout for inference requests. Referenced by
/// <see cref="CosmosClientOptions.InferenceRequestTimeout"/>.
/// </summary>
internal static readonly TimeSpan DefaultInferenceRequestTimeout = TimeSpan.FromSeconds(5);

private readonly int inferenceServiceMaxConnectionLimit;
private readonly string inferenceServiceBaseUrl;
private readonly Uri inferenceEndpoint;
private readonly TimeSpan inferenceRequestTimeout;

private HttpClient httpClient;
private AuthorizationTokenProvider cosmosAuthorization;
Expand All @@ -59,6 +67,9 @@ public InferenceService(CosmosClient client)
"AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_SERVICE_MAX_CONNECTION_LIMIT",
inferenceServiceDefaultMaxConnectionLimit) ?? inferenceServiceDefaultMaxConnectionLimit;

Debug.Assert(client.ClientOptions != null, "ClientOptions should not be null");
this.inferenceRequestTimeout = client.ClientOptions.InferenceRequestTimeout;

// Create and configure HttpClient for inference requests.
HttpMessageHandler httpMessageHandler = CosmosHttpClientCore.CreateHttpClientHandler(
gatewayModeMaxConnectionLimit: this.inferenceServiceMaxConnectionLimit,
Expand Down Expand Up @@ -95,6 +106,7 @@ public InferenceService(CosmosClient client)
/// </summary>
internal InferenceService(HttpMessageHandler messageHandler, Uri inferenceEndpoint, AuthorizationTokenProvider cosmosAuthorization)
{
this.inferenceRequestTimeout = InferenceService.DefaultInferenceRequestTimeout;
this.httpClient = new HttpClient(messageHandler);
this.CreateClientHelper(this.httpClient);
this.inferenceEndpoint = inferenceEndpoint;
Expand All @@ -115,6 +127,8 @@ public async Task<SemanticRerankResult> SemanticRerankAsync(
IDictionary<string, object> options = null,
CancellationToken cancellationToken = default)
{
DateTime startDateTimeUtc = DateTime.UtcNow;

// Prepare HTTP request for semantic reranking.
HttpRequestMessage message = new HttpRequestMessage(HttpMethod.Post, this.inferenceEndpoint);
INameValueCollection additionalHeaders = new RequestNameValueCollection();
Expand All @@ -139,8 +153,29 @@ await this.cosmosAuthorization.AddAuthorizationHeaderAsync(
Encoding.UTF8,
RuntimeConstants.MediaTypes.Json);

// Send the request and check for success.
HttpResponseMessage responseMessage = await this.httpClient.SendAsync(message, cancellationToken);
// Enforce a single-attempt, no-retry timeout for the inference request.
// HttpClient.Timeout is intentionally left unchanged; this linked CTS is the authoritative
// per-request timeout for inference calls.
using CancellationTokenSource linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
linkedCts.CancelAfter(this.inferenceRequestTimeout);

HttpResponseMessage responseMessage;
try
{
responseMessage = await this.httpClient.SendAsync(message, linkedCts.Token);
}
catch (OperationCanceledException operationCanceledException) when (!cancellationToken.IsCancellationRequested)
{
// Timeout triggered by the linked CTS (not the caller's cancellationToken).
string errorMessage = $"Inference Service Request Timeout. Start Time UTC:{startDateTimeUtc}; Total Duration:{(DateTime.UtcNow - startDateTimeUtc).TotalMilliseconds} Ms; Inference Request Timeout:{this.inferenceRequestTimeout.TotalMilliseconds} Ms; Activity id: {System.Diagnostics.Trace.CorrelationManager.ActivityId};";
throw CosmosExceptionFactory.CreateRequestTimeoutException(
message: errorMessage,
headers: new Headers()
{
ActivityId = System.Diagnostics.Trace.CorrelationManager.ActivityId.ToString()
},
innerException: operationCanceledException);
}

if (!responseMessage.IsSuccessStatusCode)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,11 +392,30 @@
"Attributes": [],
"MethodInfo": "System.Nullable`1[System.Int32] ThroughputBucket;CanRead:True;CanWrite:True;System.Nullable`1[System.Int32] get_ThroughputBucket();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;Void set_ThroughputBucket(System.Nullable`1[System.Int32]);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
},
"System.TimeSpan get_InferenceRequestTimeout()[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": {
"Type": "Method",
"Attributes": [
"CompilerGeneratedAttribute"
],
"MethodInfo": "System.TimeSpan get_InferenceRequestTimeout();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
},
"System.TimeSpan InferenceRequestTimeout": {
"Type": "Property",
"Attributes": [],
"MethodInfo": "System.TimeSpan InferenceRequestTimeout;CanRead:True;CanWrite:True;System.TimeSpan get_InferenceRequestTimeout();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;Void set_InferenceRequestTimeout(System.TimeSpan);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
},
"Void set_EnableRemoteRegionPreferredForSessionRetry(Boolean)": {
"Type": "Method",
"Attributes": [],
"MethodInfo": "Void set_EnableRemoteRegionPreferredForSessionRetry(Boolean);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
},
"Void set_InferenceRequestTimeout(System.TimeSpan)[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": {
"Type": "Method",
"Attributes": [
"CompilerGeneratedAttribute"
],
"MethodInfo": "Void set_InferenceRequestTimeout(System.TimeSpan);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
},
"Void set_ReadConsistencyStrategy(System.Nullable`1[Microsoft.Azure.Cosmos.ReadConsistencyStrategy])[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": {
"Type": "Method",
"Attributes": [
Expand Down Expand Up @@ -1232,6 +1251,11 @@
"Attributes": [],
"MethodInfo": "Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithEnableRemoteRegionPreferredForSessionRetry(Boolean);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
},
"Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithInferenceRequestTimeout(System.TimeSpan)": {
"Type": "Method",
"Attributes": [],
"MethodInfo": "Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithInferenceRequestTimeout(System.TimeSpan);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
},
"Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithReadConsistencyStrategy(Microsoft.Azure.Cosmos.ReadConsistencyStrategy)": {
"Type": "Method",
"Attributes": [],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,60 @@ public async Task SemanticRerankAsync_SuccessResponse_ReturnsResult()
Assert.AreEqual(0, result.RerankScores[0].Index);
}

[TestMethod]
public async Task SemanticRerankAsync_RequestExceedsInferenceTimeout_Throws408CosmosException()
{
// Handler delays for 10 seconds; the internal InferenceService ctor uses the
// DefaultInferenceRequestTimeout (5 seconds), so the linked CTS should cancel first.
DelayedMessageHandler delayedHandler = new DelayedMessageHandler(
delay: TimeSpan.FromSeconds(10),
statusCode: HttpStatusCode.OK,
responseContent: "{}");

Mock<AuthorizationTokenProvider> mockAuth = InferenceServiceTests.CreateMockAuthorizationTokenProvider();

using InferenceService service = new InferenceService(delayedHandler, TestEndpoint, mockAuth.Object);

CosmosException exception = await Assert.ThrowsExceptionAsync<CosmosException>(
() => service.SemanticRerankAsync(
rerankContext: "test query",
documents: new List<string> { "doc1", "doc2" }));

Assert.AreEqual(HttpStatusCode.RequestTimeout, exception.StatusCode);
Assert.IsTrue(
exception.Message.Contains("Inference Service Request Timeout"),
$"Expected timeout message. Actual: {exception.Message}");
}

[TestMethod]
public async Task SemanticRerankAsync_UserCancellation_PropagatesOperationCanceledException()
{
// Handler delays long enough that user cancellation should fire first.
DelayedMessageHandler delayedHandler = new DelayedMessageHandler(
delay: TimeSpan.FromSeconds(10),
statusCode: HttpStatusCode.OK,
responseContent: "{}");

Mock<AuthorizationTokenProvider> mockAuth = InferenceServiceTests.CreateMockAuthorizationTokenProvider();

using InferenceService service = new InferenceService(delayedHandler, TestEndpoint, mockAuth.Object);
using CancellationTokenSource cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(200));

try
{
await service.SemanticRerankAsync(
rerankContext: "test query",
documents: new List<string> { "doc1", "doc2" },
cancellationToken: cts.Token);
Assert.Fail("Expected OperationCanceledException to propagate when the caller cancels.");
}
catch (OperationCanceledException)
{
// Expected: user cancellation should surface as OperationCanceledException (or its
// TaskCanceledException subclass), not be swallowed into a timeout CosmosException.
}
}

private static Mock<AuthorizationTokenProvider> CreateMockAuthorizationTokenProvider()
{
Mock<AuthorizationTokenProvider> mockAuth = new Mock<AuthorizationTokenProvider>();
Expand Down Expand Up @@ -132,5 +186,35 @@ protected override Task<HttpResponseMessage> SendAsync(
return Task.FromResult(response);
}
}

/// <summary>
/// HttpMessageHandler that delays for a configurable duration before responding.
/// Used to exercise the per-request inference timeout.
/// </summary>
private class DelayedMessageHandler : HttpMessageHandler
{
private readonly TimeSpan delay;
private readonly HttpStatusCode statusCode;
private readonly string responseContent;

public DelayedMessageHandler(TimeSpan delay, HttpStatusCode statusCode, string responseContent)
{
this.delay = delay;
this.statusCode = statusCode;
this.responseContent = responseContent;
}

protected override async Task<HttpResponseMessage> SendAsync(
HttpRequestMessage request,
CancellationToken cancellationToken)
{
await Task.Delay(this.delay, cancellationToken);

return new HttpResponseMessage(this.statusCode)
{
Content = new StringContent(this.responseContent, Encoding.UTF8, "application/json")
};
}
}
}
}
Loading