Skip to content

Commit 51477a8

Browse files
committed
Revert "Fix Azure#10904: Task extensions synchronously complete ValueTasks (Azure#13719)"
This reverts commit 49bcc3e.
1 parent f9b7fcc commit 51477a8

File tree

2 files changed

+7
-210
lines changed

2 files changed

+7
-210
lines changed

sdk/core/Azure.Core/src/Shared/TaskExtensions.cs

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@ public static T EnsureCompleted<T>(this Task<T> task)
2525
{
2626
#if DEBUG
2727
VerifyTaskCompleted(task.IsCompleted);
28-
#else
29-
if (HasSynchronizationContext())
30-
{
31-
throw new InvalidOperationException("Synchronously waiting on non-completed task isn't allowed.");
32-
}
3328
#endif
3429
#pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead.
3530
return task.GetAwaiter().GetResult();
@@ -40,11 +35,6 @@ public static void EnsureCompleted(this Task task)
4035
{
4136
#if DEBUG
4237
VerifyTaskCompleted(task.IsCompleted);
43-
#else
44-
if (HasSynchronizationContext())
45-
{
46-
throw new InvalidOperationException("Synchronously waiting on non-completed task isn't allowed.");
47-
}
4838
#endif
4939
#pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead.
5040
task.GetAwaiter().GetResult();
@@ -53,31 +43,22 @@ public static void EnsureCompleted(this Task task)
5343

5444
public static T EnsureCompleted<T>(this ValueTask<T> task)
5545
{
56-
if (!task.IsCompleted)
57-
{
58-
#pragma warning disable AZC0107 // public asynchronous method shouldn't be called in synchronous scope. Use synchronous version of the method if it is available.
59-
return EnsureCompleted(task.AsTask());
60-
#pragma warning restore AZC0107 // public asynchronous method shouldn't be called in synchronous scope. Use synchronous version of the method if it is available.
61-
}
46+
#if DEBUG
47+
VerifyTaskCompleted(task.IsCompleted);
48+
#endif
6249
#pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead.
6350
return task.GetAwaiter().GetResult();
6451
#pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead.
6552
}
6653

6754
public static void EnsureCompleted(this ValueTask task)
6855
{
69-
if (!task.IsCompleted)
70-
{
71-
#pragma warning disable AZC0107 // public asynchronous method shouldn't be called in synchronous scope. Use synchronous version of the method if it is available.
72-
EnsureCompleted(task.AsTask());
73-
#pragma warning restore AZC0107 // public asynchronous method shouldn't be called in synchronous scope. Use synchronous version of the method if it is available.
74-
}
75-
else
76-
{
56+
#if DEBUG
57+
VerifyTaskCompleted(task.IsCompleted);
58+
#endif
7759
#pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead.
78-
task.GetAwaiter().GetResult();
60+
task.GetAwaiter().GetResult();
7961
#pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead.
80-
}
8162
}
8263

8364
public static Enumerable<T> EnsureSyncEnumerable<T>(this IAsyncEnumerable<T> asyncEnumerable) => new Enumerable<T>(asyncEnumerable);
@@ -120,9 +101,6 @@ private static void VerifyTaskCompleted(bool isCompleted)
120101
}
121102
}
122103

123-
private static bool HasSynchronizationContext()
124-
=> SynchronizationContext.Current != null && SynchronizationContext.Current.GetType() != typeof(SynchronizationContext) || TaskScheduler.Current != TaskScheduler.Default;
125-
126104
/// <summary>
127105
/// Both <see cref="Enumerable{T}"/> and <see cref="Enumerator{T}"/> are defined as public structs so that foreach can use duck typing
128106
/// to call <see cref="Enumerable{T}.GetEnumerator"/> and avoid heap memory allocation.

sdk/core/Azure.Core/tests/TaskExtensionsTest.cs

Lines changed: 0 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -4,144 +4,13 @@
44
using Azure.Core.Pipeline;
55
using NUnit.Framework;
66
using System;
7-
using System.Collections.Concurrent;
87
using System.Threading;
98
using System.Threading.Tasks;
109

1110
namespace Azure.Core.Tests
1211
{
1312
public class TaskExtensionsTest
1413
{
15-
[Test]
16-
public void TaskExtensions_TaskEnsureCompleted()
17-
{
18-
var task = Task.CompletedTask;
19-
task.EnsureCompleted();
20-
}
21-
22-
[Test]
23-
public void TaskExtensions_TaskOfTEnsureCompleted()
24-
{
25-
var task = Task.FromResult(42);
26-
Assert.AreEqual(42, task.EnsureCompleted());
27-
}
28-
29-
[Test]
30-
public void TaskExtensions_ValueTaskEnsureCompleted()
31-
{
32-
var task = new ValueTask();
33-
task.EnsureCompleted();
34-
}
35-
36-
[Test]
37-
public void TaskExtensions_ValueTaskOfTEnsureCompleted()
38-
{
39-
var task = new ValueTask<int>(42);
40-
Assert.AreEqual(42, task.EnsureCompleted());
41-
}
42-
43-
[Test]
44-
public async Task TaskExtensions_TaskEnsureCompleted_NotCompletedNoSyncContext()
45-
{
46-
var tcs = new TaskCompletionSource<int>();
47-
Task task = tcs.Task;
48-
#if DEBUG
49-
Assert.Catch<InvalidOperationException>(() => task.EnsureCompleted());
50-
await Task.CompletedTask;
51-
#else
52-
Task runningTask = Task.Run(() => task.EnsureCompleted());
53-
Assert.IsFalse(runningTask.IsCompleted);
54-
tcs.SetResult(0);
55-
await runningTask;
56-
#endif
57-
}
58-
59-
[Test]
60-
public async Task TaskExtensions_TaskOfTEnsureCompleted_NotCompletedNoSyncContext()
61-
{
62-
var tcs = new TaskCompletionSource<int>();
63-
#if DEBUG
64-
Assert.Catch<InvalidOperationException>(() => tcs.Task.EnsureCompleted());
65-
await Task.CompletedTask;
66-
#else
67-
Task<int> runningTask = Task.Run(() => tcs.Task.EnsureCompleted());
68-
Assert.IsFalse(runningTask.IsCompleted);
69-
tcs.SetResult(42);
70-
Assert.AreEqual(42, await runningTask);
71-
#endif
72-
}
73-
74-
[Test]
75-
public async Task TaskExtensions_ValueTaskEnsureCompleted_NotCompletedNoSyncContext()
76-
{
77-
var tcs = new TaskCompletionSource<int>();
78-
ValueTask task = new ValueTask(tcs.Task);
79-
#if DEBUG
80-
Assert.Catch<InvalidOperationException>(() => task.EnsureCompleted());
81-
await Task.CompletedTask;
82-
#else
83-
Task runningTask = Task.Run(() => task.EnsureCompleted());
84-
Assert.IsFalse(runningTask.IsCompleted);
85-
tcs.SetResult(0);
86-
await runningTask;
87-
#endif
88-
}
89-
90-
[Test]
91-
public async Task TaskExtensions_ValueTaskOfTEnsureCompleted_NotCompletedNoSyncContext()
92-
{
93-
var tcs = new TaskCompletionSource<int>();
94-
ValueTask<int> task = new ValueTask<int>(tcs.Task);
95-
#if DEBUG
96-
Assert.Catch<InvalidOperationException>(() => task.EnsureCompleted());
97-
await Task.CompletedTask;
98-
#else
99-
Task<int> runningTask = Task.Run(() => task.EnsureCompleted());
100-
Assert.IsFalse(runningTask.IsCompleted);
101-
tcs.SetResult(42);
102-
Assert.AreEqual(42, await runningTask);
103-
#endif
104-
}
105-
106-
[Test]
107-
public void TaskExtensions_TaskEnsureCompleted_NotCompletedInSyncContext()
108-
{
109-
using SingleThreadedSynchronizationContext syncContext = new SingleThreadedSynchronizationContext();
110-
var tcs = new TaskCompletionSource<int>();
111-
Task task = tcs.Task;
112-
113-
syncContext.Post(t => { Assert.Catch<InvalidOperationException>(() => task.EnsureCompleted()); }, null);
114-
}
115-
116-
[Test]
117-
public void TaskExtensions_TaskOfTEnsureCompleted_NotCompletedInSyncContext()
118-
{
119-
using SingleThreadedSynchronizationContext syncContext = new SingleThreadedSynchronizationContext();
120-
var tcs = new TaskCompletionSource<int>();
121-
122-
syncContext.Post(t => { Assert.Catch<InvalidOperationException>(() => tcs.Task.EnsureCompleted()); }, null);
123-
}
124-
125-
[Test]
126-
public void TaskExtensions_ValueTaskEnsureCompleted_NotCompletedInSyncContext()
127-
{
128-
using SingleThreadedSynchronizationContext syncContext = new SingleThreadedSynchronizationContext();
129-
var tcs = new TaskCompletionSource<int>();
130-
ValueTask task = new ValueTask(tcs.Task);
131-
132-
syncContext.Post(t => { Assert.Catch<InvalidOperationException>(() => task.EnsureCompleted()); }, null);
133-
}
134-
135-
[Test]
136-
public void TaskExtensions_ValueTaskOfTEnsureCompleted_NotCompletedInSyncContext()
137-
{
138-
using SingleThreadedSynchronizationContext syncContext = new SingleThreadedSynchronizationContext();
139-
var tcs = new TaskCompletionSource<int>();
140-
var task = new ValueTask<int>(tcs.Task);
141-
142-
syncContext.Post(t => { Assert.Catch<InvalidOperationException>(() => task.EnsureCompleted()); }, null);
143-
}
144-
14514
[Test]
14615
public void TaskExtensions_TaskWithCancellationDefault()
14716
{
@@ -323,55 +192,5 @@ public void TaskExtensions_ValueTaskWithCancellationFailedAfterContinuationSched
323192
Assert.AreEqual(true, awaiter.IsCompleted);
324193
Assert.Catch<OperationCanceledException>(() => awaiter.GetResult(), "Error");
325194
}
326-
327-
private sealed class SingleThreadedSynchronizationContext : SynchronizationContext, IDisposable
328-
{
329-
private readonly Task _task;
330-
private readonly BlockingCollection<Action> _queue;
331-
private readonly ConcurrentQueue<Exception> _exceptions;
332-
333-
public SingleThreadedSynchronizationContext()
334-
{
335-
_queue = new BlockingCollection<Action>();
336-
_exceptions = new ConcurrentQueue<Exception>();
337-
_task = Task.Run(RunLoop);
338-
}
339-
340-
private void RunLoop()
341-
{
342-
try
343-
{
344-
SetSynchronizationContext(this);
345-
while (!_queue.IsCompleted)
346-
{
347-
Action action = _queue.Take();
348-
try
349-
{
350-
action();
351-
}
352-
catch (Exception e)
353-
{
354-
_exceptions.Enqueue(e);
355-
}
356-
}
357-
}
358-
catch (InvalidOperationException) { }
359-
catch (OperationCanceledException) { }
360-
finally
361-
{
362-
SetSynchronizationContext(null);
363-
}
364-
}
365-
366-
public override void Post(SendOrPostCallback d, object state) => _queue.Add(() => d(state));
367-
368-
public void Dispose()
369-
{
370-
_queue.CompleteAdding();
371-
_task.Wait();
372-
}
373-
374-
public AggregateException Exceptions => new AggregateException(_exceptions);
375-
}
376195
}
377196
}

0 commit comments

Comments
 (0)