diff --git a/src/fastmcp/server/context.py b/src/fastmcp/server/context.py index 4cf8b7050f..afe8966f56 100644 --- a/src/fastmcp/server/context.py +++ b/src/fastmcp/server/context.py @@ -191,12 +191,14 @@ def __init__( session: ServerSession | None = None, *, task_id: str | None = None, + origin_request_id: str | None = None, ): self._fastmcp: weakref.ref[FastMCP] = weakref.ref(fastmcp) self._session: ServerSession | None = session # For state ops during init self._tokens: list[Token] = [] # Background task support (SEP-1686) self._task_id: str | None = task_id + self._origin_request_id: str | None = origin_request_id # Request-scoped state for non-serializable values (serializable=False) self._request_state: dict[str, Any] = {} @@ -227,6 +229,18 @@ def task_id(self) -> str | None: """ return self._task_id + @property + def origin_request_id(self) -> str | None: + """Get the request ID that originated this execution, if available. + + In foreground request mode, this is the current request_id. + In background task mode, this is the request_id captured when the task + was submitted, if one was available. + """ + if self.request_context is not None: + return str(self.request_context.request_id) + return self._origin_request_id + @property def fastmcp(self) -> FastMCP: """Get the FastMCP instance.""" @@ -533,13 +547,14 @@ async def log( extra: Optional mapping for additional arguments """ data = LogData(msg=message, extra=extra) + related_request_id = self.origin_request_id await _log_to_server_and_client( data=data, session=self.session, level=level or "info", logger_name=logger_name, - related_request_id=self.request_id, + related_request_id=related_request_id, ) @property diff --git a/src/fastmcp/server/dependencies.py b/src/fastmcp/server/dependencies.py index 3e72d9f84b..ccee89adcb 100644 --- a/src/fastmcp/server/dependencies.py +++ b/src/fastmcp/server/dependencies.py @@ -270,11 +270,14 @@ def transform_context_annotations(fn: Callable[..., Any]) -> Callable[..., Any]: # First pass: identify which params need transformation params_to_transform: set[str] = set() + optional_context_params: set[str] = set() for name, param in sig.parameters.items(): annotation = type_hints.get(name, param.annotation) if is_class_member_of_type(annotation, Context): if not isinstance(param.default, Dependency): params_to_transform.add(name) + if param.default is None: + optional_context_params.add(name) if not params_to_transform: return fn @@ -300,7 +303,10 @@ def transform_context_annotations(fn: Callable[..., Any]) -> Callable[..., Any]: # We use CurrentContext() instead of Depends(get_context) because # get_context() returns the Context which is an AsyncContextManager, # and the DI system would try to enter it again (it's already entered) - param = param.replace(default=CurrentContext()) + if name in optional_context_params: + param = param.replace(default=OptionalCurrentContext()) + else: + param = param.replace(default=CurrentContext()) # Sort into buckets based on parameter kind if param.kind == P.POSITIONAL_ONLY: @@ -792,6 +798,36 @@ async def _restore_task_access_token( return None +async def _restore_task_origin_request_id(session_id: str, task_id: str) -> str | None: + """Restore the origin request ID snapshot for a background task. + + Returns None if no request ID was captured at submission time. + """ + docket = _current_docket.get() + if docket is None: + return None + + request_id_key = docket.key( + f"fastmcp:task:{session_id}:{task_id}:origin_request_id" + ) + try: + async with docket.redis() as redis: + request_id_data = await redis.get(request_id_key) + if request_id_data is None: + return None + if isinstance(request_id_data, bytes): + return request_id_data.decode() + return str(request_id_data) + except Exception: + _logger.warning( + "Failed to restore origin request ID for task %s:%s", + session_id, + task_id, + exc_info=True, + ) + return None + + class _CurrentContext(Dependency): # type: ignore[misc] """Async context manager for Context dependency. @@ -818,11 +854,15 @@ async def __aenter__(self) -> Context: session = get_task_session(task_info.session_id) # Get server from ContextVar server = get_server() + origin_request_id = await _restore_task_origin_request_id( + task_info.session_id, task_info.task_id + ) # Create task-aware Context self._context = Context( fastmcp=server, session=session, task_id=task_info.task_id, + origin_request_id=origin_request_id, ) # Enter the context to set up ContextVars await self._context.__aenter__() @@ -853,6 +893,34 @@ async def __aexit__(self, *args: object) -> None: self._context = None +class _OptionalCurrentContext(Dependency): # type: ignore[misc] + """Context dependency that degrades to None when no context is active. + + This is implemented as a wrapper (composition), not a subclass of + `_CurrentContext`, to avoid overriding `__aenter__` with an incompatible + return type. + """ + + _inner: _CurrentContext | None = None + + async def __aenter__(self) -> Context | None: + inner = _CurrentContext() + try: + context = await inner.__aenter__() + except RuntimeError as exc: + if "No active context found" in str(exc): + return None + raise + self._inner = inner + return context + + async def __aexit__(self, *args: object) -> None: + if self._inner is None: + return + await self._inner.__aexit__(*args) + self._inner = None + + def CurrentContext() -> Context: """Get the current FastMCP Context instance. @@ -878,6 +946,11 @@ async def log_progress(ctx: Context = CurrentContext()) -> str: return cast("Context", _CurrentContext()) +def OptionalCurrentContext() -> Context | None: + """Get the current FastMCP Context, or None when no context is active.""" + return cast("Context | None", _OptionalCurrentContext()) + + class _CurrentDocket(Dependency): # type: ignore[misc] """Async context manager for Docket dependency.""" diff --git a/src/fastmcp/server/tasks/handlers.py b/src/fastmcp/server/tasks/handlers.py index be7bddd615..10bf18b6d0 100644 --- a/src/fastmcp/server/tasks/handlers.py +++ b/src/fastmcp/server/tasks/handlers.py @@ -98,7 +98,13 @@ async def submit_to_docket( poll_interval_key = docket.key( f"fastmcp:task:{session_id}:{server_task_id}:poll_interval" ) + origin_request_id_key = docket.key( + f"fastmcp:task:{session_id}:{server_task_id}:origin_request_id" + ) poll_interval_ms = int(component.task_config.poll_interval.total_seconds() * 1000) + origin_request_id = ( + str(ctx.request_context.request_id) if ctx.request_context is not None else None + ) # Snapshot the current access token (if any) for background task access (#3095) access_token = get_access_token() @@ -110,6 +116,8 @@ async def submit_to_docket( await redis.set(task_meta_key, task_key, ex=ttl_seconds) await redis.set(created_at_key, created_at.isoformat(), ex=ttl_seconds) await redis.set(poll_interval_key, str(poll_interval_ms), ex=ttl_seconds) + if origin_request_id is not None: + await redis.set(origin_request_id_key, origin_request_id, ex=ttl_seconds) if access_token is not None: await redis.set( access_token_key, access_token.model_dump_json(), ex=ttl_seconds diff --git a/tests/server/tasks/test_context_background_task.py b/tests/server/tasks/test_context_background_task.py index c7eb9e90c2..8b2a92889a 100644 --- a/tests/server/tasks/test_context_background_task.py +++ b/tests/server/tasks/test_context_background_task.py @@ -14,6 +14,7 @@ from fastmcp import FastMCP from fastmcp.client import Client from fastmcp.client.elicitation import ElicitResult +from fastmcp.dependencies import CurrentDocket from fastmcp.server.auth import AccessToken from fastmcp.server.context import Context from fastmcp.server.dependencies import get_access_token @@ -229,6 +230,43 @@ async def verify_wiring(ctx: Context) -> str: assert captured["session_id"] is not None assert captured["is_background"] is True + async def test_origin_request_id_round_trips_through_background_task(self): + """E2E: origin_request_id captured at submit time is restored in worker. + + We validate this by comparing ctx.origin_request_id with the value + stored in Docket's Redis for this task. + """ + + mcp = FastMCP("origin-request-id-roundtrip") + + @mcp.tool(task=True) + async def check_origin_request_id(ctx: Context, docket=CurrentDocket()) -> str: + assert ctx.is_background_task is True + assert ctx.request_context is None + assert ctx.task_id is not None + + origin = ctx.origin_request_id + assert origin is not None + assert isinstance(origin, str) + assert origin != "" + + key = docket.key( + f"fastmcp:task:{ctx.session_id}:{ctx.task_id}:origin_request_id" + ) + async with docket.redis() as redis: + raw = await redis.get(key) + + assert raw is not None + if isinstance(raw, bytes): + raw = raw.decode() + assert str(raw) == origin + return "ok" + + async with Client(mcp) as client: + task = await client.call_tool("check_origin_request_id", {}, task=True) + result = await task.result() + assert result.data == "ok" + async def test_elicit_accept_flow(self): """E2E: tool elicits input, client accepts via elicitation_handler.""" mcp = FastMCP("elicit-accept-test") diff --git a/tests/server/tasks/test_task_return_types.py b/tests/server/tasks/test_task_return_types.py index cbceac4b49..3ef3529029 100644 --- a/tests/server/tasks/test_task_return_types.py +++ b/tests/server/tasks/test_task_return_types.py @@ -402,9 +402,11 @@ async def return_file() -> File: ), ( "return_image_data", - lambda r: len(r.content) == 1 - and r.content[0].type == "image" - and r.content[0].mimeType == "image/png", + lambda r: ( + len(r.content) == 1 + and r.content[0].type == "image" + and r.content[0].mimeType == "image/png" + ), ), ( "return_audio", @@ -615,15 +617,19 @@ async def return_mixed_content() -> list[TextContent | ImageContent]: [ ( "return_text_content", - lambda r: len(r.content) == 1 - and r.content[0].type == "text" - and r.content[0].text == "Direct text content", + lambda r: ( + len(r.content) == 1 + and r.content[0].type == "text" + and r.content[0].text == "Direct text content" + ), ), ( "return_image_content", - lambda r: len(r.content) == 1 - and r.content[0].type == "image" - and r.content[0].mimeType == "image/png", + lambda r: ( + len(r.content) == 1 + and r.content[0].type == "image" + and r.content[0].mimeType == "image/png" + ), ), ( "return_embedded_resource", @@ -631,9 +637,11 @@ async def return_mixed_content() -> list[TextContent | ImageContent]: ), ( "return_resource_link", - lambda r: len(r.content) == 1 - and r.content[0].type == "resource_link" - and str(r.content[0].uri) == "test://linked", + lambda r: ( + len(r.content) == 1 + and r.content[0].type == "resource_link" + and str(r.content[0].uri) == "test://linked" + ), ), ], ) diff --git a/tests/server/test_dependencies.py b/tests/server/test_dependencies.py index 106babecce..df7f046977 100644 --- a/tests/server/test_dependencies.py +++ b/tests/server/test_dependencies.py @@ -754,6 +754,42 @@ async def tool_with_validation(val: str = Depends(validate_input)) -> str: class TestTransformContextAnnotations: """Tests for the transform_context_annotations function.""" + async def test_optional_context_degrades_to_none_without_active_context(self): + """Optional Context should resolve to None when no context is active.""" + import inspect + + from fastmcp.server.dependencies import transform_context_annotations + + async def fn_with_optional_ctx(name: str, ctx: Context | None = None) -> str: + return name + + transform_context_annotations(fn_with_optional_ctx) + sig = inspect.signature(fn_with_optional_ctx) + ctx_dependency = sig.parameters["ctx"].default + + resolved_ctx = await ctx_dependency.__aenter__() + try: + assert resolved_ctx is None + finally: + await ctx_dependency.__aexit__(None, None, None) + + async def test_optional_context_still_injected_in_foreground_requests( + self, mcp: FastMCP + ): + """Optional Context should still be injected for normal MCP requests.""" + + @mcp.tool() + async def tool_with_optional_context( + name: str, ctx: Context | None = None + ) -> str: + if ctx is None: + return f"missing:{name}" + return f"present:{ctx.session_id}:{name}" + + async with Client(mcp) as client: + result = await client.call_tool("tool_with_optional_context", {"name": "x"}) + assert result.content[0].text.startswith("present:") + async def test_basic_context_transformation(self, mcp: FastMCP): """Test basic Context type annotation is transformed."""