-
Notifications
You must be signed in to change notification settings - Fork 2k
fix: snapshot access token for background tasks (#3095) #3138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,11 +9,13 @@ | |
|
|
||
| import contextlib | ||
| import inspect | ||
| import logging | ||
| import weakref | ||
| from collections.abc import AsyncGenerator, Callable | ||
| from contextlib import AsyncExitStack, asynccontextmanager | ||
| from contextvars import ContextVar | ||
| from contextvars import ContextVar, Token | ||
| from dataclasses import dataclass | ||
| from datetime import datetime, timezone | ||
| from functools import lru_cache | ||
| from typing import TYPE_CHECKING, Any, Protocol, cast, get_type_hints, runtime_checkable | ||
|
|
||
|
|
@@ -33,6 +35,8 @@ | |
| from fastmcp.utilities.async_utils import call_sync_fn_in_threadpool | ||
| from fastmcp.utilities.types import find_kwarg_by_type, is_class_member_of_type | ||
|
|
||
| _logger = logging.getLogger(__name__) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from docket import Docket | ||
| from docket.worker import Worker | ||
|
|
@@ -166,6 +170,9 @@ def get_task_session(session_id: str) -> ServerSession | None: | |
| ) | ||
| _current_docket: ContextVar[Docket | None] = ContextVar("docket", default=None) | ||
| _current_worker: ContextVar[Worker | None] = ContextVar("worker", default=None) | ||
| _task_access_token: ContextVar[AccessToken | None] = ContextVar( | ||
| "task_access_token", default=None | ||
| ) | ||
|
|
||
|
|
||
| # --- Docket availability check --- | ||
|
|
@@ -479,7 +486,8 @@ def get_access_token() -> AccessToken | None: | |
| This function first tries to get the token from the current HTTP request's scope, | ||
| which is more reliable for long-lived connections where the SDK's auth_context_var | ||
| may become stale after token refresh. Falls back to the SDK's context var if no | ||
| request is available. | ||
| request is available. In background tasks (Docket workers), falls back to the | ||
| token snapshot stored in Redis at task submission time. | ||
|
|
||
| Returns: | ||
| The access token if an authenticated user is available, None otherwise. | ||
|
|
@@ -502,6 +510,19 @@ def get_access_token() -> AccessToken | None: | |
| if access_token is None: | ||
| access_token = _sdk_get_access_token() | ||
|
|
||
| # Fall back to background task snapshot (#3095) | ||
| # In Docket workers, neither HTTP request nor SDK context var are available. | ||
| # The token was snapshotted in Redis at submit_to_docket() time and restored | ||
| # into this ContextVar by _CurrentContext.__aenter__(). | ||
| if access_token is None: | ||
| task_token = _task_access_token.get() | ||
| if task_token is not None: | ||
| # Check expiration: if expires_at is set and past, treat as expired | ||
| if task_token.expires_at is not None: | ||
| if task_token.expires_at < int(datetime.now(timezone.utc).timestamp()): | ||
| return None | ||
| return task_token | ||
|
|
||
| if access_token is None or isinstance(access_token, AccessToken): | ||
| return access_token | ||
|
|
||
|
|
@@ -719,14 +740,49 @@ async def resolve_dependencies( | |
| # so that get_dependency_parameters can detect them. | ||
|
|
||
|
|
||
| async def _restore_task_access_token( | ||
| session_id: str, task_id: str | ||
| ) -> Token[AccessToken | None] | None: | ||
| """Restore the access token snapshot from Redis into a ContextVar. | ||
|
|
||
| Called when setting up context in a Docket worker. The token was stored at | ||
| submit_to_docket() time. The token is restored regardless of expiration; | ||
| get_access_token() checks expiry when reading from the ContextVar. | ||
|
|
||
| Returns: | ||
| The ContextVar token for resetting, or None if nothing was restored. | ||
| """ | ||
| docket = _current_docket.get() | ||
| if docket is None: | ||
| return None | ||
|
|
||
| token_key = docket.key(f"fastmcp:task:{session_id}:{task_id}:access_token") | ||
| try: | ||
| async with docket.redis() as redis: | ||
| token_data = await redis.get(token_key) | ||
| if token_data is not None: | ||
| restored = AccessToken.model_validate_json(token_data) | ||
| return _task_access_token.set(restored) | ||
| except Exception: | ||
| _logger.warning( | ||
| "Failed to restore access token for task %s:%s", | ||
| session_id, | ||
| task_id, | ||
| exc_info=True, | ||
| ) | ||
| return None | ||
|
|
||
|
|
||
| class _CurrentContext(Dependency): # type: ignore[misc] | ||
| """Async context manager for Context dependency. | ||
|
|
||
| In foreground (request) mode: returns the active context from _current_context. | ||
| In background (Docket worker) mode: creates a task-aware Context with task_id. | ||
| In background (Docket worker) mode: creates a task-aware Context with task_id | ||
| and restores the access token snapshot from Redis. | ||
| """ | ||
|
|
||
| _context: Context | None = None | ||
| _access_token_cv_token: Token[AccessToken | None] | None = None | ||
|
|
||
| async def __aenter__(self) -> Context: | ||
| from fastmcp.server.context import Context, _current_context | ||
|
|
@@ -751,6 +807,12 @@ async def __aenter__(self) -> Context: | |
| ) | ||
| # Enter the context to set up ContextVars | ||
| await self._context.__aenter__() | ||
|
|
||
| # Restore access token snapshot from Redis (#3095) | ||
| self._access_token_cv_token = await _restore_task_access_token( | ||
| task_info.session_id, task_info.task_id | ||
| ) | ||
|
|
||
| return self._context | ||
|
|
||
| # Neither foreground nor background context available | ||
|
|
@@ -762,6 +824,10 @@ async def __aenter__(self) -> Context: | |
| ) | ||
|
|
||
| async def __aexit__(self, *args: object) -> None: | ||
| # Clean up access token ContextVar | ||
| if self._access_token_cv_token is not None: | ||
| _task_access_token.reset(self._access_token_cv_token) | ||
| self._access_token_cv_token = None | ||
| # Clean up if we created a context for background task | ||
| if self._context is not None: | ||
| await self._context.__aexit__(*args) | ||
|
|
@@ -1130,8 +1196,22 @@ async def __aexit__(self, *args: object) -> None: | |
| class _CurrentAccessToken(Dependency): # type: ignore[misc] | ||
| """Async context manager for AccessToken dependency.""" | ||
|
|
||
| _access_token_cv_token: Token[AccessToken | None] | None = None | ||
|
|
||
| async def __aenter__(self) -> AccessToken: | ||
| token = get_access_token() | ||
|
|
||
| # If no token found and we're in a Docket worker, try restoring from | ||
| # Redis. This handles the case where ctx: Context is not in the | ||
| # function signature, so _CurrentContext never ran the restoration. | ||
| if token is None: | ||
| task_info = get_task_context() | ||
| if task_info is not None: | ||
| self._access_token_cv_token = await _restore_task_access_token( | ||
| task_info.session_id, task_info.task_id | ||
| ) | ||
| token = get_access_token() | ||
|
Comment on lines
+1210
to
+1213
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎. |
||
|
|
||
| if token is None: | ||
| raise RuntimeError( | ||
| "No access token found. Ensure authentication is configured " | ||
|
|
@@ -1140,7 +1220,9 @@ async def __aenter__(self) -> AccessToken: | |
| return token | ||
|
|
||
| async def __aexit__(self, *args: object) -> None: | ||
| pass | ||
| if self._access_token_cv_token is not None: | ||
| _task_access_token.reset(self._access_token_cv_token) | ||
| self._access_token_cv_token = None | ||
|
|
||
|
|
||
| def CurrentAccessToken() -> AccessToken: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_access_token()now falls back to_task_access_token, but this ContextVar is only populated when_restore_task_access_token()runs via_CurrentContextor_CurrentAccessToken. A background task that callsget_access_token()directly and does not injectctx: Context/CurrentAccessToken()never triggers restoration, so this branch still returnsNoneeven thoughsubmit_to_docket()persisted a token for the task.Useful? React with 👍 / 👎.