Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/fastmcp/server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ def lifespan_context(self) -> dict[str, Any]:
Returns an empty dict if no lifespan was configured or if the MCP
session is not yet established.

In background tasks (Docket workers), where request_context is not
available, falls back to reading from the FastMCP server's lifespan
result directly.

Example:
```python
@server.tool
Expand All @@ -330,6 +334,11 @@ def my_tool(ctx: Context) -> str:
"""
rc = self.request_context
if rc is None:
# In background tasks, request_context is not available.
# Fall back to the server's lifespan result directly (#3095).
result = self.fastmcp._lifespan_result
if result is not None:
return result
return {}
return rc.lifespan_context

Expand Down
90 changes: 86 additions & 4 deletions src/fastmcp/server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

import contextlib
import inspect
import logging
import weakref
from collections.abc import AsyncGenerator, Callable
from contextlib import AsyncExitStack, asynccontextmanager
from contextvars import ContextVar
from contextvars import ContextVar, Token
from dataclasses import dataclass
from datetime import datetime, timezone
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Protocol, cast, get_type_hints, runtime_checkable

Expand All @@ -33,6 +35,8 @@
from fastmcp.utilities.async_utils import call_sync_fn_in_threadpool
from fastmcp.utilities.types import find_kwarg_by_type, is_class_member_of_type

_logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from docket import Docket
from docket.worker import Worker
Expand Down Expand Up @@ -166,6 +170,9 @@ def get_task_session(session_id: str) -> ServerSession | None:
)
_current_docket: ContextVar[Docket | None] = ContextVar("docket", default=None)
_current_worker: ContextVar[Worker | None] = ContextVar("worker", default=None)
_task_access_token: ContextVar[AccessToken | None] = ContextVar(
"task_access_token", default=None
)


# --- Docket availability check ---
Expand Down Expand Up @@ -479,7 +486,8 @@ def get_access_token() -> AccessToken | None:
This function first tries to get the token from the current HTTP request's scope,
which is more reliable for long-lived connections where the SDK's auth_context_var
may become stale after token refresh. Falls back to the SDK's context var if no
request is available.
request is available. In background tasks (Docket workers), falls back to the
token snapshot stored in Redis at task submission time.

Returns:
The access token if an authenticated user is available, None otherwise.
Expand All @@ -502,6 +510,19 @@ def get_access_token() -> AccessToken | None:
if access_token is None:
access_token = _sdk_get_access_token()

# Fall back to background task snapshot (#3095)
# In Docket workers, neither HTTP request nor SDK context var are available.
# The token was snapshotted in Redis at submit_to_docket() time and restored
# into this ContextVar by _CurrentContext.__aenter__().
if access_token is None:
task_token = _task_access_token.get()
Comment on lines +517 to +518
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Restore snapshot for direct get_access_token task usage

get_access_token() now falls back to _task_access_token, but this ContextVar is only populated when _restore_task_access_token() runs via _CurrentContext or _CurrentAccessToken. A background task that calls get_access_token() directly and does not inject ctx: Context/CurrentAccessToken() never triggers restoration, so this branch still returns None even though submit_to_docket() persisted a token for the task.

Useful? React with 👍 / 👎.

if task_token is not None:
# Check expiration: if expires_at is set and past, treat as expired
if task_token.expires_at is not None:
if task_token.expires_at < int(datetime.now(timezone.utc).timestamp()):
return None
return task_token

if access_token is None or isinstance(access_token, AccessToken):
return access_token

Expand Down Expand Up @@ -719,14 +740,49 @@ async def resolve_dependencies(
# so that get_dependency_parameters can detect them.


async def _restore_task_access_token(
session_id: str, task_id: str
) -> Token[AccessToken | None] | None:
"""Restore the access token snapshot from Redis into a ContextVar.

Called when setting up context in a Docket worker. The token was stored at
submit_to_docket() time. The token is restored regardless of expiration;
get_access_token() checks expiry when reading from the ContextVar.

Returns:
The ContextVar token for resetting, or None if nothing was restored.
"""
docket = _current_docket.get()
if docket is None:
return None

token_key = docket.key(f"fastmcp:task:{session_id}:{task_id}:access_token")
try:
async with docket.redis() as redis:
token_data = await redis.get(token_key)
if token_data is not None:
restored = AccessToken.model_validate_json(token_data)
return _task_access_token.set(restored)
except Exception:
_logger.warning(
"Failed to restore access token for task %s:%s",
session_id,
task_id,
exc_info=True,
)
return None


class _CurrentContext(Dependency): # type: ignore[misc]
"""Async context manager for Context dependency.

In foreground (request) mode: returns the active context from _current_context.
In background (Docket worker) mode: creates a task-aware Context with task_id.
In background (Docket worker) mode: creates a task-aware Context with task_id
and restores the access token snapshot from Redis.
"""

_context: Context | None = None
_access_token_cv_token: Token[AccessToken | None] | None = None

async def __aenter__(self) -> Context:
from fastmcp.server.context import Context, _current_context
Expand All @@ -751,6 +807,12 @@ async def __aenter__(self) -> Context:
)
# Enter the context to set up ContextVars
await self._context.__aenter__()

# Restore access token snapshot from Redis (#3095)
self._access_token_cv_token = await _restore_task_access_token(
task_info.session_id, task_info.task_id
)

return self._context

# Neither foreground nor background context available
Expand All @@ -762,6 +824,10 @@ async def __aenter__(self) -> Context:
)

async def __aexit__(self, *args: object) -> None:
# Clean up access token ContextVar
if self._access_token_cv_token is not None:
_task_access_token.reset(self._access_token_cv_token)
self._access_token_cv_token = None
# Clean up if we created a context for background task
if self._context is not None:
await self._context.__aexit__(*args)
Expand Down Expand Up @@ -1130,8 +1196,22 @@ async def __aexit__(self, *args: object) -> None:
class _CurrentAccessToken(Dependency): # type: ignore[misc]
"""Async context manager for AccessToken dependency."""

_access_token_cv_token: Token[AccessToken | None] | None = None

async def __aenter__(self) -> AccessToken:
token = get_access_token()

# If no token found and we're in a Docket worker, try restoring from
# Redis. This handles the case where ctx: Context is not in the
# function signature, so _CurrentContext never ran the restoration.
if token is None:
task_info = get_task_context()
if task_info is not None:
self._access_token_cv_token = await _restore_task_access_token(
task_info.session_id, task_info.task_id
)
token = get_access_token()
Comment on lines +1210 to +1213
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Clear saved ContextVar token when access-token enter fails

_CurrentAccessToken.__aenter__ stores self._access_token_cv_token before confirming the restored token is usable; if the restored token is expired, get_access_token() returns None and __aenter__ raises. Because failed __aenter__ calls are not paired with __aexit__, this stale token remains on the reused dependency instance and a later invocation can call _task_access_token.reset() with a token from another context, raising ValueError during cleanup.

Useful? React with 👍 / 👎.


if token is None:
raise RuntimeError(
"No access token found. Ensure authentication is configured "
Expand All @@ -1140,7 +1220,9 @@ async def __aenter__(self) -> AccessToken:
return token

async def __aexit__(self, *args: object) -> None:
pass
if self._access_token_cv_token is not None:
_task_access_token.reset(self._access_token_cv_token)
self._access_token_cv_token = None


def CurrentAccessToken() -> AccessToken:
Expand Down
13 changes: 12 additions & 1 deletion src/fastmcp/server/tasks/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mcp.shared.exceptions import McpError
from mcp.types import INTERNAL_ERROR, ErrorData

from fastmcp.server.dependencies import _current_docket, get_context
from fastmcp.server.dependencies import _current_docket, get_access_token, get_context
from fastmcp.server.tasks.config import TaskMeta
from fastmcp.server.tasks.keys import build_task_key
from fastmcp.utilities.logging import get_logger
Expand Down Expand Up @@ -99,10 +99,21 @@ async def submit_to_docket(
f"fastmcp:task:{session_id}:{server_task_id}:poll_interval"
)
poll_interval_ms = int(component.task_config.poll_interval.total_seconds() * 1000)

# Snapshot the current access token (if any) for background task access (#3095)
access_token = get_access_token()
access_token_key = docket.key(
f"fastmcp:task:{session_id}:{server_task_id}:access_token"
)

async with docket.redis() as redis:
await redis.set(task_meta_key, task_key, ex=ttl_seconds)
await redis.set(created_at_key, created_at.isoformat(), ex=ttl_seconds)
await redis.set(poll_interval_key, str(poll_interval_ms), ex=ttl_seconds)
if access_token is not None:
await redis.set(
access_token_key, access_token.model_dump_json(), ex=ttl_seconds
)

# Register session for Context access in background workers (SEP-1686)
# This enables elicitation/sampling from background tasks via weakref
Expand Down
123 changes: 123 additions & 0 deletions tests/server/tasks/test_context_background_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from fastmcp import FastMCP
from fastmcp.client import Client
from fastmcp.client.elicitation import ElicitResult
from fastmcp.server.auth import AccessToken
from fastmcp.server.context import Context
from fastmcp.server.dependencies import get_access_token
from fastmcp.server.elicitation import AcceptedElicitation, DeclinedElicitation
from fastmcp.server.tasks.elicitation import handle_task_input

Expand Down Expand Up @@ -317,3 +319,124 @@ async def simple_tool() -> str:
fastmcp=mcp,
)
assert success is False


class TestAccessTokenInBackgroundTasks:
"""Tests for access token availability in background tasks (#3095).

Integration tests use Client(mcp) with the real memory:// Docket backend.
The token snapshot/restore round-trip flows through actual Redis (fakeredis).

Note: async tests run in isolated asyncio tasks, so ContextVar changes
are automatically scoped — no cleanup required.
"""

async def test_token_round_trips_through_background_task(self):
"""E2E: token set at submit time is available inside the worker."""
from mcp.server.auth.middleware.auth_context import auth_context_var
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser

mcp = FastMCP("token-roundtrip")

@mcp.tool(task=True)
async def check_token(ctx: Context) -> str:
token = get_access_token()
if token is None:
return "no-token"
return f"{token.token}|{token.client_id}"

test_token = AccessToken(
token="roundtrip-jwt",
client_id="test-client",
scopes=["read"],
claims={"sub": "user-1"},
)
auth_context_var.set(AuthenticatedUser(test_token))

async with Client(mcp) as client:
task = await client.call_tool("check_token", {}, task=True)
result = await task.result()
assert result.data == "roundtrip-jwt|test-client"

async def test_no_token_when_unauthenticated(self):
"""E2E: background task gets no token when nothing was set."""
mcp = FastMCP("no-auth")

@mcp.tool(task=True)
async def check_token(ctx: Context) -> str:
token = get_access_token()
return "no-token" if token is None else token.token

async with Client(mcp) as client:
task = await client.call_tool("check_token", {}, task=True)
result = await task.result()
assert result.data == "no-token"

async def test_expired_token_returns_none(self):
"""get_access_token() returns None when task token has expired."""
from datetime import datetime, timezone

from fastmcp.server.dependencies import _task_access_token

expired = AccessToken(
token="expired-jwt",
client_id="test-client",
scopes=["read"],
expires_at=int(datetime.now(timezone.utc).timestamp()) - 3600,
)
_task_access_token.set(expired)
assert get_access_token() is None

async def test_valid_token_with_future_expiry(self):
"""get_access_token() returns token when expiry is in the future."""
from datetime import datetime, timezone

from fastmcp.server.dependencies import _task_access_token

valid = AccessToken(
token="valid-jwt",
client_id="test-client",
scopes=["read"],
expires_at=int(datetime.now(timezone.utc).timestamp()) + 3600,
)
_task_access_token.set(valid)
result = get_access_token()
assert result is not None
assert result.token == "valid-jwt"

async def test_token_without_expiry_always_valid(self):
"""get_access_token() returns token when no expires_at is set."""
from fastmcp.server.dependencies import _task_access_token

no_expiry = AccessToken(
token="eternal-jwt",
client_id="test-client",
scopes=["read"],
)
_task_access_token.set(no_expiry)
result = get_access_token()
assert result is not None
assert result.token == "eternal-jwt"


class TestLifespanContextInBackgroundTasks:
"""Tests for lifespan_context availability in background tasks (#3095)."""

def test_lifespan_context_falls_back_to_server_result(self):
"""lifespan_context reads from server when request_context is None."""
mcp = FastMCP("test")
mcp._lifespan_result = {"db": "mock-db-connection", "cache": "mock-cache"}

ctx = Context(mcp, task_id="test-task")
assert ctx.request_context is None
assert ctx.lifespan_context == {
"db": "mock-db-connection",
"cache": "mock-cache",
}

def test_lifespan_context_returns_empty_dict_when_no_lifespan(self):
"""lifespan_context returns {} when no lifespan is configured."""
mcp = FastMCP("test")
ctx = Context(mcp, task_id="test-task")
assert ctx.request_context is None
assert ctx.lifespan_context == {}