diff --git a/src/fastmcp/server/context.py b/src/fastmcp/server/context.py index a9245fa2b1..f73f2a3f1e 100644 --- a/src/fastmcp/server/context.py +++ b/src/fastmcp/server/context.py @@ -318,6 +318,10 @@ def lifespan_context(self) -> dict[str, Any]: Returns an empty dict if no lifespan was configured or if the MCP session is not yet established. + In background tasks (Docket workers), where request_context is not + available, falls back to reading from the FastMCP server's lifespan + result directly. + Example: ```python @server.tool @@ -330,6 +334,11 @@ def my_tool(ctx: Context) -> str: """ rc = self.request_context if rc is None: + # In background tasks, request_context is not available. + # Fall back to the server's lifespan result directly (#3095). + result = self.fastmcp._lifespan_result + if result is not None: + return result return {} return rc.lifespan_context diff --git a/src/fastmcp/server/dependencies.py b/src/fastmcp/server/dependencies.py index ffef2561bb..02aee2f223 100644 --- a/src/fastmcp/server/dependencies.py +++ b/src/fastmcp/server/dependencies.py @@ -12,7 +12,7 @@ import weakref from collections.abc import AsyncGenerator, Callable from contextlib import AsyncExitStack, asynccontextmanager -from contextvars import ContextVar +from contextvars import ContextVar, Token from dataclasses import dataclass from functools import lru_cache from typing import TYPE_CHECKING, Any, Protocol, cast, get_type_hints, runtime_checkable @@ -165,6 +165,9 @@ def get_task_session(session_id: str) -> ServerSession | None: ) _current_docket: ContextVar[Docket | None] = ContextVar("docket", default=None) _current_worker: ContextVar[Worker | None] = ContextVar("worker", default=None) +_task_access_token: ContextVar[AccessToken | None] = ContextVar( + "task_access_token", default=None +) # --- Docket availability check --- @@ -478,7 +481,8 @@ def get_access_token() -> AccessToken | None: This function first tries to get the token from the current HTTP request's scope, which is more reliable for long-lived connections where the SDK's auth_context_var may become stale after token refresh. Falls back to the SDK's context var if no - request is available. + request is available. In background tasks (Docket workers), falls back to the + token snapshot stored in Redis at task submission time. Returns: The access token if an authenticated user is available, None otherwise. @@ -501,6 +505,21 @@ def get_access_token() -> AccessToken | None: if access_token is None: access_token = _sdk_get_access_token() + # Fall back to background task snapshot (#3095) + # In Docket workers, neither HTTP request nor SDK context var are available. + # The token was snapshotted in Redis at submit_to_docket() time and restored + # into this ContextVar by _CurrentContext.__aenter__(). + if access_token is None: + task_token = _task_access_token.get() + if task_token is not None: + # Check expiration: if expires_at is set and past, treat as expired + if task_token.expires_at is not None: + from datetime import datetime, timezone + + if task_token.expires_at < int(datetime.now(timezone.utc).timestamp()): + return None + return task_token + if access_token is None or isinstance(access_token, AccessToken): return access_token @@ -718,14 +737,54 @@ async def resolve_dependencies( # so that get_dependency_parameters can detect them. +async def _restore_task_access_token( + session_id: str, task_id: str +) -> Token[AccessToken | None] | None: + """Restore the access token snapshot from Redis into a ContextVar. + + Called when setting up context in a Docket worker. The token was stored at + submit_to_docket() time. If the token has expired, it is not restored + (get_access_token() will return None). + + Returns: + The ContextVar token for resetting, or None if nothing was restored. + """ + docket = _current_docket.get() + # Fall back to docket's own worker ContextVar, which is set by the worker + # even when _CurrentContext hasn't run (no ctx: Context in signature) + if docket is None: + try: + from docket.dependencies import Dependency as DocketDependency + + docket = DocketDependency.docket.get() + except (ImportError, LookupError): + pass + if docket is None: + return None + + token_key = docket.key(f"fastmcp:task:{session_id}:{task_id}:access_token") + try: + async with docket.redis() as redis: + token_data = await redis.get(token_key) + if token_data is not None: + restored = AccessToken.model_validate_json(token_data) + return _task_access_token.set(restored) + except Exception: + # Don't let token restoration failures break task execution + pass + return None + + class _CurrentContext(Dependency): # type: ignore[misc] """Async context manager for Context dependency. In foreground (request) mode: returns the active context from _current_context. - In background (Docket worker) mode: creates a task-aware Context with task_id. + In background (Docket worker) mode: creates a task-aware Context with task_id + and restores the access token snapshot from Redis. """ _context: Context | None = None + _access_token_cv_token: Any = None async def __aenter__(self) -> Context: from fastmcp.server.context import Context, _current_context @@ -750,6 +809,12 @@ async def __aenter__(self) -> Context: ) # Enter the context to set up ContextVars await self._context.__aenter__() + + # Restore access token snapshot from Redis (#3095) + self._access_token_cv_token = await _restore_task_access_token( + task_info.session_id, task_info.task_id + ) + return self._context # Neither foreground nor background context available @@ -761,6 +826,10 @@ async def __aenter__(self) -> Context: ) async def __aexit__(self, *args: object) -> None: + # Clean up access token ContextVar + if self._access_token_cv_token is not None: + _task_access_token.reset(self._access_token_cv_token) + self._access_token_cv_token = None # Clean up if we created a context for background task if self._context is not None: await self._context.__aexit__(*args) @@ -994,8 +1063,22 @@ async def get_auth_type(headers: dict = CurrentHeaders()) -> str: class _CurrentAccessToken(Dependency): # type: ignore[misc] """Async context manager for AccessToken dependency.""" + _access_token_cv_token: Any = None + async def __aenter__(self) -> AccessToken: token = get_access_token() + + # If no token found and we're in a Docket worker, try restoring from + # Redis. This handles the case where ctx: Context is not in the + # function signature, so _CurrentContext never ran the restoration. + if token is None: + task_info = get_task_context() + if task_info is not None: + self._access_token_cv_token = await _restore_task_access_token( + task_info.session_id, task_info.task_id + ) + token = get_access_token() + if token is None: raise RuntimeError( "No access token found. Ensure authentication is configured " @@ -1004,7 +1087,9 @@ async def __aenter__(self) -> AccessToken: return token async def __aexit__(self, *args: object) -> None: - pass + if self._access_token_cv_token is not None: + _task_access_token.reset(self._access_token_cv_token) + self._access_token_cv_token = None def CurrentAccessToken() -> AccessToken: diff --git a/src/fastmcp/server/tasks/handlers.py b/src/fastmcp/server/tasks/handlers.py index 02da22148d..94d6be85c6 100644 --- a/src/fastmcp/server/tasks/handlers.py +++ b/src/fastmcp/server/tasks/handlers.py @@ -14,7 +14,7 @@ from mcp.shared.exceptions import McpError from mcp.types import INTERNAL_ERROR, ErrorData -from fastmcp.server.dependencies import _current_docket, get_context +from fastmcp.server.dependencies import _current_docket, get_access_token, get_context from fastmcp.server.tasks.config import TaskMeta from fastmcp.server.tasks.keys import build_task_key @@ -96,10 +96,21 @@ async def submit_to_docket( f"fastmcp:task:{session_id}:{server_task_id}:poll_interval" ) poll_interval_ms = int(component.task_config.poll_interval.total_seconds() * 1000) + + # Snapshot the current access token (if any) for background task access (#3095) + access_token = get_access_token() + access_token_key = docket.key( + f"fastmcp:task:{session_id}:{server_task_id}:access_token" + ) + async with docket.redis() as redis: await redis.set(task_meta_key, task_key, ex=ttl_seconds) await redis.set(created_at_key, created_at.isoformat(), ex=ttl_seconds) await redis.set(poll_interval_key, str(poll_interval_ms), ex=ttl_seconds) + if access_token is not None: + await redis.set( + access_token_key, access_token.model_dump_json(), ex=ttl_seconds + ) # Register session for Context access in background workers (SEP-1686) # This enables elicitation/sampling from background tasks via weakref diff --git a/tests/server/tasks/test_context_background_task.py b/tests/server/tasks/test_context_background_task.py index b778a63b29..fb12c10810 100644 --- a/tests/server/tasks/test_context_background_task.py +++ b/tests/server/tasks/test_context_background_task.py @@ -3,7 +3,9 @@ import pytest from fastmcp import FastMCP +from fastmcp.server.auth import AccessToken from fastmcp.server.context import Context +from fastmcp.server.dependencies import get_access_token from fastmcp.server.elicitation import AcceptedElicitation from fastmcp.server.tasks.elicitation import elicit_for_task, handle_task_input @@ -860,3 +862,332 @@ async def ask_for_name(ctx: Context) -> str: # Verify the tool received the elicited value and returned correctly assert result.data == "Hello, Bob!" + + +class TestAccessTokenInBackgroundTasks: + """Tests for access token availability in background tasks (#3095).""" + + async def test_access_token_stored_in_redis_at_submit_time(self): + """Verify submit_to_docket() snapshots the access token in Redis.""" + from unittest.mock import AsyncMock, MagicMock, patch + + from fastmcp.server.tasks.handlers import submit_to_docket + + # Create a mock access token + token = AccessToken( + token="test-jwt-token-123", + client_id="test-client", + scopes=["read", "write"], + claims={"sub": "user-1"}, + ) + + # Track Redis set calls + redis_data: dict[str, str | bytes] = {} + + mock_redis = AsyncMock() + + async def mock_set(key, value, ex=None): + redis_data[key] = value + + mock_redis.set = mock_set + + mock_docket = MagicMock() + mock_docket.redis = MagicMock(return_value=AsyncMock()) + mock_docket.redis.return_value.__aenter__ = AsyncMock(return_value=mock_redis) + mock_docket.redis.return_value.__aexit__ = AsyncMock() + mock_docket.key = lambda k: k + mock_docket.execution_ttl.total_seconds.return_value = 300 + + # Mock context + mock_session = MagicMock() + mock_session._fastmcp_state_prefix = "test-session" + mock_session.send_notification = AsyncMock() + mock_session._subscription_task_group = None + + mock_ctx = MagicMock() + mock_ctx.session_id = "test-session" + mock_ctx.session = mock_session + + # Mock component + mock_component = MagicMock() + mock_component.task_config.poll_interval.total_seconds.return_value = 1.0 + mock_component.add_to_docket = AsyncMock() + + with ( + patch("fastmcp.server.tasks.handlers.get_context", return_value=mock_ctx), + patch( + "fastmcp.server.tasks.handlers._current_docket", + MagicMock(get=MagicMock(return_value=mock_docket)), + ), + patch( + "fastmcp.server.tasks.handlers.get_access_token", + return_value=token, + ), + ): + result = await submit_to_docket( + task_type="tool", + key="test_tool", + component=mock_component, + arguments={"x": 1}, + ) + + # Verify token was stored in Redis + task_id = result.task.taskId + token_key = f"fastmcp:task:test-session:{task_id}:access_token" + assert token_key in redis_data + + # Verify stored token can be deserialized + restored = AccessToken.model_validate_json(redis_data[token_key]) + assert restored.token == "test-jwt-token-123" + assert restored.client_id == "test-client" + assert restored.scopes == ["read", "write"] + assert restored.claims == {"sub": "user-1"} + + async def test_access_token_not_stored_when_unauthenticated(self): + """Verify submit_to_docket() doesn't store token when no auth.""" + from unittest.mock import AsyncMock, MagicMock, patch + + from fastmcp.server.tasks.handlers import submit_to_docket + + redis_data: dict[str, str | bytes] = {} + + mock_redis = AsyncMock() + + async def mock_set(key, value, ex=None): + redis_data[key] = value + + mock_redis.set = mock_set + + mock_docket = MagicMock() + mock_docket.redis = MagicMock(return_value=AsyncMock()) + mock_docket.redis.return_value.__aenter__ = AsyncMock(return_value=mock_redis) + mock_docket.redis.return_value.__aexit__ = AsyncMock() + mock_docket.key = lambda k: k + mock_docket.execution_ttl.total_seconds.return_value = 300 + + mock_session = MagicMock() + mock_session._fastmcp_state_prefix = "test-session" + mock_session.send_notification = AsyncMock() + mock_session._subscription_task_group = None + + mock_ctx = MagicMock() + mock_ctx.session_id = "test-session" + mock_ctx.session = mock_session + + mock_component = MagicMock() + mock_component.task_config.poll_interval.total_seconds.return_value = 1.0 + mock_component.add_to_docket = AsyncMock() + + with ( + patch("fastmcp.server.tasks.handlers.get_context", return_value=mock_ctx), + patch( + "fastmcp.server.tasks.handlers._current_docket", + MagicMock(get=MagicMock(return_value=mock_docket)), + ), + patch( + "fastmcp.server.tasks.handlers.get_access_token", + return_value=None, + ), + ): + result = await submit_to_docket( + task_type="tool", + key="test_tool", + component=mock_component, + arguments={"x": 1}, + ) + + # Verify no token key was stored + task_id = result.task.taskId + token_key = f"fastmcp:task:test-session:{task_id}:access_token" + assert token_key not in redis_data + + async def test_access_token_restored_in_background_task_context(self): + """Verify _CurrentContext restores access token from Redis in workers.""" + from unittest.mock import AsyncMock, MagicMock, patch + + from fastmcp.server.dependencies import ( + TaskContextInfo, + _current_docket, + _current_server, + _CurrentContext, + _task_access_token, + _task_sessions, + ) + + # Create token to store + token = AccessToken( + token="bg-task-token", + client_id="bg-client", + scopes=["admin"], + claims={"sub": "admin-user"}, + ) + + # Set up mock Redis with pre-stored token + mock_redis = AsyncMock() + mock_redis.get = AsyncMock(return_value=token.model_dump_json().encode()) + + mock_docket = MagicMock() + mock_docket.redis = MagicMock(return_value=AsyncMock()) + mock_docket.redis.return_value.__aenter__ = AsyncMock(return_value=mock_redis) + mock_docket.redis.return_value.__aexit__ = AsyncMock() + mock_docket.key = lambda k: k + + # Set up server and session + mock_server = MagicMock() + mock_server._docket = mock_docket + server_token = _current_server.set(MagicMock(return_value=mock_server)) + docket_token = _current_docket.set(mock_docket) + + mock_session = MagicMock() + mock_session._fastmcp_state_prefix = "test-session-id" + _task_sessions["test-session-id"] = MagicMock(return_value=mock_session) + + try: + task_info = TaskContextInfo( + task_id="test-task-123", + session_id="test-session-id", + ) + with patch( + "fastmcp.server.dependencies.get_task_context", + return_value=task_info, + ): + dep = _CurrentContext() + ctx = await dep.__aenter__() + + # Verify context is task-aware + assert ctx.is_background_task is True + + # Verify access token was restored into ContextVar + restored = _task_access_token.get() + assert restored is not None + assert restored.token == "bg-task-token" + assert restored.client_id == "bg-client" + assert restored.claims == {"sub": "admin-user"} + + # Verify get_access_token() returns the restored token + result = get_access_token() + assert result is not None + assert result.token == "bg-task-token" + + # Clean up + await dep.__aexit__(None, None, None) + + # Verify ContextVar was reset after exit + assert _task_access_token.get() is None + finally: + _current_server.reset(server_token) + _current_docket.reset(docket_token) + _task_sessions.pop("test-session-id", None) + + async def test_expired_access_token_returns_none(self): + """Verify expired tokens return None from get_access_token().""" + from datetime import datetime, timezone + + from fastmcp.server.dependencies import _task_access_token + + # Create an expired token (expired 1 hour ago) + expired_token = AccessToken( + token="expired-token", + client_id="test-client", + scopes=["read"], + expires_at=int(datetime.now(timezone.utc).timestamp()) - 3600, + ) + + token = _task_access_token.set(expired_token) + try: + result = get_access_token() + assert result is None + finally: + _task_access_token.reset(token) + + async def test_valid_access_token_with_future_expiry(self): + """Verify non-expired tokens are returned from get_access_token().""" + from datetime import datetime, timezone + + from fastmcp.server.dependencies import _task_access_token + + # Create a valid token (expires in 1 hour) + valid_token = AccessToken( + token="valid-token", + client_id="test-client", + scopes=["read"], + expires_at=int(datetime.now(timezone.utc).timestamp()) + 3600, + ) + + token = _task_access_token.set(valid_token) + try: + result = get_access_token() + assert result is not None + assert result.token == "valid-token" + finally: + _task_access_token.reset(token) + + async def test_access_token_without_expiry_returned(self): + """Verify tokens without expires_at are returned (no expiry check).""" + from fastmcp.server.dependencies import _task_access_token + + token_no_expiry = AccessToken( + token="no-expiry-token", + client_id="test-client", + scopes=["read"], + ) + + token = _task_access_token.set(token_no_expiry) + try: + result = get_access_token() + assert result is not None + assert result.token == "no-expiry-token" + finally: + _task_access_token.reset(token) + + +class TestLifespanContextInBackgroundTasks: + """Tests for lifespan_context availability in background tasks (#3095).""" + + def test_lifespan_context_falls_back_to_server_result(self): + """Verify lifespan_context reads from server when request_context is None.""" + mcp = FastMCP("test") + # Simulate lifespan result being set (as would happen during server startup) + mcp._lifespan_result = {"db": "mock-db-connection", "cache": "mock-cache"} + mcp._lifespan_result_set = True + + # Create context without request_context (background task scenario) + ctx = Context(mcp, task_id="test-task") + + # request_context should be None (no MCP session) + assert ctx.request_context is None + + # lifespan_context should fall back to server's lifespan result + assert ctx.lifespan_context == { + "db": "mock-db-connection", + "cache": "mock-cache", + } + + def test_lifespan_context_returns_empty_dict_when_no_lifespan(self): + """Verify lifespan_context returns {} when no lifespan configured.""" + mcp = FastMCP("test") + + ctx = Context(mcp, task_id="test-task") + assert ctx.request_context is None + assert ctx.lifespan_context == {} + + def test_lifespan_context_still_uses_request_context_when_available(self): + """Verify lifespan_context prefers request_context when available.""" + from unittest.mock import MagicMock, patch + + mcp = FastMCP("test") + mcp._lifespan_result = {"server": "value"} + mcp._lifespan_result_set = True + + ctx = Context(mcp) + + # Mock request_context with different lifespan data + mock_rc = MagicMock() + mock_rc.lifespan_context = {"request": "value"} + + with patch.object( + type(ctx), + "request_context", + new_callable=lambda: property(lambda self: mock_rc), + ): + assert ctx.lifespan_context == {"request": "value"}