Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 46 additions & 34 deletions src/docket/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,38 +697,50 @@ def __init__(self, parameter: str, error: Exception) -> None:
async def resolved_dependencies(
worker: "Worker", execution: Execution
) -> AsyncGenerator[dict[str, Any], None]:
# Set context variables once at the beginning
Dependency.docket.set(worker.docket)
Dependency.worker.set(worker)
Dependency.execution.set(execution)

_Depends.cache.set({})

async with AsyncExitStack() as stack:
_Depends.stack.set(stack)

arguments: dict[str, Any] = {}

parameters = get_dependency_parameters(execution.function)
for parameter, dependency in parameters.items():
kwargs = execution.kwargs
if parameter in kwargs:
arguments[parameter] = kwargs[parameter]
continue

# Special case for TaskArguments, they are "magical" and infer the parameter
# they refer to from the parameter name (unless otherwise specified). At
# the top-level task function call, it doesn't make sense to specify one
# _without_ a parameter name, so we'll call that a failed dependency.
if isinstance(dependency, _TaskArgument) and not dependency.parameter:
arguments[parameter] = FailedDependency(
parameter, ValueError("No parameter name specified")
)
continue

# Capture tokens for all contextvar sets to ensure proper cleanup
docket_token = Dependency.docket.set(worker.docket)
worker_token = Dependency.worker.set(worker)
execution_token = Dependency.execution.set(execution)
cache_token = _Depends.cache.set({})

try:
async with AsyncExitStack() as stack:
stack_token = _Depends.stack.set(stack)
try:
arguments[parameter] = await stack.enter_async_context(dependency)
except Exception as error:
arguments[parameter] = FailedDependency(parameter, error)

yield arguments
arguments: dict[str, Any] = {}

parameters = get_dependency_parameters(execution.function)
for parameter, dependency in parameters.items():
kwargs = execution.kwargs
if parameter in kwargs:
arguments[parameter] = kwargs[parameter]
continue

# Special case for TaskArguments, they are "magical" and infer the parameter
# they refer to from the parameter name (unless otherwise specified). At
# the top-level task function call, it doesn't make sense to specify one
# _without_ a parameter name, so we'll call that a failed dependency.
if (
isinstance(dependency, _TaskArgument)
and not dependency.parameter
):
arguments[parameter] = FailedDependency(
parameter, ValueError("No parameter name specified")
)
continue

try:
arguments[parameter] = await stack.enter_async_context(
dependency
)
except Exception as error:
arguments[parameter] = FailedDependency(parameter, error)

yield arguments
finally:
_Depends.stack.reset(stack_token)
finally:
_Depends.cache.reset(cache_token)
Dependency.execution.reset(execution_token)
Dependency.worker.reset(worker_token)
Dependency.docket.reset(docket_token)
216 changes: 214 additions & 2 deletions tests/test_dependencies.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import logging
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime, timedelta, timezone

import pytest

from docket import CurrentDocket, CurrentWorker, Docket, Worker
from docket.dependencies import Depends, ExponentialRetry, Retry, TaskArgument
from docket.dependencies import (
Depends,
Dependency,
ExponentialRetry,
Retry,
TaskArgument,
_Depends, # type: ignore[attr-defined]
resolved_dependencies,
)
from docket.execution import Execution


async def test_dependencies_may_be_duplicated(docket: Docket, worker: Worker):
Expand Down Expand Up @@ -449,3 +458,206 @@ async def dependent_task(result: int = Depends(sync_adder)):
await worker.run_until_finished()

assert called


async def test_contextvar_isolation_between_tasks(docket: Docket, worker: Worker):
"""Contextvars should be isolated between sequential task executions"""
executions_seen: list[tuple[str, Execution]] = []

async def first_task(a: str):
# Capture the execution context during first task
execution = Dependency.execution.get()
executions_seen.append(("first", execution))
assert a == "first"

async def second_task(b: str):
# Capture the execution context during second task
execution = Dependency.execution.get()
executions_seen.append(("second", execution))
assert b == "second"

await docket.add(first_task)(a="first")
await docket.add(second_task)(b="second")
await worker.run_until_finished()

# Verify we captured both executions
assert len(executions_seen) == 2

# Find first and second executions (order may vary)
executions_by_name = {name: exec for name, exec in executions_seen}
assert set(executions_by_name.keys()) == {"first", "second"}

# Verify the executions are different and have correct kwargs
first_execution = executions_by_name["first"]
second_execution = executions_by_name["second"]
assert first_execution is not second_execution
assert first_execution.kwargs["a"] == "first"
assert second_execution.kwargs["b"] == "second"


async def test_contextvar_cleanup_after_task(docket: Docket, worker: Worker):
"""Contextvars should be reset after task execution completes"""
captured_stack = None
captured_cache = None

async def capture_task():
nonlocal captured_stack, captured_cache
# Capture references during task execution
captured_stack = _Depends.stack.get()
captured_cache = _Depends.cache.get()

await docket.add(capture_task)()
await worker.run_until_finished()

# After the task completes, the contextvars should be reset
# Attempting to get them should raise LookupError
with pytest.raises(LookupError):
_Depends.stack.get()

with pytest.raises(LookupError):
_Depends.cache.get()

with pytest.raises(LookupError):
Dependency.execution.get()

with pytest.raises(LookupError):
Dependency.worker.get()

with pytest.raises(LookupError):
Dependency.docket.get()


async def test_dependency_cache_isolated_between_tasks(docket: Docket, worker: Worker):
"""Dependency cache should be fresh for each task, not reused"""
call_counts = {"task1": 0, "task2": 0}

def dependency_for_task1() -> str:
call_counts["task1"] += 1
return f"task1-call-{call_counts['task1']}"

def dependency_for_task2() -> str:
call_counts["task2"] += 1
return f"task2-call-{call_counts['task2']}"

async def first_task(val: str = Depends(dependency_for_task1)):
assert val == "task1-call-1"

async def second_task(val: str = Depends(dependency_for_task2)):
assert val == "task2-call-1"

# Run tasks sequentially
await docket.add(first_task)()
await worker.run_until_finished()

await docket.add(second_task)()
await worker.run_until_finished()

# Each dependency should have been called once (no cache leakage between tasks)
assert call_counts["task1"] == 1
assert call_counts["task2"] == 1


async def test_async_exit_stack_cleanup(docket: Docket, worker: Worker):
"""AsyncExitStack should be properly cleaned up after task execution"""
cleanup_called: list[str] = []

@asynccontextmanager
async def tracked_resource():
try:
yield "resource"
finally:
cleanup_called.append("cleaned")

async def task_with_context(res: str = Depends(tracked_resource)):
assert res == "resource"
assert len(cleanup_called) == 0 # Not cleaned up yet

await docket.add(task_with_context)()
await worker.run_until_finished()

# After task completes, cleanup should have been called
assert cleanup_called == ["cleaned"]


async def test_contextvar_reset_on_reentrant_call(docket: Docket, worker: Worker):
"""Contextvars should be properly reset on reentrant calls to resolved_dependencies"""

# Create two mock executions
async def task1(): ...

async def task2(): ...

execution1 = Execution(
key="task1-key",
function=task1,
args=(),
kwargs={},
attempt=1,
when=datetime.now(timezone.utc),
)

execution2 = Execution(
key="task2-key",
function=task2,
args=(),
kwargs={},
attempt=1,
when=datetime.now(timezone.utc),
)

# Capture contextvars from first call
captured_exec1 = None
captured_stack1 = None

async with resolved_dependencies(worker, execution1):
captured_exec1 = Dependency.execution.get()
captured_stack1 = _Depends.stack.get()
assert captured_exec1 is execution1

# After exiting, contextvars should be reset (raise LookupError)
with pytest.raises(LookupError):
Dependency.execution.get()

# Now make a second call - should not see values from first call
async with resolved_dependencies(worker, execution2):
captured_exec2 = Dependency.execution.get()
captured_stack2 = _Depends.stack.get()
assert captured_exec2 is execution2
assert captured_exec2 is not captured_exec1
# Stacks should be different objects
assert captured_stack2 is not captured_stack1


async def test_contextvar_not_leaked_to_caller(docket: Docket):
"""Verify contextvars don't leak outside resolved_dependencies context"""
# Before calling resolved_dependencies, contextvars should not be set
with pytest.raises(LookupError):
Dependency.execution.get()

async def dummy_task(): ...

execution = Execution(
key="test-key",
function=dummy_task,
args=(),
kwargs={},
attempt=1,
when=datetime.now(timezone.utc),
)

async with Docket("test-contextvar-leak", url="memory://leak-test") as test_docket:
async with Worker(test_docket) as test_worker:
# Use resolved_dependencies
async with resolved_dependencies(test_worker, execution):
# Inside context, we should be able to get values
assert Dependency.execution.get() is execution

# After exiting context, contextvars should be cleaned up
with pytest.raises(LookupError):
Dependency.execution.get()

with pytest.raises(LookupError):
_Depends.stack.get()

with pytest.raises(LookupError): # pragma: no branch
_Depends.cache.get()