Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions sdk/core/Azure.Core/src/Shared/TaskExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ public static T EnsureCompleted<T>(this Task<T> task)
{
#if DEBUG
VerifyTaskCompleted(task.IsCompleted);
#else
if (HasSynchronizationContext())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is worth doing. Might break customers apps that worked just fine before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may for some custom non-blocking sync context. In that case, they will report a bug and use Task.Run as a workaround while we fix it. However, if we don't throw in this case, the most common sync context scenario of interaction with UI will hang without any clue for the user why it has happened.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is of course if we have any bugs like that at all (this exception should be thrown only if we have sync code that waits on incomplete task while should only get the result of completed).

{
throw new InvalidOperationException("Synchronously waiting on non-completed task isn't allowed.");
}
#endif
#pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead.
return task.GetAwaiter().GetResult();
Expand All @@ -35,6 +40,11 @@ public static void EnsureCompleted(this Task task)
{
#if DEBUG
VerifyTaskCompleted(task.IsCompleted);
#else
if (HasSynchronizationContext())
{
throw new InvalidOperationException("Synchronously waiting on non-completed task isn't allowed.");
}
#endif
#pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead.
task.GetAwaiter().GetResult();
Expand All @@ -43,22 +53,31 @@ public static void EnsureCompleted(this Task task)

public static T EnsureCompleted<T>(this ValueTask<T> task)
{
#if DEBUG
VerifyTaskCompleted(task.IsCompleted);
#endif
if (!task.IsCompleted)
{
#pragma warning disable AZC0107 // public asynchronous method shouldn't be called in synchronous scope. Use synchronous version of the method if it is available.
return EnsureCompleted(task.AsTask());
#pragma warning restore AZC0107 // public asynchronous method shouldn't be called in synchronous scope. Use synchronous version of the method if it is available.
}
#pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead.
return task.GetAwaiter().GetResult();
#pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead.
}

public static void EnsureCompleted(this ValueTask task)
{
#if DEBUG
VerifyTaskCompleted(task.IsCompleted);
#endif
if (!task.IsCompleted)
{
#pragma warning disable AZC0107 // public asynchronous method shouldn't be called in synchronous scope. Use synchronous version of the method if it is available.
EnsureCompleted(task.AsTask());
#pragma warning restore AZC0107 // public asynchronous method shouldn't be called in synchronous scope. Use synchronous version of the method if it is available.
}
else
{
#pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead.
task.GetAwaiter().GetResult();
task.GetAwaiter().GetResult();
#pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead.
}
}

public static Enumerable<T> EnsureSyncEnumerable<T>(this IAsyncEnumerable<T> asyncEnumerable) => new Enumerable<T>(asyncEnumerable);
Expand Down Expand Up @@ -101,6 +120,9 @@ private static void VerifyTaskCompleted(bool isCompleted)
}
}

private static bool HasSynchronizationContext()
=> SynchronizationContext.Current != null && SynchronizationContext.Current.GetType() != typeof(SynchronizationContext) || TaskScheduler.Current != TaskScheduler.Default;

/// <summary>
/// Both <see cref="Enumerable{T}"/> and <see cref="Enumerator{T}"/> are defined as public structs so that foreach can use duck typing
/// to call <see cref="Enumerable{T}.GetEnumerator"/> and avoid heap memory allocation.
Expand Down