Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/fastmcp/server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,14 @@ def __init__(
session: ServerSession | None = None,
*,
task_id: str | None = None,
origin_request_id: str | None = None,
):
self._fastmcp: weakref.ref[FastMCP] = weakref.ref(fastmcp)
self._session: ServerSession | None = session # For state ops during init
self._tokens: list[Token] = []
# Background task support (SEP-1686)
self._task_id: str | None = task_id
self._origin_request_id: str | None = origin_request_id
# Request-scoped state for non-serializable values (serializable=False)
self._request_state: dict[str, Any] = {}

Expand Down Expand Up @@ -227,6 +229,18 @@ def task_id(self) -> str | None:
"""
return self._task_id

@property
def origin_request_id(self) -> str | None:
"""Get the request ID that originated this execution, if available.

In foreground request mode, this is the current request_id.
In background task mode, this is the request_id captured when the task
was submitted, if one was available.
"""
if self.request_context is not None:
return str(self.request_context.request_id)
return self._origin_request_id

@property
def fastmcp(self) -> FastMCP:
"""Get the FastMCP instance."""
Expand Down Expand Up @@ -533,13 +547,14 @@ async def log(
extra: Optional mapping for additional arguments
"""
data = LogData(msg=message, extra=extra)
related_request_id = self.origin_request_id

await _log_to_server_and_client(
data=data,
session=self.session,
level=level or "info",
logger_name=logger_name,
related_request_id=self.request_id,
related_request_id=related_request_id,
)

@property
Expand Down
75 changes: 74 additions & 1 deletion src/fastmcp/server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,14 @@ def transform_context_annotations(fn: Callable[..., Any]) -> Callable[..., Any]:

# First pass: identify which params need transformation
params_to_transform: set[str] = set()
optional_context_params: set[str] = set()
for name, param in sig.parameters.items():
annotation = type_hints.get(name, param.annotation)
if is_class_member_of_type(annotation, Context):
if not isinstance(param.default, Dependency):
params_to_transform.add(name)
if param.default is None:
optional_context_params.add(name)

if not params_to_transform:
return fn
Expand All @@ -300,7 +303,10 @@ def transform_context_annotations(fn: Callable[..., Any]) -> Callable[..., Any]:
# We use CurrentContext() instead of Depends(get_context) because
# get_context() returns the Context which is an AsyncContextManager,
# and the DI system would try to enter it again (it's already entered)
param = param.replace(default=CurrentContext())
if name in optional_context_params:
param = param.replace(default=OptionalCurrentContext())
else:
param = param.replace(default=CurrentContext())

# Sort into buckets based on parameter kind
if param.kind == P.POSITIONAL_ONLY:
Expand Down Expand Up @@ -792,6 +798,36 @@ async def _restore_task_access_token(
return None


async def _restore_task_origin_request_id(session_id: str, task_id: str) -> str | None:
"""Restore the origin request ID snapshot for a background task.

Returns None if no request ID was captured at submission time.
"""
docket = _current_docket.get()
if docket is None:
return None

request_id_key = docket.key(
f"fastmcp:task:{session_id}:{task_id}:origin_request_id"
)
try:
async with docket.redis() as redis:
request_id_data = await redis.get(request_id_key)
if request_id_data is None:
return None
if isinstance(request_id_data, bytes):
return request_id_data.decode()
return str(request_id_data)
except Exception:
_logger.warning(
"Failed to restore origin request ID for task %s:%s",
session_id,
task_id,
exc_info=True,
)
return None


class _CurrentContext(Dependency): # type: ignore[misc]
"""Async context manager for Context dependency.

Expand All @@ -818,11 +854,15 @@ async def __aenter__(self) -> Context:
session = get_task_session(task_info.session_id)
# Get server from ContextVar
server = get_server()
origin_request_id = await _restore_task_origin_request_id(
task_info.session_id, task_info.task_id
)
# Create task-aware Context
self._context = Context(
fastmcp=server,
session=session,
task_id=task_info.task_id,
origin_request_id=origin_request_id,
)
# Enter the context to set up ContextVars
await self._context.__aenter__()
Expand Down Expand Up @@ -853,6 +893,34 @@ async def __aexit__(self, *args: object) -> None:
self._context = None


class _OptionalCurrentContext(Dependency): # type: ignore[misc]
"""Context dependency that degrades to None when no context is active.

This is implemented as a wrapper (composition), not a subclass of
`_CurrentContext`, to avoid overriding `__aenter__` with an incompatible
return type.
"""

_inner: _CurrentContext | None = None

async def __aenter__(self) -> Context | None:
inner = _CurrentContext()
try:
context = await inner.__aenter__()
except RuntimeError as exc:
if "No active context found" in str(exc):
return None
raise
self._inner = inner
return context

async def __aexit__(self, *args: object) -> None:
if self._inner is None:
return
await self._inner.__aexit__(*args)
self._inner = None


def CurrentContext() -> Context:
"""Get the current FastMCP Context instance.

Expand All @@ -878,6 +946,11 @@ async def log_progress(ctx: Context = CurrentContext()) -> str:
return cast("Context", _CurrentContext())


def OptionalCurrentContext() -> Context | None:
"""Get the current FastMCP Context, or None when no context is active."""
return cast("Context | None", _OptionalCurrentContext())


class _CurrentDocket(Dependency): # type: ignore[misc]
"""Async context manager for Docket dependency."""

Expand Down
8 changes: 8 additions & 0 deletions src/fastmcp/server/tasks/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ async def submit_to_docket(
poll_interval_key = docket.key(
f"fastmcp:task:{session_id}:{server_task_id}:poll_interval"
)
origin_request_id_key = docket.key(
f"fastmcp:task:{session_id}:{server_task_id}:origin_request_id"
)
poll_interval_ms = int(component.task_config.poll_interval.total_seconds() * 1000)
origin_request_id = (
str(ctx.request_context.request_id) if ctx.request_context is not None else None
Comment on lines +105 to +106
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve origin request ID when submitting from task context

submit_to_docket snapshots lineage from ctx.request_context.request_id, but in background workers request_context is always None, so no origin_request_id key is written for child submissions. This breaks request correlation for task chains (a task enqueueing another task), because downstream workers cannot recover the original request ID and logs/status notifications lose lineage even though Context.origin_request_id is available in that context.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly I think if a task calls another task, it should provide all the necessary info to it. I don't think we want to magically propagate that.

)

# Snapshot the current access token (if any) for background task access (#3095)
access_token = get_access_token()
Expand All @@ -110,6 +116,8 @@ async def submit_to_docket(
await redis.set(task_meta_key, task_key, ex=ttl_seconds)
await redis.set(created_at_key, created_at.isoformat(), ex=ttl_seconds)
await redis.set(poll_interval_key, str(poll_interval_ms), ex=ttl_seconds)
if origin_request_id is not None:
await redis.set(origin_request_id_key, origin_request_id, ex=ttl_seconds)
if access_token is not None:
await redis.set(
access_token_key, access_token.model_dump_json(), ex=ttl_seconds
Expand Down
38 changes: 38 additions & 0 deletions tests/server/tasks/test_context_background_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from fastmcp import FastMCP
from fastmcp.client import Client
from fastmcp.client.elicitation import ElicitResult
from fastmcp.dependencies import CurrentDocket
from fastmcp.server.auth import AccessToken
from fastmcp.server.context import Context
from fastmcp.server.dependencies import get_access_token
Expand Down Expand Up @@ -229,6 +230,43 @@ async def verify_wiring(ctx: Context) -> str:
assert captured["session_id"] is not None
assert captured["is_background"] is True

async def test_origin_request_id_round_trips_through_background_task(self):
"""E2E: origin_request_id captured at submit time is restored in worker.

We validate this by comparing ctx.origin_request_id with the value
stored in Docket's Redis for this task.
"""

mcp = FastMCP("origin-request-id-roundtrip")

@mcp.tool(task=True)
async def check_origin_request_id(ctx: Context, docket=CurrentDocket()) -> str:
assert ctx.is_background_task is True
assert ctx.request_context is None
assert ctx.task_id is not None

origin = ctx.origin_request_id
assert origin is not None
assert isinstance(origin, str)
assert origin != ""

key = docket.key(
f"fastmcp:task:{ctx.session_id}:{ctx.task_id}:origin_request_id"
)
async with docket.redis() as redis:
raw = await redis.get(key)

assert raw is not None
if isinstance(raw, bytes):
raw = raw.decode()
assert str(raw) == origin
return "ok"

async with Client(mcp) as client:
task = await client.call_tool("check_origin_request_id", {}, task=True)
result = await task.result()
assert result.data == "ok"

async def test_elicit_accept_flow(self):
"""E2E: tool elicits input, client accepts via elicitation_handler."""
mcp = FastMCP("elicit-accept-test")
Expand Down
32 changes: 20 additions & 12 deletions tests/server/tasks/test_task_return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,11 @@ async def return_file() -> File:
),
(
"return_image_data",
lambda r: len(r.content) == 1
and r.content[0].type == "image"
and r.content[0].mimeType == "image/png",
lambda r: (
len(r.content) == 1
and r.content[0].type == "image"
and r.content[0].mimeType == "image/png"
),
),
(
"return_audio",
Expand Down Expand Up @@ -615,25 +617,31 @@ async def return_mixed_content() -> list[TextContent | ImageContent]:
[
(
"return_text_content",
lambda r: len(r.content) == 1
and r.content[0].type == "text"
and r.content[0].text == "Direct text content",
lambda r: (
len(r.content) == 1
and r.content[0].type == "text"
and r.content[0].text == "Direct text content"
),
),
(
"return_image_content",
lambda r: len(r.content) == 1
and r.content[0].type == "image"
and r.content[0].mimeType == "image/png",
lambda r: (
len(r.content) == 1
and r.content[0].type == "image"
and r.content[0].mimeType == "image/png"
),
),
(
"return_embedded_resource",
lambda r: len(r.content) == 1 and r.content[0].type == "resource",
),
(
"return_resource_link",
lambda r: len(r.content) == 1
and r.content[0].type == "resource_link"
and str(r.content[0].uri) == "test://linked",
lambda r: (
len(r.content) == 1
and r.content[0].type == "resource_link"
and str(r.content[0].uri) == "test://linked"
),
),
],
)
Expand Down
36 changes: 36 additions & 0 deletions tests/server/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,42 @@ async def tool_with_validation(val: str = Depends(validate_input)) -> str:
class TestTransformContextAnnotations:
"""Tests for the transform_context_annotations function."""

async def test_optional_context_degrades_to_none_without_active_context(self):
"""Optional Context should resolve to None when no context is active."""
import inspect

from fastmcp.server.dependencies import transform_context_annotations

async def fn_with_optional_ctx(name: str, ctx: Context | None = None) -> str:
return name

transform_context_annotations(fn_with_optional_ctx)
sig = inspect.signature(fn_with_optional_ctx)
ctx_dependency = sig.parameters["ctx"].default

resolved_ctx = await ctx_dependency.__aenter__()
try:
assert resolved_ctx is None
finally:
await ctx_dependency.__aexit__(None, None, None)

async def test_optional_context_still_injected_in_foreground_requests(
self, mcp: FastMCP
):
"""Optional Context should still be injected for normal MCP requests."""

@mcp.tool()
async def tool_with_optional_context(
name: str, ctx: Context | None = None
) -> str:
if ctx is None:
return f"missing:{name}"
return f"present:{ctx.session_id}:{name}"

async with Client(mcp) as client:
result = await client.call_tool("tool_with_optional_context", {"name": "x"})
assert result.content[0].text.startswith("present:")

async def test_basic_context_transformation(self, mcp: FastMCP):
"""Test basic Context type annotation is transformed."""

Expand Down