diff --git a/src/Grpc.Net.Client/Internal/GrpcCall.cs b/src/Grpc.Net.Client/Internal/GrpcCall.cs index 15aa25325..36f126a43 100644 --- a/src/Grpc.Net.Client/Internal/GrpcCall.cs +++ b/src/Grpc.Net.Client/Internal/GrpcCall.cs @@ -61,14 +61,16 @@ internal sealed partial class GrpcCall : GrpcCall, IGrpcCal public HttpContentClientStreamWriter? ClientStreamWriter { get; private set; } public HttpContentClientStreamReader? ClientStreamReader { get; private set; } - public GrpcCall(Method method, GrpcMethodInfo grpcMethodInfo, CallOptions options, GrpcChannel channel, int attemptCount) + public GrpcCall(Method method, GrpcMethodInfo grpcMethodInfo, CallOptions options, GrpcChannel channel, int attemptCount, bool forceAsyncHttpResponse) : base(options, channel) { // Validate deadline before creating any objects that require cleanup ValidateDeadline(options.Deadline); _callCts = new CancellationTokenSource(); - _httpResponseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + // Retries and hedging can run multiple calls at the same time and use locking for thread-safety. + // Running HTTP response continuation asynchronously is required for locking to work correctly. + _httpResponseTcs = new TaskCompletionSource(forceAsyncHttpResponse ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None); // Run the callTcs continuation immediately to keep the same context. Required for Activity. _callTcs = new TaskCompletionSource(); Method = method; @@ -142,7 +144,10 @@ public void StartDuplexStreaming() internal void StartUnaryCore(HttpContent content) { - _responseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + // Not created with RunContinuationsAsynchronously to avoid unnecessary dispatch to the thread pool. + // The TCS is set from RunCall but it is the last operation before the method exits so there shouldn't + // be an impact from running the response continutation synchronously. + _responseTcs = new TaskCompletionSource(); var timeout = GetTimeout(); var message = CreateHttpRequestMessage(timeout); @@ -161,7 +166,10 @@ internal void StartServerStreamingCore(HttpContent content) internal void StartClientStreamingCore(HttpContentClientStreamWriter clientStreamWriter, HttpContent content) { - _responseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + // Not created with RunContinuationsAsynchronously to avoid unnecessary dispatch to the thread pool. + // The TCS is set from RunCall but it is the last operation before the method exits so there shouldn't + // be an impact from running the response continutation synchronously. + _responseTcs = new TaskCompletionSource(); var timeout = GetTimeout(); var message = CreateHttpRequestMessage(timeout); @@ -431,9 +439,6 @@ private void CancelCall(Status status) // Cancellation will also cause reader/writer to throw if used afterwards. _callCts.Cancel(); - // Ensure any logic that is waiting on the HttpResponse is unstuck. - _httpResponseTcs.TrySetCanceled(); - // Cancellation token won't send RST_STREAM if HttpClient.SendAsync is complete. // Dispose HttpResponseMessage to send RST_STREAM to server for in-progress calls. HttpResponse?.Dispose(); @@ -652,6 +657,9 @@ private async Task RunCall(HttpRequestMessage request, TimeSpan? timeout) // Verify that FinishCall is called in every code path of this method. // Should create an "Unassigned variable" compiler error if not set. Debug.Assert(finished); + // Should be completed before exiting. + Debug.Assert(_httpResponseTcs.Task.IsCompleted); + Debug.Assert(_responseTcs == null || _responseTcs.Task.IsCompleted); } } diff --git a/src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs b/src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs index 4783ef4fc..16db39afb 100644 --- a/src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs +++ b/src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs @@ -177,7 +177,7 @@ private static IGrpcCall CreateRootGrpcCall(channel, method, options, attempt: 1, callWrapper: null); + return CreateGrpcCall(channel, method, options, attempt: 1, forceAsyncHttpResponse: false, callWrapper: null); } } @@ -210,6 +210,7 @@ public static GrpcCall CreateGrpcCall( Method method, CallOptions options, int attempt, + bool forceAsyncHttpResponse, object? callWrapper) where TRequest : class where TResponse : class @@ -217,7 +218,7 @@ public static GrpcCall CreateGrpcCall( ObjectDisposedThrowHelper.ThrowIf(channel.Disposed, typeof(GrpcChannel)); var methodInfo = channel.GetCachedGrpcMethodInfo(method); - var call = new GrpcCall(method, methodInfo, options, channel, attempt); + var call = new GrpcCall(method, methodInfo, options, channel, attempt, forceAsyncHttpResponse); call.CallWrapper = callWrapper; return call; diff --git a/src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs b/src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs index c53080fee..4eb134548 100644 --- a/src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs +++ b/src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs @@ -75,7 +75,7 @@ private async Task StartCall(Action> startCallFunc OnStartingAttempt(); - call = HttpClientCallInvoker.CreateGrpcCall(Channel, Method, Options, AttemptCount, CallWrapper); + call = HttpClientCallInvoker.CreateGrpcCall(Channel, Method, Options, AttemptCount, forceAsyncHttpResponse: true, CallWrapper); _activeCalls.Add(call); startCallFunc(call); diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCall.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCall.cs index dd645e3f0..f007a917b 100644 --- a/src/Grpc.Net.Client/Internal/Retry/RetryCall.cs +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCall.cs @@ -109,7 +109,7 @@ private async Task StartRetry(Action> startCallFun // Start new call. OnStartingAttempt(); - currentCall = _activeCall = HttpClientCallInvoker.CreateGrpcCall(Channel, Method, Options, AttemptCount, CallWrapper); + currentCall = _activeCall = HttpClientCallInvoker.CreateGrpcCall(Channel, Method, Options, AttemptCount, forceAsyncHttpResponse: true, CallWrapper); startCallFunc(currentCall); SetNewActiveCallUnsynchronized(currentCall); diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs index 7f54895f1..53cf2da2d 100644 --- a/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs @@ -37,6 +37,7 @@ internal abstract partial class RetryCallBase : IGrpcCall? _responseTask; private Task? _responseHeadersTask; private TRequest? _request; + private bool _commitStarted; // Internal for unit testing. internal CancellationTokenRegistration? _ctsRegistration; @@ -369,8 +370,11 @@ protected void CommitCall(IGrpcCall call, CommitReason comm { lock (Lock) { - if (!CommitedCallTask.IsCompletedSuccessfully()) + if (!_commitStarted) { + // Specify that call is commiting. This is to prevent any chance of re-entrancy from logic run in OnCommitCall. + _commitStarted = true; + // The buffer size is verified in unit tests after calls are completed. // Clear the buffer before commiting call. ClearRetryBuffer(); diff --git a/test/Grpc.Net.Client.Tests/CancellationTests.cs b/test/Grpc.Net.Client.Tests/CancellationTests.cs index a1bfe1009..6e4a7f006 100644 --- a/test/Grpc.Net.Client.Tests/CancellationTests.cs +++ b/test/Grpc.Net.Client.Tests/CancellationTests.cs @@ -113,7 +113,7 @@ public async Task AsyncClientStreamingCall_CancellationDuringSend_ThrowOperation cts.Cancel(); - var ex = await ExceptionAssert.ThrowsAsync(() => responseHeadersTask).DefaultTimeout(); + var ex = await ExceptionAssert.ThrowsAsync(() => responseHeadersTask).DefaultTimeout(); Assert.AreEqual(StatusCode.Cancelled, call.GetStatus().StatusCode); Assert.AreEqual("Call canceled by the client.", call.GetStatus().Detail); } diff --git a/test/Grpc.Net.Client.Tests/HttpContentClientStreamReaderTests.cs b/test/Grpc.Net.Client.Tests/HttpContentClientStreamReaderTests.cs index 4df9e5d56..4465738de 100644 --- a/test/Grpc.Net.Client.Tests/HttpContentClientStreamReaderTests.cs +++ b/test/Grpc.Net.Client.Tests/HttpContentClientStreamReaderTests.cs @@ -231,7 +231,8 @@ private static GrpcCall CreateGrpcCall(GrpcChannel cha new GrpcMethodInfo(new GrpcCallScope(ClientTestHelpers.ServiceMethod.Type, uri), uri, methodConfig: null), new CallOptions(), channel, - attemptCount: 0); + attemptCount: 0, + forceAsyncHttpResponse: false); } private static GrpcChannel CreateChannel(HttpClient httpClient, ILoggerFactory? loggerFactory = null, bool? throwOperationCanceledOnCancellation = null) diff --git a/test/Grpc.Net.Client.Tests/ResponseHeadersAsyncTests.cs b/test/Grpc.Net.Client.Tests/ResponseHeadersAsyncTests.cs index 249c39827..33e6879a7 100644 --- a/test/Grpc.Net.Client.Tests/ResponseHeadersAsyncTests.cs +++ b/test/Grpc.Net.Client.Tests/ResponseHeadersAsyncTests.cs @@ -112,7 +112,10 @@ public async Task AsyncUnaryCall_AuthInterceptorDispose_ResponseHeadersError() var credentialsSyncPoint = new SyncPoint(runContinuationsAsynchronously: true); var credentials = CallCredentials.FromInterceptor(async (context, metadata) => { - await credentialsSyncPoint.WaitToContinue(); + var tcs = new TaskCompletionSource(); + context.CancellationToken.Register(s => ((TaskCompletionSource)s!).SetResult(true), tcs); + + await Task.WhenAny(credentialsSyncPoint.WaitToContinue(), tcs.Task); metadata.Add("Authorization", $"Bearer TEST"); }); diff --git a/test/Grpc.Net.Client.Tests/Retry/RetryTests.cs b/test/Grpc.Net.Client.Tests/Retry/RetryTests.cs index 105a18c89..3508dc1f0 100644 --- a/test/Grpc.Net.Client.Tests/Retry/RetryTests.cs +++ b/test/Grpc.Net.Client.Tests/Retry/RetryTests.cs @@ -165,7 +165,10 @@ public async Task AsyncUnaryCall_AuthInteceptorDispose_Error() var credentialsSyncPoint = new SyncPoint(runContinuationsAsynchronously: true); var credentials = CallCredentials.FromInterceptor(async (context, metadata) => { - await credentialsSyncPoint.WaitToContinue(); + var tcs = new TaskCompletionSource(); + context.CancellationToken.Register(s => ((TaskCompletionSource)s!).SetResult(true), tcs); + + await Task.WhenAny(credentialsSyncPoint.WaitToContinue(), tcs.Task); metadata.Add("Authorization", $"Bearer TEST"); }); var invoker = HttpClientCallInvokerFactory.Create(httpClient, loggerFactory: provider.GetRequiredService(), serviceConfig: serviceConfig, configure: options => options.Credentials = ChannelCredentials.Create(new SslCredentials(), credentials));