Skip to content

Commit add5764

Browse files
authored
Preserve CurrentThreadExecutor across create_task (#320)
Fixes #214.
1 parent d451a72 commit add5764

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

asgiref/sync.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import asyncio.coroutines
23
import contextvars
34
import functools
@@ -101,6 +102,10 @@ class AsyncToSync:
101102
# Local, not a threadlocal, so that tasks can work out what their parent used.
102103
executors = Local()
103104

105+
# When we can't find a CurrentThreadExecutor from the context, such as
106+
# inside create_task, we'll look it up here from the running event loop.
107+
loop_thread_executors: "Dict[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = {}
108+
104109
def __init__(self, awaitable, force_new_loop=False):
105110
if not callable(awaitable) or not _iscoroutinefunction_or_partial(awaitable):
106111
# Python does not have very reliable detection of async functions
@@ -164,6 +169,7 @@ def __call__(self, *args, **kwargs):
164169
old_current_executor = None
165170
current_executor = CurrentThreadExecutor()
166171
self.executors.current = current_executor
172+
loop = None
167173
# Use call_soon_threadsafe to schedule a synchronous callback on the
168174
# main event loop's thread if it's there, otherwise make a new loop
169175
# in this thread.
@@ -175,6 +181,7 @@ def __call__(self, *args, **kwargs):
175181
if not (self.main_event_loop and self.main_event_loop.is_running()):
176182
# Make our own event loop - in a new thread - and run inside that.
177183
loop = asyncio.new_event_loop()
184+
self.loop_thread_executors[loop] = current_executor
178185
loop_executor = ThreadPoolExecutor(max_workers=1)
179186
loop_future = loop_executor.submit(
180187
self._run_event_loop, loop, awaitable
@@ -194,6 +201,8 @@ def __call__(self, *args, **kwargs):
194201
current_executor.run_until_future(call_result)
195202
finally:
196203
# Clean up any executor we were running
204+
if loop is not None:
205+
del self.loop_thread_executors[loop]
197206
if hasattr(self.executors, "current"):
198207
del self.executors.current
199208
if old_current_executor:
@@ -378,6 +387,9 @@ async def __call__(self, *args, **kwargs):
378387
# Create new thread executor in current context
379388
executor = ThreadPoolExecutor(max_workers=1)
380389
self.context_to_thread_executor[thread_sensitive_context] = executor
390+
elif loop in AsyncToSync.loop_thread_executors:
391+
# Re-use thread executor for running loop
392+
executor = AsyncToSync.loop_thread_executors[loop]
381393
elif self.deadlock_context and self.deadlock_context.get(False):
382394
raise RuntimeError(
383395
"Single thread executor already being used, would deadlock"

tests/test_sync.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,15 +397,21 @@ def test_thread_sensitive_outside_sync():
397397
@async_to_sync
398398
async def middle():
399399
await inner()
400+
await asyncio.create_task(inner_task())
400401

401-
# Inner sync function
402+
# Inner sync functions
402403
@sync_to_async
403404
def inner():
404405
result["thread"] = threading.current_thread()
405406

407+
@sync_to_async
408+
def inner_task():
409+
result["thread2"] = threading.current_thread()
410+
406411
# Run it
407412
middle()
408413
assert result["thread"] == threading.current_thread()
414+
assert result["thread2"] == threading.current_thread()
409415

410416

411417
@pytest.mark.asyncio

0 commit comments

Comments
 (0)