Skip to content

Commit

Permalink
bpo-46994: Accept explicit contextvars.Context in asyncio create_task…
Browse files Browse the repository at this point in the history
…() API (GH-31837)
  • Loading branch information
asvetlov authored Mar 14, 2022
1 parent 2153daf commit 9523c0d
Show file tree
Hide file tree
Showing 13 changed files with 209 additions and 65 deletions.
11 changes: 9 additions & 2 deletions Doc/library/asyncio-eventloop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ Creating Futures and Tasks

.. versionadded:: 3.5.2

.. method:: loop.create_task(coro, *, name=None)
.. method:: loop.create_task(coro, *, name=None, context=None)

Schedule the execution of a :ref:`coroutine`.
Return a :class:`Task` object.
Expand All @@ -342,17 +342,24 @@ Creating Futures and Tasks
If the *name* argument is provided and not ``None``, it is set as
the name of the task using :meth:`Task.set_name`.

An optional keyword-only *context* argument allows specifying a
custom :class:`contextvars.Context` for the *coro* to run in.
The current context copy is created when no *context* is provided.

.. versionchanged:: 3.8
Added the *name* parameter.

.. versionchanged:: 3.11
Added the *context* parameter.

.. method:: loop.set_task_factory(factory)

Set a task factory that will be used by
:meth:`loop.create_task`.

If *factory* is ``None`` the default task factory will be set.
Otherwise, *factory* must be a *callable* with the signature matching
``(loop, coro)``, where *loop* is a reference to the active
``(loop, coro, context=None)``, where *loop* is a reference to the active
event loop, and *coro* is a coroutine object. The callable
must return a :class:`asyncio.Future`-compatible object.

Expand Down
9 changes: 8 additions & 1 deletion Doc/library/asyncio-task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,18 @@ Running an asyncio Program
Creating Tasks
==============

.. function:: create_task(coro, *, name=None)
.. function:: create_task(coro, *, name=None, context=None)

Wrap the *coro* :ref:`coroutine <coroutine>` into a :class:`Task`
and schedule its execution. Return the Task object.

If *name* is not ``None``, it is set as the name of the task using
:meth:`Task.set_name`.

An optional keyword-only *context* argument allows specifying a
custom :class:`contextvars.Context` for the *coro* to run in.
The current context copy is created when no *context* is provided.

The task is executed in the loop returned by :func:`get_running_loop`,
:exc:`RuntimeError` is raised if there is no running loop in
current thread.
Expand Down Expand Up @@ -281,6 +285,9 @@ Creating Tasks
.. versionchanged:: 3.8
Added the *name* parameter.

.. versionchanged:: 3.11
Added the *context* parameter.


Sleeping
========
Expand Down
11 changes: 8 additions & 3 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,18 +426,23 @@ def create_future(self):
"""Create a Future object attached to the loop."""
return futures.Future(loop=self)

def create_task(self, coro, *, name=None):
def create_task(self, coro, *, name=None, context=None):
"""Schedule a coroutine object.
Return a task object.
"""
self._check_closed()
if self._task_factory is None:
task = tasks.Task(coro, loop=self, name=name)
task = tasks.Task(coro, loop=self, name=name, context=context)
if task._source_traceback:
del task._source_traceback[-1]
else:
task = self._task_factory(self, coro)
if context is None:
# Use legacy API if context is not needed
task = self._task_factory(self, coro)
else:
task = self._task_factory(self, coro, context=context)

tasks._set_task_name(task, name)

return task
Expand Down
2 changes: 1 addition & 1 deletion Lib/asyncio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def create_future(self):

# Method scheduling a coroutine object: create a task.

def create_task(self, coro, *, name=None):
def create_task(self, coro, *, name=None, context=None):
raise NotImplementedError

# Methods for interacting with threads.
Expand Down
7 changes: 5 additions & 2 deletions Lib/asyncio/taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,15 @@ async def __aexit__(self, et, exc, tb):
me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)
raise me from None

def create_task(self, coro, *, name=None):
def create_task(self, coro, *, name=None, context=None):
if not self._entered:
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
if self._exiting and self._unfinished_tasks == 0:
raise RuntimeError(f"TaskGroup {self!r} is finished")
task = self._loop.create_task(coro)
if context is None:
task = self._loop.create_task(coro)
else:
task = self._loop.create_task(coro, context=context)
tasks._set_task_name(task, name)
task.add_done_callback(self._on_task_done)
self._unfinished_tasks += 1
Expand Down
16 changes: 12 additions & 4 deletions Lib/asyncio/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
# status is still pending
_log_destroy_pending = True

def __init__(self, coro, *, loop=None, name=None):
def __init__(self, coro, *, loop=None, name=None, context=None):
super().__init__(loop=loop)
if self._source_traceback:
del self._source_traceback[-1]
Expand All @@ -112,7 +112,10 @@ def __init__(self, coro, *, loop=None, name=None):
self._must_cancel = False
self._fut_waiter = None
self._coro = coro
self._context = contextvars.copy_context()
if context is None:
self._context = contextvars.copy_context()
else:
self._context = context

self._loop.call_soon(self.__step, context=self._context)
_register_task(self)
Expand Down Expand Up @@ -360,13 +363,18 @@ def __wakeup(self, future):
Task = _CTask = _asyncio.Task


def create_task(coro, *, name=None):
def create_task(coro, *, name=None, context=None):
"""Schedule the execution of a coroutine object in a spawn task.
Return a Task object.
"""
loop = events.get_running_loop()
task = loop.create_task(coro)
if context is None:
# Use legacy API if context is not needed
task = loop.create_task(coro)
else:
task = loop.create_task(coro, context=context)

_set_task_name(task, name)
return task

Expand Down
18 changes: 18 additions & 0 deletions Lib/test/test_asyncio/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


import asyncio
import contextvars

from asyncio import taskgroups
import unittest
Expand Down Expand Up @@ -708,6 +709,23 @@ async def coro():
t = g.create_task(coro(), name="yolo")
self.assertEqual(t.get_name(), "yolo")

async def test_taskgroup_task_context(self):
cvar = contextvars.ContextVar('cvar')

async def coro(val):
await asyncio.sleep(0)
cvar.set(val)

async with taskgroups.TaskGroup() as g:
ctx = contextvars.copy_context()
self.assertIsNone(ctx.get(cvar))
t1 = g.create_task(coro(1), context=ctx)
await t1
self.assertEqual(1, ctx.get(cvar))
t2 = g.create_task(coro(2), context=ctx)
await t2
self.assertEqual(2, ctx.get(cvar))


if __name__ == "__main__":
unittest.main()
88 changes: 86 additions & 2 deletions Lib/test/test_asyncio/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ class BaseTaskTests:
Task = None
Future = None

def new_task(self, loop, coro, name='TestTask'):
return self.__class__.Task(coro, loop=loop, name=name)
def new_task(self, loop, coro, name='TestTask', context=None):
return self.__class__.Task(coro, loop=loop, name=name, context=context)

def new_future(self, loop):
return self.__class__.Future(loop=loop)
Expand Down Expand Up @@ -2527,6 +2527,90 @@ async def main():

self.assertEqual(cvar.get(), -1)

def test_context_4(self):
cvar = contextvars.ContextVar('cvar')

async def coro(val):
await asyncio.sleep(0)
cvar.set(val)

async def main():
ret = []
ctx = contextvars.copy_context()
ret.append(ctx.get(cvar))
t1 = self.new_task(loop, coro(1), context=ctx)
await t1
ret.append(ctx.get(cvar))
t2 = self.new_task(loop, coro(2), context=ctx)
await t2
ret.append(ctx.get(cvar))
return ret

loop = asyncio.new_event_loop()
try:
task = self.new_task(loop, main())
ret = loop.run_until_complete(task)
finally:
loop.close()

self.assertEqual([None, 1, 2], ret)

def test_context_5(self):
cvar = contextvars.ContextVar('cvar')

async def coro(val):
await asyncio.sleep(0)
cvar.set(val)

async def main():
ret = []
ctx = contextvars.copy_context()
ret.append(ctx.get(cvar))
t1 = asyncio.create_task(coro(1), context=ctx)
await t1
ret.append(ctx.get(cvar))
t2 = asyncio.create_task(coro(2), context=ctx)
await t2
ret.append(ctx.get(cvar))
return ret

loop = asyncio.new_event_loop()
try:
task = self.new_task(loop, main())
ret = loop.run_until_complete(task)
finally:
loop.close()

self.assertEqual([None, 1, 2], ret)

def test_context_6(self):
cvar = contextvars.ContextVar('cvar')

async def coro(val):
await asyncio.sleep(0)
cvar.set(val)

async def main():
ret = []
ctx = contextvars.copy_context()
ret.append(ctx.get(cvar))
t1 = loop.create_task(coro(1), context=ctx)
await t1
ret.append(ctx.get(cvar))
t2 = loop.create_task(coro(2), context=ctx)
await t2
ret.append(ctx.get(cvar))
return ret

loop = asyncio.new_event_loop()
try:
task = loop.create_task(main())
ret = loop.run_until_complete(task)
finally:
loop.close()

self.assertEqual([None, 1, 2], ret)

def test_get_coro(self):
loop = asyncio.new_event_loop()
coro = coroutine_function()
Expand Down
55 changes: 17 additions & 38 deletions Lib/unittest/async_case.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextvars
import inspect
import warnings

Expand Down Expand Up @@ -34,7 +35,7 @@ class IsolatedAsyncioTestCase(TestCase):
def __init__(self, methodName='runTest'):
super().__init__(methodName)
self._asyncioTestLoop = None
self._asyncioCallsQueue = None
self._asyncioTestContext = contextvars.copy_context()

async def asyncSetUp(self):
pass
Expand All @@ -58,7 +59,7 @@ def addAsyncCleanup(self, func, /, *args, **kwargs):
self.addCleanup(*(func, *args), **kwargs)

def _callSetUp(self):
self.setUp()
self._asyncioTestContext.run(self.setUp)
self._callAsync(self.asyncSetUp)

def _callTestMethod(self, method):
Expand All @@ -68,64 +69,42 @@ def _callTestMethod(self, method):

def _callTearDown(self):
self._callAsync(self.asyncTearDown)
self.tearDown()
self._asyncioTestContext.run(self.tearDown)

def _callCleanup(self, function, *args, **kwargs):
self._callMaybeAsync(function, *args, **kwargs)

def _callAsync(self, func, /, *args, **kwargs):
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
ret = func(*args, **kwargs)
assert inspect.isawaitable(ret), f'{func!r} returned non-awaitable'
fut = self._asyncioTestLoop.create_future()
self._asyncioCallsQueue.put_nowait((fut, ret))
return self._asyncioTestLoop.run_until_complete(fut)
assert inspect.iscoroutinefunction(func), f'{func!r} is not an async function'
task = self._asyncioTestLoop.create_task(
func(*args, **kwargs),
context=self._asyncioTestContext,
)
return self._asyncioTestLoop.run_until_complete(task)

def _callMaybeAsync(self, func, /, *args, **kwargs):
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
ret = func(*args, **kwargs)
if inspect.isawaitable(ret):
fut = self._asyncioTestLoop.create_future()
self._asyncioCallsQueue.put_nowait((fut, ret))
return self._asyncioTestLoop.run_until_complete(fut)
if inspect.iscoroutinefunction(func):
task = self._asyncioTestLoop.create_task(
func(*args, **kwargs),
context=self._asyncioTestContext,
)
return self._asyncioTestLoop.run_until_complete(task)
else:
return ret

async def _asyncioLoopRunner(self, fut):
self._asyncioCallsQueue = queue = asyncio.Queue()
fut.set_result(None)
while True:
query = await queue.get()
queue.task_done()
if query is None:
return
fut, awaitable = query
try:
ret = await awaitable
if not fut.cancelled():
fut.set_result(ret)
except (SystemExit, KeyboardInterrupt):
raise
except (BaseException, asyncio.CancelledError) as ex:
if not fut.cancelled():
fut.set_exception(ex)
return self._asyncioTestContext.run(func, *args, **kwargs)

def _setupAsyncioLoop(self):
assert self._asyncioTestLoop is None, 'asyncio test loop already initialized'
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.set_debug(True)
self._asyncioTestLoop = loop
fut = loop.create_future()
self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut))
loop.run_until_complete(fut)

def _tearDownAsyncioLoop(self):
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
loop = self._asyncioTestLoop
self._asyncioTestLoop = None
self._asyncioCallsQueue.put_nowait(None)
loop.run_until_complete(self._asyncioCallsQueue.join())

try:
# cancel all tasks
Expand Down
Loading

0 comments on commit 9523c0d

Please sign in to comment.