diff --git a/src/fastmcp/server/context.py b/src/fastmcp/server/context.py index d65e87c436..9889aaee94 100644 --- a/src/fastmcp/server/context.py +++ b/src/fastmcp/server/context.py @@ -318,6 +318,10 @@ def lifespan_context(self) -> dict[str, Any]: Returns an empty dict if no lifespan was configured or if the MCP session is not yet established. + In background tasks (Docket workers), where request_context is not + available, falls back to reading from the FastMCP server's lifespan + result directly. + Example: ```python @server.tool @@ -330,6 +334,11 @@ def my_tool(ctx: Context) -> str: """ rc = self.request_context if rc is None: + # In background tasks, request_context is not available. + # Fall back to the server's lifespan result directly (#3095). + result = self.fastmcp._lifespan_result + if result is not None: + return result return {} return rc.lifespan_context diff --git a/src/fastmcp/server/dependencies.py b/src/fastmcp/server/dependencies.py index 97cf4fbeff..acd7fea194 100644 --- a/src/fastmcp/server/dependencies.py +++ b/src/fastmcp/server/dependencies.py @@ -9,11 +9,13 @@ import contextlib import inspect +import logging import weakref from collections.abc import AsyncGenerator, Callable from contextlib import AsyncExitStack, asynccontextmanager -from contextvars import ContextVar +from contextvars import ContextVar, Token from dataclasses import dataclass +from datetime import datetime, timezone from functools import lru_cache from typing import TYPE_CHECKING, Any, Protocol, cast, get_type_hints, runtime_checkable @@ -33,6 +35,8 @@ from fastmcp.utilities.async_utils import call_sync_fn_in_threadpool from fastmcp.utilities.types import find_kwarg_by_type, is_class_member_of_type +_logger = logging.getLogger(__name__) + if TYPE_CHECKING: from docket import Docket from docket.worker import Worker @@ -166,6 +170,9 @@ def get_task_session(session_id: str) -> ServerSession | None: ) _current_docket: ContextVar[Docket | None] = ContextVar("docket", default=None) _current_worker: ContextVar[Worker | None] = ContextVar("worker", default=None) +_task_access_token: ContextVar[AccessToken | None] = ContextVar( + "task_access_token", default=None +) # --- Docket availability check --- @@ -479,7 +486,8 @@ def get_access_token() -> AccessToken | None: This function first tries to get the token from the current HTTP request's scope, which is more reliable for long-lived connections where the SDK's auth_context_var may become stale after token refresh. Falls back to the SDK's context var if no - request is available. + request is available. In background tasks (Docket workers), falls back to the + token snapshot stored in Redis at task submission time. Returns: The access token if an authenticated user is available, None otherwise. @@ -502,6 +510,19 @@ def get_access_token() -> AccessToken | None: if access_token is None: access_token = _sdk_get_access_token() + # Fall back to background task snapshot (#3095) + # In Docket workers, neither HTTP request nor SDK context var are available. + # The token was snapshotted in Redis at submit_to_docket() time and restored + # into this ContextVar by _CurrentContext.__aenter__(). + if access_token is None: + task_token = _task_access_token.get() + if task_token is not None: + # Check expiration: if expires_at is set and past, treat as expired + if task_token.expires_at is not None: + if task_token.expires_at < int(datetime.now(timezone.utc).timestamp()): + return None + return task_token + if access_token is None or isinstance(access_token, AccessToken): return access_token @@ -719,14 +740,49 @@ async def resolve_dependencies( # so that get_dependency_parameters can detect them. +async def _restore_task_access_token( + session_id: str, task_id: str +) -> Token[AccessToken | None] | None: + """Restore the access token snapshot from Redis into a ContextVar. + + Called when setting up context in a Docket worker. The token was stored at + submit_to_docket() time. The token is restored regardless of expiration; + get_access_token() checks expiry when reading from the ContextVar. + + Returns: + The ContextVar token for resetting, or None if nothing was restored. + """ + docket = _current_docket.get() + if docket is None: + return None + + token_key = docket.key(f"fastmcp:task:{session_id}:{task_id}:access_token") + try: + async with docket.redis() as redis: + token_data = await redis.get(token_key) + if token_data is not None: + restored = AccessToken.model_validate_json(token_data) + return _task_access_token.set(restored) + except Exception: + _logger.warning( + "Failed to restore access token for task %s:%s", + session_id, + task_id, + exc_info=True, + ) + return None + + class _CurrentContext(Dependency): # type: ignore[misc] """Async context manager for Context dependency. In foreground (request) mode: returns the active context from _current_context. - In background (Docket worker) mode: creates a task-aware Context with task_id. + In background (Docket worker) mode: creates a task-aware Context with task_id + and restores the access token snapshot from Redis. """ _context: Context | None = None + _access_token_cv_token: Token[AccessToken | None] | None = None async def __aenter__(self) -> Context: from fastmcp.server.context import Context, _current_context @@ -751,6 +807,12 @@ async def __aenter__(self) -> Context: ) # Enter the context to set up ContextVars await self._context.__aenter__() + + # Restore access token snapshot from Redis (#3095) + self._access_token_cv_token = await _restore_task_access_token( + task_info.session_id, task_info.task_id + ) + return self._context # Neither foreground nor background context available @@ -762,6 +824,10 @@ async def __aenter__(self) -> Context: ) async def __aexit__(self, *args: object) -> None: + # Clean up access token ContextVar + if self._access_token_cv_token is not None: + _task_access_token.reset(self._access_token_cv_token) + self._access_token_cv_token = None # Clean up if we created a context for background task if self._context is not None: await self._context.__aexit__(*args) @@ -1130,8 +1196,22 @@ async def __aexit__(self, *args: object) -> None: class _CurrentAccessToken(Dependency): # type: ignore[misc] """Async context manager for AccessToken dependency.""" + _access_token_cv_token: Token[AccessToken | None] | None = None + async def __aenter__(self) -> AccessToken: token = get_access_token() + + # If no token found and we're in a Docket worker, try restoring from + # Redis. This handles the case where ctx: Context is not in the + # function signature, so _CurrentContext never ran the restoration. + if token is None: + task_info = get_task_context() + if task_info is not None: + self._access_token_cv_token = await _restore_task_access_token( + task_info.session_id, task_info.task_id + ) + token = get_access_token() + if token is None: raise RuntimeError( "No access token found. Ensure authentication is configured " @@ -1140,7 +1220,9 @@ async def __aenter__(self) -> AccessToken: return token async def __aexit__(self, *args: object) -> None: - pass + if self._access_token_cv_token is not None: + _task_access_token.reset(self._access_token_cv_token) + self._access_token_cv_token = None def CurrentAccessToken() -> AccessToken: diff --git a/src/fastmcp/server/tasks/handlers.py b/src/fastmcp/server/tasks/handlers.py index fa8ba3ce4e..be7bddd615 100644 --- a/src/fastmcp/server/tasks/handlers.py +++ b/src/fastmcp/server/tasks/handlers.py @@ -14,7 +14,7 @@ from mcp.shared.exceptions import McpError from mcp.types import INTERNAL_ERROR, ErrorData -from fastmcp.server.dependencies import _current_docket, get_context +from fastmcp.server.dependencies import _current_docket, get_access_token, get_context from fastmcp.server.tasks.config import TaskMeta from fastmcp.server.tasks.keys import build_task_key from fastmcp.utilities.logging import get_logger @@ -99,10 +99,21 @@ async def submit_to_docket( f"fastmcp:task:{session_id}:{server_task_id}:poll_interval" ) poll_interval_ms = int(component.task_config.poll_interval.total_seconds() * 1000) + + # Snapshot the current access token (if any) for background task access (#3095) + access_token = get_access_token() + access_token_key = docket.key( + f"fastmcp:task:{session_id}:{server_task_id}:access_token" + ) + async with docket.redis() as redis: await redis.set(task_meta_key, task_key, ex=ttl_seconds) await redis.set(created_at_key, created_at.isoformat(), ex=ttl_seconds) await redis.set(poll_interval_key, str(poll_interval_ms), ex=ttl_seconds) + if access_token is not None: + await redis.set( + access_token_key, access_token.model_dump_json(), ex=ttl_seconds + ) # Register session for Context access in background workers (SEP-1686) # This enables elicitation/sampling from background tasks via weakref diff --git a/tests/server/tasks/test_context_background_task.py b/tests/server/tasks/test_context_background_task.py index 2b5a1efa95..c7eb9e90c2 100644 --- a/tests/server/tasks/test_context_background_task.py +++ b/tests/server/tasks/test_context_background_task.py @@ -14,7 +14,9 @@ from fastmcp import FastMCP from fastmcp.client import Client from fastmcp.client.elicitation import ElicitResult +from fastmcp.server.auth import AccessToken from fastmcp.server.context import Context +from fastmcp.server.dependencies import get_access_token from fastmcp.server.elicitation import AcceptedElicitation, DeclinedElicitation from fastmcp.server.tasks.elicitation import handle_task_input @@ -317,3 +319,124 @@ async def simple_tool() -> str: fastmcp=mcp, ) assert success is False + + +class TestAccessTokenInBackgroundTasks: + """Tests for access token availability in background tasks (#3095). + + Integration tests use Client(mcp) with the real memory:// Docket backend. + The token snapshot/restore round-trip flows through actual Redis (fakeredis). + + Note: async tests run in isolated asyncio tasks, so ContextVar changes + are automatically scoped — no cleanup required. + """ + + async def test_token_round_trips_through_background_task(self): + """E2E: token set at submit time is available inside the worker.""" + from mcp.server.auth.middleware.auth_context import auth_context_var + from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser + + mcp = FastMCP("token-roundtrip") + + @mcp.tool(task=True) + async def check_token(ctx: Context) -> str: + token = get_access_token() + if token is None: + return "no-token" + return f"{token.token}|{token.client_id}" + + test_token = AccessToken( + token="roundtrip-jwt", + client_id="test-client", + scopes=["read"], + claims={"sub": "user-1"}, + ) + auth_context_var.set(AuthenticatedUser(test_token)) + + async with Client(mcp) as client: + task = await client.call_tool("check_token", {}, task=True) + result = await task.result() + assert result.data == "roundtrip-jwt|test-client" + + async def test_no_token_when_unauthenticated(self): + """E2E: background task gets no token when nothing was set.""" + mcp = FastMCP("no-auth") + + @mcp.tool(task=True) + async def check_token(ctx: Context) -> str: + token = get_access_token() + return "no-token" if token is None else token.token + + async with Client(mcp) as client: + task = await client.call_tool("check_token", {}, task=True) + result = await task.result() + assert result.data == "no-token" + + async def test_expired_token_returns_none(self): + """get_access_token() returns None when task token has expired.""" + from datetime import datetime, timezone + + from fastmcp.server.dependencies import _task_access_token + + expired = AccessToken( + token="expired-jwt", + client_id="test-client", + scopes=["read"], + expires_at=int(datetime.now(timezone.utc).timestamp()) - 3600, + ) + _task_access_token.set(expired) + assert get_access_token() is None + + async def test_valid_token_with_future_expiry(self): + """get_access_token() returns token when expiry is in the future.""" + from datetime import datetime, timezone + + from fastmcp.server.dependencies import _task_access_token + + valid = AccessToken( + token="valid-jwt", + client_id="test-client", + scopes=["read"], + expires_at=int(datetime.now(timezone.utc).timestamp()) + 3600, + ) + _task_access_token.set(valid) + result = get_access_token() + assert result is not None + assert result.token == "valid-jwt" + + async def test_token_without_expiry_always_valid(self): + """get_access_token() returns token when no expires_at is set.""" + from fastmcp.server.dependencies import _task_access_token + + no_expiry = AccessToken( + token="eternal-jwt", + client_id="test-client", + scopes=["read"], + ) + _task_access_token.set(no_expiry) + result = get_access_token() + assert result is not None + assert result.token == "eternal-jwt" + + +class TestLifespanContextInBackgroundTasks: + """Tests for lifespan_context availability in background tasks (#3095).""" + + def test_lifespan_context_falls_back_to_server_result(self): + """lifespan_context reads from server when request_context is None.""" + mcp = FastMCP("test") + mcp._lifespan_result = {"db": "mock-db-connection", "cache": "mock-cache"} + + ctx = Context(mcp, task_id="test-task") + assert ctx.request_context is None + assert ctx.lifespan_context == { + "db": "mock-db-connection", + "cache": "mock-cache", + } + + def test_lifespan_context_returns_empty_dict_when_no_lifespan(self): + """lifespan_context returns {} when no lifespan is configured.""" + mcp = FastMCP("test") + ctx = Context(mcp, task_id="test-task") + assert ctx.request_context is None + assert ctx.lifespan_context == {}