Skip to content

Commit 746c114

Browse files
bcmillsseifertm
authored andcommitted
Maintain contextvars.Context in fixtures and tests
The approach I've taken here is to maintain a contextvars.Context instance in a contextvars.ContextVar, copying it from the ambient context whenever we create a new event loop. The fixture setup and teardown run within that context, and each test function gets a copy (as if it were created as a new asyncio.Task from within the fixture task). Fixes #127.
1 parent ebbd602 commit 746c114

File tree

2 files changed

+101
-7
lines changed

2 files changed

+101
-7
lines changed

pytest_asyncio/plugin.py

+65-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import asyncio
66
import contextlib
7+
import contextvars
78
import enum
89
import functools
910
import inspect
@@ -318,6 +319,8 @@ def _asyncgen_fixture_wrapper(request: FixtureRequest, **kwargs: Any):
318319
kwargs.pop(event_loop_fixture_id, None)
319320
gen_obj = func(**_add_kwargs(func, kwargs, event_loop, request))
320321

322+
context = _event_loop_context.get(None)
323+
321324
async def setup():
322325
res = await gen_obj.__anext__() # type: ignore[union-attr]
323326
return res
@@ -335,9 +338,11 @@ async def async_finalizer() -> None:
335338
msg += "Yield only once."
336339
raise ValueError(msg)
337340

338-
event_loop.run_until_complete(async_finalizer())
341+
task = _create_task_in_context(event_loop, async_finalizer(), context)
342+
event_loop.run_until_complete(task)
339343

340-
result = event_loop.run_until_complete(setup())
344+
setup_task = _create_task_in_context(event_loop, setup(), context)
345+
result = event_loop.run_until_complete(setup_task)
341346
request.addfinalizer(finalizer)
342347
return result
343348

@@ -360,7 +365,10 @@ async def setup():
360365
res = await func(**_add_kwargs(func, kwargs, event_loop, request))
361366
return res
362367

363-
return event_loop.run_until_complete(setup())
368+
task = _create_task_in_context(
369+
event_loop, setup(), _event_loop_context.get(None)
370+
)
371+
return event_loop.run_until_complete(task)
364372

365373
fixturedef.func = _async_fixture_wrapper # type: ignore[misc]
366374

@@ -584,6 +592,46 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
584592
Session: "session",
585593
}
586594

595+
# _event_loop_context stores the Context in which asyncio tasks on the fixture
596+
# event loop should be run. After fixture setup, individual async test functions
597+
# are run on copies of this context.
598+
_event_loop_context: contextvars.ContextVar[contextvars.Context] = (
599+
contextvars.ContextVar("pytest_asyncio_event_loop_context")
600+
)
601+
602+
603+
@contextlib.contextmanager
604+
def _set_event_loop_context():
605+
"""Set event_loop_context to a copy of the calling thread's current context."""
606+
context = contextvars.copy_context()
607+
token = _event_loop_context.set(context)
608+
try:
609+
yield
610+
finally:
611+
_event_loop_context.reset(token)
612+
613+
614+
def _create_task_in_context(loop, coro, context):
615+
"""
616+
Return an asyncio task that runs the coro in the specified context,
617+
if possible.
618+
619+
This allows fixture setup and teardown to be run as separate asyncio tasks,
620+
while still being able to use context-manager idioms to maintain context
621+
variables and make those variables visible to test functions.
622+
623+
This is only fully supported on Python 3.11 and newer, as it requires
624+
the API added for https://github.com/python/cpython/issues/91150.
625+
On earlier versions, the returned task will use the default context instead.
626+
"""
627+
if context is not None:
628+
try:
629+
return loop.create_task(coro, context=context)
630+
except TypeError:
631+
pass
632+
return loop.create_task(coro)
633+
634+
587635
# A stack used to push package-scoped loops during collection of a package
588636
# and pop those loops during collection of a Module
589637
__package_loop_stack: list[FixtureFunctionMarker | FixtureFunction] = []
@@ -631,7 +679,8 @@ def scoped_event_loop(
631679
loop = asyncio.new_event_loop()
632680
loop.__pytest_asyncio = True # type: ignore[attr-defined]
633681
asyncio.set_event_loop(loop)
634-
yield loop
682+
with _set_event_loop_context():
683+
yield loop
635684
loop.close()
636685

637686
# @pytest.fixture does not register the fixture anywhere, so pytest doesn't
@@ -938,9 +987,16 @@ def wrap_in_sync(
938987

939988
@functools.wraps(func)
940989
def inner(*args, **kwargs):
990+
# Give each test its own context based on the loop's main context.
991+
context = _event_loop_context.get(None)
992+
if context is not None:
993+
# We are using our own event loop fixture, so make a new copy of the
994+
# fixture context so that the test won't pollute it.
995+
context = context.copy()
996+
941997
coro = func(*args, **kwargs)
942998
_loop = _get_event_loop_no_warn()
943-
task = asyncio.ensure_future(coro, loop=_loop)
999+
task = _create_task_in_context(_loop, coro, context)
9441000
try:
9451001
_loop.run_until_complete(task)
9461002
except BaseException:
@@ -1049,7 +1105,8 @@ def event_loop(request: FixtureRequest) -> Iterator[asyncio.AbstractEventLoop]:
10491105
# The magic value must be set as part of the function definition, because pytest
10501106
# seems to have multiple instances of the same FixtureDef or fixture function
10511107
loop.__original_fixture_loop = True # type: ignore[attr-defined]
1052-
yield loop
1108+
with _set_event_loop_context():
1109+
yield loop
10531110
loop.close()
10541111

10551112

@@ -1062,7 +1119,8 @@ def _session_event_loop(
10621119
loop = asyncio.new_event_loop()
10631120
loop.__pytest_asyncio = True # type: ignore[attr-defined]
10641121
asyncio.set_event_loop(loop)
1065-
yield loop
1122+
with _set_event_loop_context():
1123+
yield loop
10661124
loop.close()
10671125

10681126

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""
2+
Regression test for https://github.com/pytest-dev/pytest-asyncio/issues/127:
3+
contextvars were not properly maintained among fixtures and tests.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
import sys
9+
from contextlib import asynccontextmanager
10+
from contextvars import ContextVar
11+
12+
import pytest
13+
14+
15+
@asynccontextmanager
16+
async def context_var_manager():
17+
context_var = ContextVar("context_var")
18+
token = context_var.set("value")
19+
try:
20+
yield context_var
21+
finally:
22+
context_var.reset(token)
23+
24+
25+
@pytest.fixture(scope="function")
26+
async def context_var():
27+
async with context_var_manager() as v:
28+
yield v
29+
30+
31+
@pytest.mark.asyncio
32+
@pytest.mark.xfail(
33+
sys.version_info < (3, 11), reason="requires asyncio Task context support"
34+
)
35+
async def test(context_var):
36+
assert context_var.get() == "value"

0 commit comments

Comments
 (0)