diff --git a/sdk/core/Azure.Core/src/Shared/TaskExtensions.cs b/sdk/core/Azure.Core/src/Shared/TaskExtensions.cs index a22738192f8b..61b215054ec0 100644 --- a/sdk/core/Azure.Core/src/Shared/TaskExtensions.cs +++ b/sdk/core/Azure.Core/src/Shared/TaskExtensions.cs @@ -25,6 +25,11 @@ public static T EnsureCompleted(this Task task) { #if DEBUG VerifyTaskCompleted(task.IsCompleted); +#else + if (HasSynchronizationContext()) + { + throw new InvalidOperationException("Synchronously waiting on non-completed task isn't allowed."); + } #endif #pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. return task.GetAwaiter().GetResult(); @@ -35,6 +40,11 @@ public static void EnsureCompleted(this Task task) { #if DEBUG VerifyTaskCompleted(task.IsCompleted); +#else + if (HasSynchronizationContext()) + { + throw new InvalidOperationException("Synchronously waiting on non-completed task isn't allowed."); + } #endif #pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. task.GetAwaiter().GetResult(); @@ -43,9 +53,12 @@ public static void EnsureCompleted(this Task task) public static T EnsureCompleted(this ValueTask task) { -#if DEBUG - VerifyTaskCompleted(task.IsCompleted); -#endif + if (!task.IsCompleted) + { +#pragma warning disable AZC0107 // public asynchronous method shouldn't be called in synchronous scope. Use synchronous version of the method if it is available. + return EnsureCompleted(task.AsTask()); +#pragma warning restore AZC0107 // public asynchronous method shouldn't be called in synchronous scope. Use synchronous version of the method if it is available. + } #pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. return task.GetAwaiter().GetResult(); #pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. @@ -53,12 +66,18 @@ public static T EnsureCompleted(this ValueTask task) public static void EnsureCompleted(this ValueTask task) { -#if DEBUG - VerifyTaskCompleted(task.IsCompleted); -#endif + if (!task.IsCompleted) + { +#pragma warning disable AZC0107 // public asynchronous method shouldn't be called in synchronous scope. Use synchronous version of the method if it is available. + EnsureCompleted(task.AsTask()); +#pragma warning restore AZC0107 // public asynchronous method shouldn't be called in synchronous scope. Use synchronous version of the method if it is available. + } + else + { #pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. - task.GetAwaiter().GetResult(); + task.GetAwaiter().GetResult(); #pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. + } } public static Enumerable EnsureSyncEnumerable(this IAsyncEnumerable asyncEnumerable) => new Enumerable(asyncEnumerable); @@ -101,6 +120,9 @@ private static void VerifyTaskCompleted(bool isCompleted) } } + private static bool HasSynchronizationContext() + => SynchronizationContext.Current != null && SynchronizationContext.Current.GetType() != typeof(SynchronizationContext) || TaskScheduler.Current != TaskScheduler.Default; + /// /// Both and are defined as public structs so that foreach can use duck typing /// to call and avoid heap memory allocation. diff --git a/sdk/core/Azure.Core/tests/TaskExtensionsTest.cs b/sdk/core/Azure.Core/tests/TaskExtensionsTest.cs index 4ce2839ef361..9fe2dd399949 100644 --- a/sdk/core/Azure.Core/tests/TaskExtensionsTest.cs +++ b/sdk/core/Azure.Core/tests/TaskExtensionsTest.cs @@ -4,6 +4,7 @@ using Azure.Core.Pipeline; using NUnit.Framework; using System; +using System.Collections.Concurrent; using System.Threading; using System.Threading.Tasks; @@ -11,6 +12,136 @@ namespace Azure.Core.Tests { public class TaskExtensionsTest { + [Test] + public void TaskExtensions_TaskEnsureCompleted() + { + var task = Task.CompletedTask; + task.EnsureCompleted(); + } + + [Test] + public void TaskExtensions_TaskOfTEnsureCompleted() + { + var task = Task.FromResult(42); + Assert.AreEqual(42, task.EnsureCompleted()); + } + + [Test] + public void TaskExtensions_ValueTaskEnsureCompleted() + { + var task = new ValueTask(); + task.EnsureCompleted(); + } + + [Test] + public void TaskExtensions_ValueTaskOfTEnsureCompleted() + { + var task = new ValueTask(42); + Assert.AreEqual(42, task.EnsureCompleted()); + } + + [Test] + public async Task TaskExtensions_TaskEnsureCompleted_NotCompletedNoSyncContext() + { + var tcs = new TaskCompletionSource(); + Task task = tcs.Task; +#if DEBUG + Assert.Catch(() => task.EnsureCompleted()); + await Task.CompletedTask; +#else + Task runningTask = Task.Run(() => task.EnsureCompleted()); + Assert.IsFalse(runningTask.IsCompleted); + tcs.SetResult(0); + await runningTask; +#endif + } + + [Test] + public async Task TaskExtensions_TaskOfTEnsureCompleted_NotCompletedNoSyncContext() + { + var tcs = new TaskCompletionSource(); +#if DEBUG + Assert.Catch(() => tcs.Task.EnsureCompleted()); + await Task.CompletedTask; +#else + Task runningTask = Task.Run(() => tcs.Task.EnsureCompleted()); + Assert.IsFalse(runningTask.IsCompleted); + tcs.SetResult(42); + Assert.AreEqual(42, await runningTask); +#endif + } + + [Test] + public async Task TaskExtensions_ValueTaskEnsureCompleted_NotCompletedNoSyncContext() + { + var tcs = new TaskCompletionSource(); + ValueTask task = new ValueTask(tcs.Task); +#if DEBUG + Assert.Catch(() => task.EnsureCompleted()); + await Task.CompletedTask; +#else + Task runningTask = Task.Run(() => task.EnsureCompleted()); + Assert.IsFalse(runningTask.IsCompleted); + tcs.SetResult(0); + await runningTask; +#endif + } + + [Test] + public async Task TaskExtensions_ValueTaskOfTEnsureCompleted_NotCompletedNoSyncContext() + { + var tcs = new TaskCompletionSource(); + ValueTask task = new ValueTask(tcs.Task); +#if DEBUG + Assert.Catch(() => task.EnsureCompleted()); + await Task.CompletedTask; +#else + Task runningTask = Task.Run(() => task.EnsureCompleted()); + Assert.IsFalse(runningTask.IsCompleted); + tcs.SetResult(42); + Assert.AreEqual(42, await runningTask); +#endif + } + + [Test] + public void TaskExtensions_TaskEnsureCompleted_NotCompletedInSyncContext() + { + using SingleThreadedSynchronizationContext syncContext = new SingleThreadedSynchronizationContext(); + var tcs = new TaskCompletionSource(); + Task task = tcs.Task; + + syncContext.Post(t => { Assert.Catch(() => task.EnsureCompleted()); }, null); + } + + [Test] + public void TaskExtensions_TaskOfTEnsureCompleted_NotCompletedInSyncContext() + { + using SingleThreadedSynchronizationContext syncContext = new SingleThreadedSynchronizationContext(); + var tcs = new TaskCompletionSource(); + + syncContext.Post(t => { Assert.Catch(() => tcs.Task.EnsureCompleted()); }, null); + } + + [Test] + public void TaskExtensions_ValueTaskEnsureCompleted_NotCompletedInSyncContext() + { + using SingleThreadedSynchronizationContext syncContext = new SingleThreadedSynchronizationContext(); + var tcs = new TaskCompletionSource(); + ValueTask task = new ValueTask(tcs.Task); + + syncContext.Post(t => { Assert.Catch(() => task.EnsureCompleted()); }, null); + } + + [Test] + public void TaskExtensions_ValueTaskOfTEnsureCompleted_NotCompletedInSyncContext() + { + using SingleThreadedSynchronizationContext syncContext = new SingleThreadedSynchronizationContext(); + var tcs = new TaskCompletionSource(); + var task = new ValueTask(tcs.Task); + + syncContext.Post(t => { Assert.Catch(() => task.EnsureCompleted()); }, null); + } + [Test] public void TaskExtensions_TaskWithCancellationDefault() { @@ -192,5 +323,55 @@ public void TaskExtensions_ValueTaskWithCancellationFailedAfterContinuationSched Assert.AreEqual(true, awaiter.IsCompleted); Assert.Catch(() => awaiter.GetResult(), "Error"); } + + private sealed class SingleThreadedSynchronizationContext : SynchronizationContext, IDisposable + { + private readonly Task _task; + private readonly BlockingCollection _queue; + private readonly ConcurrentQueue _exceptions; + + public SingleThreadedSynchronizationContext() + { + _queue = new BlockingCollection(); + _exceptions = new ConcurrentQueue(); + _task = Task.Run(RunLoop); + } + + private void RunLoop() + { + try + { + SetSynchronizationContext(this); + while (!_queue.IsCompleted) + { + Action action = _queue.Take(); + try + { + action(); + } + catch (Exception e) + { + _exceptions.Enqueue(e); + } + } + } + catch (InvalidOperationException) { } + catch (OperationCanceledException) { } + finally + { + SetSynchronizationContext(null); + } + } + + public override void Post(SendOrPostCallback d, object state) => _queue.Add(() => d(state)); + + public void Dispose() + { + _queue.CompleteAdding(); + _task.Wait(); + } + + public AggregateException Exceptions => new AggregateException(_exceptions); + } } }