diff --git a/src/docket/dependencies.py b/src/docket/dependencies.py index 704ed42f..40daddd4 100644 --- a/src/docket/dependencies.py +++ b/src/docket/dependencies.py @@ -697,38 +697,50 @@ def __init__(self, parameter: str, error: Exception) -> None: async def resolved_dependencies( worker: "Worker", execution: Execution ) -> AsyncGenerator[dict[str, Any], None]: - # Set context variables once at the beginning - Dependency.docket.set(worker.docket) - Dependency.worker.set(worker) - Dependency.execution.set(execution) - - _Depends.cache.set({}) - - async with AsyncExitStack() as stack: - _Depends.stack.set(stack) - - arguments: dict[str, Any] = {} - - parameters = get_dependency_parameters(execution.function) - for parameter, dependency in parameters.items(): - kwargs = execution.kwargs - if parameter in kwargs: - arguments[parameter] = kwargs[parameter] - continue - - # Special case for TaskArguments, they are "magical" and infer the parameter - # they refer to from the parameter name (unless otherwise specified). At - # the top-level task function call, it doesn't make sense to specify one - # _without_ a parameter name, so we'll call that a failed dependency. - if isinstance(dependency, _TaskArgument) and not dependency.parameter: - arguments[parameter] = FailedDependency( - parameter, ValueError("No parameter name specified") - ) - continue - + # Capture tokens for all contextvar sets to ensure proper cleanup + docket_token = Dependency.docket.set(worker.docket) + worker_token = Dependency.worker.set(worker) + execution_token = Dependency.execution.set(execution) + cache_token = _Depends.cache.set({}) + + try: + async with AsyncExitStack() as stack: + stack_token = _Depends.stack.set(stack) try: - arguments[parameter] = await stack.enter_async_context(dependency) - except Exception as error: - arguments[parameter] = FailedDependency(parameter, error) - - yield arguments + arguments: dict[str, Any] = {} + + parameters = get_dependency_parameters(execution.function) + for parameter, dependency in parameters.items(): + kwargs = execution.kwargs + if parameter in kwargs: + arguments[parameter] = kwargs[parameter] + continue + + # Special case for TaskArguments, they are "magical" and infer the parameter + # they refer to from the parameter name (unless otherwise specified). At + # the top-level task function call, it doesn't make sense to specify one + # _without_ a parameter name, so we'll call that a failed dependency. + if ( + isinstance(dependency, _TaskArgument) + and not dependency.parameter + ): + arguments[parameter] = FailedDependency( + parameter, ValueError("No parameter name specified") + ) + continue + + try: + arguments[parameter] = await stack.enter_async_context( + dependency + ) + except Exception as error: + arguments[parameter] = FailedDependency(parameter, error) + + yield arguments + finally: + _Depends.stack.reset(stack_token) + finally: + _Depends.cache.reset(cache_token) + Dependency.execution.reset(execution_token) + Dependency.worker.reset(worker_token) + Dependency.docket.reset(docket_token) diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py index e802e93c..7d0c201c 100644 --- a/tests/test_dependencies.py +++ b/tests/test_dependencies.py @@ -1,11 +1,20 @@ import logging -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from datetime import datetime, timedelta, timezone import pytest from docket import CurrentDocket, CurrentWorker, Docket, Worker -from docket.dependencies import Depends, ExponentialRetry, Retry, TaskArgument +from docket.dependencies import ( + Depends, + Dependency, + ExponentialRetry, + Retry, + TaskArgument, + _Depends, # type: ignore[attr-defined] + resolved_dependencies, +) +from docket.execution import Execution async def test_dependencies_may_be_duplicated(docket: Docket, worker: Worker): @@ -449,3 +458,206 @@ async def dependent_task(result: int = Depends(sync_adder)): await worker.run_until_finished() assert called + + +async def test_contextvar_isolation_between_tasks(docket: Docket, worker: Worker): + """Contextvars should be isolated between sequential task executions""" + executions_seen: list[tuple[str, Execution]] = [] + + async def first_task(a: str): + # Capture the execution context during first task + execution = Dependency.execution.get() + executions_seen.append(("first", execution)) + assert a == "first" + + async def second_task(b: str): + # Capture the execution context during second task + execution = Dependency.execution.get() + executions_seen.append(("second", execution)) + assert b == "second" + + await docket.add(first_task)(a="first") + await docket.add(second_task)(b="second") + await worker.run_until_finished() + + # Verify we captured both executions + assert len(executions_seen) == 2 + + # Find first and second executions (order may vary) + executions_by_name = {name: exec for name, exec in executions_seen} + assert set(executions_by_name.keys()) == {"first", "second"} + + # Verify the executions are different and have correct kwargs + first_execution = executions_by_name["first"] + second_execution = executions_by_name["second"] + assert first_execution is not second_execution + assert first_execution.kwargs["a"] == "first" + assert second_execution.kwargs["b"] == "second" + + +async def test_contextvar_cleanup_after_task(docket: Docket, worker: Worker): + """Contextvars should be reset after task execution completes""" + captured_stack = None + captured_cache = None + + async def capture_task(): + nonlocal captured_stack, captured_cache + # Capture references during task execution + captured_stack = _Depends.stack.get() + captured_cache = _Depends.cache.get() + + await docket.add(capture_task)() + await worker.run_until_finished() + + # After the task completes, the contextvars should be reset + # Attempting to get them should raise LookupError + with pytest.raises(LookupError): + _Depends.stack.get() + + with pytest.raises(LookupError): + _Depends.cache.get() + + with pytest.raises(LookupError): + Dependency.execution.get() + + with pytest.raises(LookupError): + Dependency.worker.get() + + with pytest.raises(LookupError): + Dependency.docket.get() + + +async def test_dependency_cache_isolated_between_tasks(docket: Docket, worker: Worker): + """Dependency cache should be fresh for each task, not reused""" + call_counts = {"task1": 0, "task2": 0} + + def dependency_for_task1() -> str: + call_counts["task1"] += 1 + return f"task1-call-{call_counts['task1']}" + + def dependency_for_task2() -> str: + call_counts["task2"] += 1 + return f"task2-call-{call_counts['task2']}" + + async def first_task(val: str = Depends(dependency_for_task1)): + assert val == "task1-call-1" + + async def second_task(val: str = Depends(dependency_for_task2)): + assert val == "task2-call-1" + + # Run tasks sequentially + await docket.add(first_task)() + await worker.run_until_finished() + + await docket.add(second_task)() + await worker.run_until_finished() + + # Each dependency should have been called once (no cache leakage between tasks) + assert call_counts["task1"] == 1 + assert call_counts["task2"] == 1 + + +async def test_async_exit_stack_cleanup(docket: Docket, worker: Worker): + """AsyncExitStack should be properly cleaned up after task execution""" + cleanup_called: list[str] = [] + + @asynccontextmanager + async def tracked_resource(): + try: + yield "resource" + finally: + cleanup_called.append("cleaned") + + async def task_with_context(res: str = Depends(tracked_resource)): + assert res == "resource" + assert len(cleanup_called) == 0 # Not cleaned up yet + + await docket.add(task_with_context)() + await worker.run_until_finished() + + # After task completes, cleanup should have been called + assert cleanup_called == ["cleaned"] + + +async def test_contextvar_reset_on_reentrant_call(docket: Docket, worker: Worker): + """Contextvars should be properly reset on reentrant calls to resolved_dependencies""" + + # Create two mock executions + async def task1(): ... + + async def task2(): ... + + execution1 = Execution( + key="task1-key", + function=task1, + args=(), + kwargs={}, + attempt=1, + when=datetime.now(timezone.utc), + ) + + execution2 = Execution( + key="task2-key", + function=task2, + args=(), + kwargs={}, + attempt=1, + when=datetime.now(timezone.utc), + ) + + # Capture contextvars from first call + captured_exec1 = None + captured_stack1 = None + + async with resolved_dependencies(worker, execution1): + captured_exec1 = Dependency.execution.get() + captured_stack1 = _Depends.stack.get() + assert captured_exec1 is execution1 + + # After exiting, contextvars should be reset (raise LookupError) + with pytest.raises(LookupError): + Dependency.execution.get() + + # Now make a second call - should not see values from first call + async with resolved_dependencies(worker, execution2): + captured_exec2 = Dependency.execution.get() + captured_stack2 = _Depends.stack.get() + assert captured_exec2 is execution2 + assert captured_exec2 is not captured_exec1 + # Stacks should be different objects + assert captured_stack2 is not captured_stack1 + + +async def test_contextvar_not_leaked_to_caller(docket: Docket): + """Verify contextvars don't leak outside resolved_dependencies context""" + # Before calling resolved_dependencies, contextvars should not be set + with pytest.raises(LookupError): + Dependency.execution.get() + + async def dummy_task(): ... + + execution = Execution( + key="test-key", + function=dummy_task, + args=(), + kwargs={}, + attempt=1, + when=datetime.now(timezone.utc), + ) + + async with Docket("test-contextvar-leak", url="memory://leak-test") as test_docket: + async with Worker(test_docket) as test_worker: + # Use resolved_dependencies + async with resolved_dependencies(test_worker, execution): + # Inside context, we should be able to get values + assert Dependency.execution.get() is execution + + # After exiting context, contextvars should be cleaned up + with pytest.raises(LookupError): + Dependency.execution.get() + + with pytest.raises(LookupError): + _Depends.stack.get() + + with pytest.raises(LookupError): # pragma: no branch + _Depends.cache.get()