diff --git a/src/fastmcp/server/context.py b/src/fastmcp/server/context.py index 31e759c981..a39817e585 100644 --- a/src/fastmcp/server/context.py +++ b/src/fastmcp/server/context.py @@ -338,9 +338,13 @@ async def report_progress( ) -> None: """Report progress for the current operation. + Works in both foreground (MCP progress notifications) and background + (Docket task execution) contexts. + Args: progress: Current progress value e.g. 24 total: Optional total value e.g. 100 + message: Optional status message describing current progress """ progress_token = ( @@ -349,16 +353,48 @@ async def report_progress( else None ) - if progress_token is None: + # Foreground: Send MCP progress notification if we have a token + if progress_token is not None: + await self.session.send_progress_notification( + progress_token=progress_token, + progress=progress, + total=total, + message=message, + related_request_id=self.request_id, + ) return - await self.session.send_progress_notification( - progress_token=progress_token, - progress=progress, - total=total, - message=message, - related_request_id=self.request_id, - ) + # Background: Update Docket execution progress (stored in Redis) + # This makes progress visible via tasks/get and notifications/tasks/status + from fastmcp.server.dependencies import is_docket_available + + if not is_docket_available(): + return + + try: + from docket.dependencies import Dependency + + # Get current execution from worker context + execution = Dependency.execution.get() + + # Update progress in Redis using Docket's progress API. + # Docket only exposes increment() (relative), so we compute + # the delta from the last reported value stored on this execution. + if total is not None: + await execution.progress.set_total(int(total)) + + current = int(progress) + last: int = getattr(execution, "_fastmcp_last_progress", 0) + delta = current - last + if delta > 0: + await execution.progress.increment(delta) + execution._fastmcp_last_progress = current # type: ignore[attr-defined] + + if message is not None: + await execution.progress.set_message(message) + except LookupError: + # Not running in Docket worker context - no progress tracking available + pass async def _paginate_list( self, diff --git a/src/fastmcp/server/tasks/__init__.py b/src/fastmcp/server/tasks/__init__.py index 13ba9e80ce..20dd733a86 100644 --- a/src/fastmcp/server/tasks/__init__.py +++ b/src/fastmcp/server/tasks/__init__.py @@ -11,6 +11,11 @@ get_client_task_id_from_key, parse_task_key, ) +from fastmcp.server.tasks.notifications import ( + ensure_subscriber_running, + push_notification, + stop_subscriber, +) __all__ = [ "TaskConfig", @@ -18,8 +23,11 @@ "TaskMode", "build_task_key", "elicit_for_task", + "ensure_subscriber_running", "get_client_task_id_from_key", "get_task_capabilities", "handle_task_input", "parse_task_key", + "push_notification", + "stop_subscriber", ] diff --git a/src/fastmcp/server/tasks/elicitation.py b/src/fastmcp/server/tasks/elicitation.py index 2fc0bef5fc..299382df63 100644 --- a/src/fastmcp/server/tasks/elicitation.py +++ b/src/fastmcp/server/tasks/elicitation.py @@ -5,7 +5,7 @@ an active request context, so elicitation requires special handling: 1. Set task status to "input_required" via Redis -2. Send notifications/tasks/updated with elicitation metadata +2. Send notifications/tasks/status with elicitation metadata 3. Wait for client to send input via tasks/sendInput 4. Resume task execution with the provided input @@ -15,11 +15,11 @@ from __future__ import annotations -import asyncio import json import logging import uuid -from typing import TYPE_CHECKING, Any +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, cast import mcp.types from mcp import ServerSession @@ -75,12 +75,21 @@ async def elicit_for_task( # Generate a unique request ID for this elicitation request_id = str(uuid.uuid4()) - # Get session ID for Redis key construction - session_id = getattr(session, "_fastmcp_state_prefix", None) - if session_id is None: - # Generate a session ID if not already set - session_id = str(uuid.uuid4()) - session._fastmcp_state_prefix = session_id # type: ignore[attr-defined] + # Get session ID from task context (authoritative source for background tasks) + # This is extracted from the Docket execution key: {session_id}:{task_id}:... + from fastmcp.server.dependencies import get_task_context + + task_context = get_task_context() + if task_context is not None: + session_id = task_context.session_id + else: + # Fallback: try to get from session attribute (shouldn't happen in background) + session_id = getattr(session, "_fastmcp_state_prefix", None) + if session_id is None: + raise RuntimeError( + "Cannot determine session_id for elicitation. " + "This typically means elicit_for_task() was called outside a Docket worker context." + ) # Store elicitation request in Redis request_key = ELICIT_REQUEST_KEY.format(session_id=session_id, task_id=task_id) @@ -107,13 +116,24 @@ async def elicit_for_task( ex=ELICIT_TTL_SECONDS, ) - # Send task status update notification with input_required status - # This follows SEP-1686 for background task status updates - notification = mcp.types.JSONRPCNotification( - jsonrpc="2.0", - method="notifications/tasks/updated", - params={}, - _meta={ # type: ignore[call-arg] + # Send task status update notification with input_required status. + # Use notifications/tasks/status so typed MCP clients can consume it. + # + # NOTE: We use the distributed notification queue instead of session.send_notification() + # This enables notifications to work when workers run in separate processes + # (Azure Web PubSub / Service Bus inspired pattern) + timestamp = datetime.now(timezone.utc).isoformat() + notification_dict = { + "method": "notifications/tasks/status", + "params": { + "taskId": task_id, + "status": "input_required", + "statusMessage": message, + "createdAt": timestamp, + "lastUpdatedAt": timestamp, + "ttl": ELICIT_TTL_SECONDS * 1000, + }, + "_meta": { "modelcontextprotocol.io/related-task": { "taskId": task_id, "status": "input_required", @@ -125,49 +145,87 @@ async def elicit_for_task( }, } }, - ) + } + + # Push notification to Redis queue (works from any process) + # Server's subscriber loop will forward to client + from fastmcp.server.tasks.notifications import push_notification - # Send notification (best effort - task status is stored in Redis) - # Log failures for debugging but don't fail the elicitation try: - await session.send_notification(notification) # type: ignore[arg-type] + await push_notification(session_id, notification_dict, docket) except Exception as e: + # Fail fast: if notification can't be queued, client won't know to respond + # Return cancel immediately rather than waiting for 1-hour timeout logger.warning( - "Failed to send input_required notification for task %s: %s", + "Failed to queue input_required notification for task %s, cancelling elicitation: %s", task_id, e, ) + # Best-effort cleanup + try: + async with docket.redis() as redis: + await redis.delete( + docket.key(request_key), + docket.key(status_key), + ) + except Exception: + pass # Keys will expire via TTL + return mcp.types.ElicitResult(action="cancel", content=None) - # Wait for response (poll Redis) - # In a production implementation, this could use Redis pub/sub for lower latency + # Wait for response using BLPOP (blocking pop) + # This is much more efficient than polling - single Redis round-trip + # that blocks until a response is pushed, vs 7,200 round-trips/hour with polling max_wait_seconds = ELICIT_TTL_SECONDS - poll_interval = 0.5 # seconds - for _ in range(int(max_wait_seconds / poll_interval)): + try: async with docket.redis() as redis: - response_data = await redis.get(docket.key(response_key)) - if response_data: + # BLPOP blocks until an item is pushed to the list or timeout + # Returns tuple of (key, value) or None on timeout + result = await cast( + Any, + redis.blpop( + [docket.key(response_key)], + timeout=max_wait_seconds, + ), + ) + + if result: + # result is (key, value) tuple + _key, response_data = result response = json.loads(response_data) + # Clean up Redis keys await redis.delete( docket.key(request_key), - docket.key(response_key), docket.key(status_key), ) + # Convert to ElicitResult return mcp.types.ElicitResult( action=response.get("action", "accept"), content=response.get("content"), ) + except Exception as e: + logger.warning( + "BLPOP failed for task %s elicitation, falling back to cancel: %s", + task_id, + e, + ) - await asyncio.sleep(poll_interval) - - # Timeout - treat as cancellation - async with docket.redis() as redis: - await redis.delete( - docket.key(request_key), - docket.key(response_key), - docket.key(status_key), + # Timeout or error - treat as cancellation + # Best-effort cleanup - if Redis is unavailable, keys will expire via TTL + try: + async with docket.redis() as redis: + await redis.delete( + docket.key(request_key), + docket.key(response_key), + docket.key(status_key), + ) + except Exception as cleanup_error: + logger.debug( + "Failed to clean up elicitation keys for task %s (will expire via TTL): %s", + task_id, + cleanup_error, ) return mcp.types.ElicitResult(action="cancel", content=None) @@ -213,12 +271,15 @@ async def handle_task_input( if status is None or status.decode("utf-8") != "waiting": return False - # Store the response - await redis.set( + # Push response to list - this wakes up the BLPOP in elicit_for_task + # Using LPUSH instead of SET enables the efficient blocking wait pattern + await redis.lpush( # type: ignore[invalid-await] # redis-py union type (sync/async) docket.key(response_key), json.dumps(response), - ex=ELICIT_TTL_SECONDS, ) + # Set TTL on the response list (in case BLPOP doesn't consume it) + await redis.expire(docket.key(response_key), ELICIT_TTL_SECONDS) + # Update status to "responded" await redis.set( docket.key(status_key), diff --git a/src/fastmcp/server/tasks/handlers.py b/src/fastmcp/server/tasks/handlers.py index 02da22148d..494fce87fe 100644 --- a/src/fastmcp/server/tasks/handlers.py +++ b/src/fastmcp/server/tasks/handlers.py @@ -17,6 +17,7 @@ from fastmcp.server.dependencies import _current_docket, get_context from fastmcp.server.tasks.config import TaskMeta from fastmcp.server.tasks.keys import build_task_key +from fastmcp.utilities.logging import get_logger if TYPE_CHECKING: from fastmcp.prompts.prompt import Prompt @@ -24,6 +25,8 @@ from fastmcp.resources.template import ResourceTemplate from fastmcp.tools.tool import Tool +logger = get_logger(__name__) + # Redis mapping TTL buffer: Add 15 minutes to Docket's execution_ttl TASK_MAPPING_TTL_BUFFER_SECONDS = 15 * 60 @@ -109,21 +112,31 @@ async def submit_to_docket( register_task_session(session_id, ctx.session) - # Send notifications/tasks/created per SEP-1686 (mandatory) - # Send BEFORE queuing to avoid race where task completes before notification - notification = mcp.types.JSONRPCNotification( - jsonrpc="2.0", - method="notifications/tasks/created", - params={}, # Empty params per spec - _meta={ # type: ignore[call-arg] # _meta is Pydantic alias for meta field - "modelcontextprotocol.io/related-task": { + # Send an initial tasks/status notification before queueing. + # This guarantees clients can observe task creation immediately. + notification = mcp.types.TaskStatusNotification.model_validate( + { + "method": "notifications/tasks/status", + "params": { "taskId": server_task_id, - } - }, + "status": "working", + "statusMessage": "Task submitted", + "createdAt": created_at, + "lastUpdatedAt": created_at, + "ttl": ttl_ms, + "pollInterval": poll_interval_ms, + }, + "_meta": { + "modelcontextprotocol.io/related-task": { + "taskId": server_task_id, + } + }, + } ) + server_notification = mcp.types.ServerNotification(notification) with suppress(Exception): # Don't let notification failures break task creation - await ctx.session.send_notification(notification) # type: ignore[arg-type] + await ctx.session.send_notification(server_notification) # Queue function to Docket by key (result storage via execution_ttl) # Use component.add_to_docket() which handles calling conventions @@ -151,6 +164,34 @@ async def submit_to_docket( poll_interval_ms, ) + # Start notification subscriber for distributed elicitation (idempotent) + # This enables ctx.elicit() to work when workers run in separate processes + # Subscriber forwards notifications from Redis queue to client session + from fastmcp.server.tasks.notifications import ( + ensure_subscriber_running, + stop_subscriber, + ) + + try: + await ensure_subscriber_running(session_id, ctx.session, docket) + + # Register cleanup callback on session exit (once per session) + # This ensures subscriber is stopped when the session disconnects + if ( + hasattr(ctx.session, "_exit_stack") + and ctx.session._exit_stack is not None + and not getattr(ctx.session, "_notification_cleanup_registered", False) + ): + + async def _cleanup_subscriber() -> None: + await stop_subscriber(session_id) + + ctx.session._exit_stack.push_async_callback(_cleanup_subscriber) + ctx.session._notification_cleanup_registered = True # type: ignore[attr-defined] + except Exception as e: + # Non-fatal: elicitation will still work via polling fallback + logger.debug("Failed to start notification subscriber: %s", e) + # Return CreateTaskResult with proper Task object # Tasks MUST begin in "working" status per SEP-1686 final spec (line 381) return mcp.types.CreateTaskResult( diff --git a/src/fastmcp/server/tasks/notifications.py b/src/fastmcp/server/tasks/notifications.py new file mode 100644 index 0000000000..2c65e31e1b --- /dev/null +++ b/src/fastmcp/server/tasks/notifications.py @@ -0,0 +1,256 @@ +"""Distributed notification queue for background task events (SEP-1686). + +Enables distributed Docket workers to send MCP notifications to clients +without holding session references. Workers push to a Redis queue, +the MCP server process subscribes and forwards to the client's session. + +Pattern: Fire-and-forward with retry +- One queue per session_id +- LPUSH/BRPOP for reliable ordered delivery +- Retry up to 3 times on delivery failure, then discard +- TTL-based expiration for stale messages + +Note: Docket's execution.subscribe() handles task state/progress events via +Redis Pub/Sub. This module handles elicitation-specific notifications that +require reliable delivery (input_required prompts, cancel signals). +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import weakref +from contextlib import suppress +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, cast + +import mcp.types + +if TYPE_CHECKING: + from docket import Docket + from mcp.server.session import ServerSession + +logger = logging.getLogger(__name__) + +# Redis key patterns +NOTIFICATION_QUEUE_KEY = "fastmcp:notifications:{session_id}" +NOTIFICATION_ACTIVE_KEY = "fastmcp:notifications:{session_id}:active" + +# Configuration +NOTIFICATION_TTL_SECONDS = 300 # 5 minute message TTL (elicitation response window) +MAX_DELIVERY_ATTEMPTS = 3 # Retry failed deliveries before discarding +SUBSCRIBER_TIMEOUT_SECONDS = 30 # BRPOP timeout (also heartbeat interval) + + +async def push_notification( + session_id: str, + notification: dict[str, Any], + docket: Docket, +) -> None: + """Push notification to session's queue (called from Docket worker). + + Used for elicitation-specific notifications (input_required, cancel) + that need reliable delivery across distributed processes. + + Args: + session_id: Target session's identifier + notification: MCP notification dict (method, params, _meta) + docket: Docket instance for Redis access + """ + key = docket.key(NOTIFICATION_QUEUE_KEY.format(session_id=session_id)) + message = json.dumps( + { + "notification": notification, + "attempt": 0, + "enqueued_at": datetime.now(timezone.utc).isoformat(), + } + ) + async with docket.redis() as redis: + await redis.lpush(key, message) # type: ignore[invalid-await] # redis-py union type (sync/async) + await redis.expire(key, NOTIFICATION_TTL_SECONDS) + + +async def notification_subscriber_loop( + session_id: str, + session: ServerSession, + docket: Docket, +) -> None: + """Subscribe to notification queue and forward to session. + + Runs in the MCP server process. Bridges distributed workers to clients. + + This loop: + 1. Maintains a heartbeat (active subscriber marker for debugging) + 2. Blocks on BRPOP waiting for notifications + 3. Forwards notifications to the client's session + 4. Retries failed deliveries, then discards (no dead-letter queue) + + Args: + session_id: Session identifier to subscribe to + session: MCP ServerSession for sending notifications + docket: Docket instance for Redis access + """ + queue_key = docket.key(NOTIFICATION_QUEUE_KEY.format(session_id=session_id)) + active_key = docket.key(NOTIFICATION_ACTIVE_KEY.format(session_id=session_id)) + + logger.debug("Starting notification subscriber for session %s", session_id) + + while True: + try: + async with docket.redis() as redis: + # Heartbeat: mark subscriber as active (for distributed debugging) + await redis.set(active_key, "1", ex=SUBSCRIBER_TIMEOUT_SECONDS * 2) + + # Blocking wait for notification (timeout refreshes heartbeat) + # Using BRPOP (right pop) for FIFO order with LPUSH (left push) + result = await cast( + Any, redis.brpop([queue_key], timeout=SUBSCRIBER_TIMEOUT_SECONDS) + ) + if not result: + continue # Timeout - refresh heartbeat and retry + + _, message_bytes = result + message = json.loads(message_bytes) + notification_dict = message["notification"] + attempt = message.get("attempt", 0) + + try: + # Reconstruct and send MCP notification + await _send_mcp_notification(session, notification_dict) + logger.debug( + "Delivered notification to session %s (attempt %d)", + session_id, + attempt + 1, + ) + except Exception as send_error: + # Delivery failed - retry or discard + if attempt < MAX_DELIVERY_ATTEMPTS - 1: + # Re-queue with incremented attempt (back of queue) + message["attempt"] = attempt + 1 + message["last_error"] = str(send_error) + await redis.lpush(queue_key, json.dumps(message)) # type: ignore[invalid-await] + logger.debug( + "Requeued notification for session %s (attempt %d): %s", + session_id, + attempt + 2, + send_error, + ) + else: + # Discard after max attempts (session likely disconnected) + logger.warning( + "Discarding notification for session %s after %d attempts: %s", + session_id, + MAX_DELIVERY_ATTEMPTS, + send_error, + ) + + except asyncio.CancelledError: + # Graceful shutdown - leave pending messages in queue for reconnect + logger.debug("Notification subscriber cancelled for session %s", session_id) + break + except Exception as e: + logger.debug( + "Notification subscriber error for session %s: %s", session_id, e + ) + await asyncio.sleep(1) # Backoff on error + + +async def _send_mcp_notification( + session: ServerSession, + notification_dict: dict[str, Any], +) -> None: + """Reconstruct MCP notification from dict and send to session. + + Args: + session: MCP ServerSession + notification_dict: Notification as dict (method, params, _meta) + """ + method = notification_dict.get("method", "notifications/tasks/status") + if method != "notifications/tasks/status": + raise ValueError(f"Unsupported notification method for subscriber: {method}") + + notification = mcp.types.TaskStatusNotification.model_validate( + { + "method": "notifications/tasks/status", + "params": notification_dict.get("params", {}), + "_meta": notification_dict.get("_meta"), + } + ) + server_notification = mcp.types.ServerNotification(notification) + + await session.send_notification(server_notification) + + +# ============================================================================= +# Subscriber Management +# ============================================================================= + +# Registry of active subscribers per session (prevents duplicates) +# Uses weakref to session to detect disconnects +_active_subscribers: dict[ + str, tuple[asyncio.Task[None], weakref.ref[ServerSession]] +] = {} + + +async def ensure_subscriber_running( + session_id: str, + session: ServerSession, + docket: Docket, +) -> None: + """Start notification subscriber if not already running (idempotent). + + Subscriber is created on first task submission and cleaned up on disconnect. + Safe to call multiple times for the same session. + + Args: + session_id: Session identifier + session: MCP ServerSession + docket: Docket instance + """ + # Check if subscriber already running for this session + if session_id in _active_subscribers: + task, session_ref = _active_subscribers[session_id] + # Check if task is still running AND session is still alive + if not task.done() and session_ref() is not None: + return # Already running + + # Task finished or session dead - clean up + if not task.done(): + task.cancel() + with suppress(asyncio.CancelledError): + await task + del _active_subscribers[session_id] + + # Start new subscriber task + task = asyncio.create_task( + notification_subscriber_loop(session_id, session, docket), + name=f"notification-subscriber-{session_id[:8]}", + ) + _active_subscribers[session_id] = (task, weakref.ref(session)) + logger.debug("Started notification subscriber for session %s", session_id) + + +async def stop_subscriber(session_id: str) -> None: + """Stop notification subscriber for a session. + + Called when session disconnects. Pending messages remain in queue + for delivery if client reconnects (with TTL expiration). + + Args: + session_id: Session identifier + """ + if session_id not in _active_subscribers: + return + + task, _ = _active_subscribers.pop(session_id) + if not task.done(): + task.cancel() + with suppress(asyncio.CancelledError): + await task + logger.debug("Stopped notification subscriber for session %s", session_id) + + +def get_subscriber_count() -> int: + """Get number of active subscribers (for monitoring).""" + return len(_active_subscribers) diff --git a/tests/server/tasks/test_context_background_task.py b/tests/server/tasks/test_context_background_task.py index b778a63b29..219a7ae216 100644 --- a/tests/server/tasks/test_context_background_task.py +++ b/tests/server/tasks/test_context_background_task.py @@ -1,11 +1,25 @@ -"""Tests for Context background task support (SEP-1686).""" +"""Tests for Context background task support (SEP-1686). + +Tests Context API surface (unit) and background task elicitation (integration). +Integration tests use Client(mcp) with the real memory:// Docket backend — +no mocking of Redis, Docket, or session internals. +""" + +import asyncio +from typing import cast import pytest +from mcp import ServerSession from fastmcp import FastMCP +from fastmcp.client import Client from fastmcp.server.context import Context -from fastmcp.server.elicitation import AcceptedElicitation -from fastmcp.server.tasks.elicitation import elicit_for_task, handle_task_input +from fastmcp.server.elicitation import AcceptedElicitation, DeclinedElicitation +from fastmcp.server.tasks.elicitation import handle_task_input + +# ============================================================================= +# Unit tests: Context API surface (no Redis/Docket needed) +# ============================================================================= class TestContextBackgroundTaskSupport: @@ -30,7 +44,7 @@ def test_context_task_id_is_readonly(self): mcp = FastMCP("test") ctx = Context(mcp, task_id="test-task-123") with pytest.raises(AttributeError): - ctx.task_id = "new-id" # type: ignore[misc] + setattr(ctx, "task_id", "new-id") class TestContextSessionProperty: @@ -52,9 +66,10 @@ class MockSession: _fastmcp_state_prefix = "test-session" mock_session = MockSession() - ctx = Context(mcp, session=mock_session, task_id="test-task-123") # type: ignore[arg-type] + ctx = Context( + mcp, session=cast(ServerSession, mock_session), task_id="test-task-123" + ) - # In background task mode, should return the stored session assert ctx.session is mock_session def test_session_uses_stored_session_during_on_initialize(self): @@ -65,32 +80,75 @@ class MockSession: _fastmcp_state_prefix = "test-session" mock_session = MockSession() - # Simulating on_initialize: has session but not a background task - ctx = Context(mcp, session=mock_session) # type: ignore[arg-type] + ctx = Context(mcp, session=cast(ServerSession, mock_session)) - # Should return the stored session as fallback assert ctx.session is mock_session class TestContextElicitBackgroundTask: """Tests for Context.elicit() in background task mode.""" - @pytest.mark.asyncio async def test_elicit_raises_when_background_task_but_no_docket(self): """elicit() should raise when in background task mode but Docket unavailable.""" mcp = FastMCP("test") ctx = Context(mcp, task_id="test-task-123") - # Set up minimal session mock class MockSession: _fastmcp_state_prefix = "test-session" - ctx._session = MockSession() # type: ignore[assignment] + ctx._session = cast(ServerSession, MockSession()) with pytest.raises(RuntimeError, match="Docket"): await ctx.elicit("Need input", str) +class TestElicitFailFast: + """Tests for elicit_for_task fail-fast on notification push failure.""" + + async def test_elicit_returns_cancel_when_notification_push_fails(self): + """elicit_for_task should return cancel immediately when push_notification fails. + + If the client can't receive the input_required notification, waiting + for a response that will never come would block for up to 1 hour. + Instead, we return cancel immediately (fail-fast). + + This test patches ONLY push_notification — all other components + (Docket, Redis, session) are real via the memory:// backend. + """ + from unittest.mock import patch + + from fastmcp.server.elicitation import CancelledElicitation + + mcp = FastMCP("failfast-test") + elicit_started = asyncio.Event() + captured: dict[str, object] = {} + + @mcp.tool(task=True) + async def failfast_tool(ctx: Context) -> str: + elicit_started.set() + result = await ctx.elicit("This notification will fail", str) + captured["result_type"] = type(result).__name__ + captured["is_cancelled"] = isinstance(result, CancelledElicitation) + return "done" + + # Patch push_notification BEFORE starting client so it's active + # when the tool runs in the Docket worker + with patch( + "fastmcp.server.tasks.notifications.push_notification", + side_effect=ConnectionError("Redis queue unavailable"), + ): + async with Client(mcp) as client: + task = await client.call_tool("failfast_tool", {}, task=True) + await asyncio.wait_for(elicit_started.wait(), timeout=5.0) + await task.wait(timeout=10.0) + result = await task.result() + assert result.data == "done" + + # The tool should have received CancelledElicitation (fail-fast) + assert captured["is_cancelled"] is True + assert captured["result_type"] == "CancelledElicitation" + + class TestContextDocumentation: """Tests to verify Context documentation and API surface.""" @@ -110,753 +168,227 @@ def test_session_has_docstring(self): assert "background task" in Context.session.fget.__doc__.lower() -class TestBackgroundTaskElicitationE2E: - """End-to-end tests for background task elicitation (SEP-1686). - - These tests demonstrate the full flow: - 1. Client calls a tool with task=True (background execution) - 2. Tool uses ctx.elicit() to request user input - 3. Task status changes to "input_required" - 4. Client sends input via handle_task_input() - 5. Task resumes and completes with the elicited value +# ============================================================================= +# Integration tests: Client(mcp) + memory:// Docket backend +# ============================================================================= - This simulates what a client would see when interacting with - a background task that needs user input. - """ - - async def test_elicit_for_task_stores_request_in_redis(self): - """Test that elicit_for_task stores the elicitation request in Redis. - This tests the Redis coordination layer that enables client interaction. - When a background task calls elicit(), the request is stored in Redis - so clients can retrieve it and respond. - """ - from unittest.mock import AsyncMock, MagicMock, patch - - from fastmcp.server.tasks.elicitation import ( - elicit_for_task, - ) - - # Create mocks - mock_redis = AsyncMock() - mock_redis.set = AsyncMock() - mock_redis.get = AsyncMock(return_value=None) # No response yet - mock_redis.delete = AsyncMock() - - mock_docket = MagicMock() - mock_docket.redis = MagicMock(return_value=AsyncMock()) - mock_docket.redis.return_value.__aenter__ = AsyncMock(return_value=mock_redis) - mock_docket.redis.return_value.__aexit__ = AsyncMock() - mock_docket.key = lambda k: k - - mock_fastmcp = MagicMock() - mock_fastmcp._docket = mock_docket - - mock_session = MagicMock() - mock_session._fastmcp_state_prefix = "test-session-id" - mock_session.send_notification = AsyncMock() - - # Call elicit_for_task with a short timeout to avoid blocking - with patch("fastmcp.server.tasks.elicitation.ELICIT_TTL_SECONDS", 1): - with patch("fastmcp.server.tasks.elicitation.asyncio.sleep", AsyncMock()): - # Make it return after first poll - mock_redis.get = AsyncMock( - return_value=b'{"action": "accept", "content": {"value": 42}}' - ) - - result = await elicit_for_task( - task_id="test-task-123", - session=mock_session, - message="Please provide a number", - schema={ - "type": "object", - "properties": {"value": {"type": "integer"}}, - }, - fastmcp=mock_fastmcp, - ) - - # Verify the result - assert result.action == "accept" - assert result.content == {"value": 42} - - # Verify Redis operations were called - assert mock_redis.set.call_count >= 2 # request + status - - async def test_handle_task_input_stores_response(self): - """Test that handle_task_input stores the response in Redis. - - This tests the client-side flow: when a client sends input via - tasks/sendInput, the response is stored in Redis for the waiting task. - """ - from unittest.mock import AsyncMock, MagicMock - - # Create mocks - mock_redis = AsyncMock() - mock_redis.get = AsyncMock(return_value=b"waiting") # Status is waiting - mock_redis.set = AsyncMock() - - mock_docket = MagicMock() - mock_docket.redis = MagicMock(return_value=AsyncMock()) - mock_docket.redis.return_value.__aenter__ = AsyncMock(return_value=mock_redis) - mock_docket.redis.return_value.__aexit__ = AsyncMock() - mock_docket.key = lambda k: k - - mock_fastmcp = MagicMock() - mock_fastmcp._docket = mock_docket - - # Call handle_task_input - success = await handle_task_input( - task_id="test-task-123", - session_id="test-session-id", - action="accept", - content={"value": 42}, - fastmcp=mock_fastmcp, - ) - - # Verify success - assert success is True - - # Verify Redis operations - assert mock_redis.set.call_count == 2 # response + status update - - async def test_handle_task_input_rejects_when_not_waiting(self): - """Test that handle_task_input rejects input when task isn't waiting. - - This verifies proper state management - clients can only send input - when a task is actually waiting for it. - """ - from unittest.mock import AsyncMock, MagicMock - - mock_redis = AsyncMock() - mock_redis.get = AsyncMock(return_value=None) # No waiting status - - mock_docket = MagicMock() - mock_docket.redis = MagicMock(return_value=AsyncMock()) - mock_docket.redis.return_value.__aenter__ = AsyncMock(return_value=mock_redis) - mock_docket.redis.return_value.__aexit__ = AsyncMock() - mock_docket.key = lambda k: k - - mock_fastmcp = MagicMock() - mock_fastmcp._docket = mock_docket - - success = await handle_task_input( - task_id="test-task-123", - session_id="test-session-id", - action="accept", - content={"value": 42}, - fastmcp=mock_fastmcp, - ) - - # Should fail because no task is waiting - assert success is False - - async def test_elicit_for_task_sends_notification(self): - """Test that elicit_for_task sends input_required notification. - - Per SEP-1686, the server should send notifications/tasks/updated - with status="input_required" when a task needs input. - """ - from unittest.mock import AsyncMock, MagicMock, patch +class TestBackgroundTaskIntegration: + """Integration tests for background task context using real Docket memory backend. - mock_redis = AsyncMock() - mock_redis.set = AsyncMock() - mock_redis.get = AsyncMock( - return_value=b'{"action": "accept", "content": {"value": 1}}' - ) - mock_redis.delete = AsyncMock() - - mock_docket = MagicMock() - mock_docket.redis = MagicMock(return_value=AsyncMock()) - mock_docket.redis.return_value.__aenter__ = AsyncMock(return_value=mock_redis) - mock_docket.redis.return_value.__aexit__ = AsyncMock() - mock_docket.key = lambda k: k - - mock_fastmcp = MagicMock() - mock_fastmcp._docket = mock_docket - - mock_session = MagicMock() - mock_session._fastmcp_state_prefix = "test-session" - mock_session.send_notification = AsyncMock() - - with patch("fastmcp.server.tasks.elicitation.asyncio.sleep", AsyncMock()): - await elicit_for_task( - task_id="my-task-id", - session=mock_session, - message="Enter value", - schema={"type": "object"}, - fastmcp=mock_fastmcp, - ) - - # Verify notification was sent - mock_session.send_notification.assert_called_once() - notification = mock_session.send_notification.call_args[0][0] - assert notification.method == "notifications/tasks/updated" - - async def test_elicit_for_task_timeout_returns_cancel(self): - """Test that elicit_for_task returns cancel on timeout. - - If no response is received within the TTL, the elicitation - should be treated as cancelled. - """ - from unittest.mock import AsyncMock, MagicMock, patch - - mock_redis = AsyncMock() - mock_redis.set = AsyncMock() - mock_redis.get = AsyncMock(return_value=None) # Never responds - mock_redis.delete = AsyncMock() - - mock_docket = MagicMock() - mock_docket.redis = MagicMock(return_value=AsyncMock()) - mock_docket.redis.return_value.__aenter__ = AsyncMock(return_value=mock_redis) - mock_docket.redis.return_value.__aexit__ = AsyncMock() - mock_docket.key = lambda k: k - - mock_fastmcp = MagicMock() - mock_fastmcp._docket = mock_docket - - mock_session = MagicMock() - mock_session._fastmcp_state_prefix = "test-session" - mock_session.send_notification = AsyncMock() - - # Use very short TTL for test - with patch("fastmcp.server.tasks.elicitation.ELICIT_TTL_SECONDS", 0.1): - with patch( - "fastmcp.server.tasks.elicitation.asyncio.sleep", - AsyncMock(), - ): - result = await elicit_for_task( - task_id="timeout-task", - session=mock_session, - message="This will timeout", - schema={"type": "object"}, - fastmcp=mock_fastmcp, - ) - - # Should return cancel on timeout - assert result.action == "cancel" - assert result.content is None - - async def test_elicit_notification_includes_full_schema(self): - """Test that the notification includes the full JSON schema for complex types. - - This test demonstrates what the client sees when eliciting a Pydantic model. - The client receives a full JSON Schema that describes the expected input, - which they can use to: - - Render a dynamic form - - Validate user input before sending - - Show field descriptions to the user - - Example notification metadata for a UserInfo model: - ```json - { - "modelcontextprotocol.io/related-task": { - "taskId": "test-task", - "status": "input_required", - "statusMessage": "Please provide user info", - "elicitation": { - "requestId": "...", - "message": "Please provide user info", - "requestedSchema": { - "type": "object", - "properties": { - "name": {"type": "string", "title": "Name"}, - "age": {"type": "integer", "title": "Age"} - }, - "required": ["name", "age"], - "title": "UserInfo" - } - } - } - } - ``` - """ - from unittest.mock import AsyncMock, MagicMock, patch - - from pydantic import BaseModel - - class UserInfo(BaseModel): - """User information for registration.""" - - name: str - age: int - - mock_redis = AsyncMock() - mock_redis.set = AsyncMock() - mock_redis.get = AsyncMock( - return_value=b'{"action": "accept", "content": {"name": "Alice", "age": 30}}' - ) - mock_redis.delete = AsyncMock() - - mock_docket = MagicMock() - mock_docket.redis = MagicMock(return_value=AsyncMock()) - mock_docket.redis.return_value.__aenter__ = AsyncMock(return_value=mock_redis) - mock_docket.redis.return_value.__aexit__ = AsyncMock() - mock_docket.key = lambda k: k - - mock_fastmcp = MagicMock() - mock_fastmcp._docket = mock_docket - - mock_session = MagicMock() - mock_session._fastmcp_state_prefix = "test-session" - mock_session.send_notification = AsyncMock() - - # Create task-aware context - ctx = Context( - mock_fastmcp, - session=mock_session, - task_id="schema-test-task", - ) - - # Call elicit with a Pydantic model type - with patch("fastmcp.server.tasks.elicitation.asyncio.sleep", AsyncMock()): - result = await ctx.elicit("Please provide user info", UserInfo) - - # Verify the notification includes the full schema - mock_session.send_notification.assert_called_once() - notification = mock_session.send_notification.call_args[0][0] - meta = notification._meta - related_task = meta["modelcontextprotocol.io/related-task"] - schema = related_task["elicitation"]["requestedSchema"] - - # Verify schema structure matches UserInfo - assert schema["type"] == "object" - assert "properties" in schema - assert "name" in schema["properties"] - assert "age" in schema["properties"] - assert schema["properties"]["name"]["type"] == "string" - assert schema["properties"]["age"]["type"] == "integer" - assert "required" in schema - assert set(schema["required"]) == {"name", "age"} - - # Verify the result is properly parsed into the Pydantic model - assert result.action == "accept" - assert isinstance(result, AcceptedElicitation) # Type narrowing - assert isinstance(result.data, UserInfo) - assert result.data.name == "Alice" - assert result.data.age == 30 - - -class TestBackgroundTaskContextWiring: - """Integration tests for Context wiring in Docket workers. - - These tests verify that when a background task runs in a Docket worker, - the Context dependency is properly created with task_id and session, - allowing ctx.elicit() to work transparently. - - Per Chris Guidry's review request: "Could we get at least one test showing - the end-to-end of it working, with a background task that's eliciting input? - This will help with what the client-side sees when this happens." - - The key test is `test_context_elicit_full_flow_with_mocked_redis` which shows: - - CLIENT RECEIVES: - notifications/tasks/updated with: - - taskId: the background task ID - - status: "input_required" - - statusMessage: the elicit prompt - - elicitation.requestedSchema: JSON schema for expected input - - CLIENT RESPONDS: - handle_task_input(task_id, session_id, action="accept", content={...}) - - TOOL RECEIVES: - AcceptedElicitation(action="accept", data=) + These tests use Client(mcp) with the memory:// broker — no mocking. + The memory:// backend provides a fully functional in-memory Redis store + that Docket uses automatically when running tests. """ - async def test_context_is_created_with_task_id_in_worker(self): - """Test that Context is created with task_id when running in Docket worker. - - This verifies the wiring from _CurrentContext that creates a task-aware - Context when get_task_context() returns TaskContextInfo. - """ - from unittest.mock import MagicMock, patch - - from fastmcp.server.dependencies import ( - TaskContextInfo, - _current_server, - _CurrentContext, - _task_sessions, - ) - - # Set up mock server - mock_server = MagicMock() - mock_server._docket = MagicMock() - server_token = _current_server.set(MagicMock(return_value=mock_server)) - - # Set up mock session in registry - mock_session = MagicMock() - mock_session._fastmcp_state_prefix = "test-session-id" - _task_sessions["test-session-id"] = MagicMock(return_value=mock_session) - - try: - # Mock get_task_context to return TaskContextInfo - task_info = TaskContextInfo( - task_id="test-task-123", - session_id="test-session-id", - ) - with patch( - "fastmcp.server.dependencies.get_task_context", - return_value=task_info, - ): - # Create the dependency and enter it - dep = _CurrentContext() - ctx = await dep.__aenter__() - - # Verify context is task-aware - assert ctx.is_background_task is True - assert ctx.task_id == "test-task-123" - assert ctx.session is mock_session - - # Clean up - await dep.__aexit__(None, None, None) - finally: - _current_server.reset(server_token) - _task_sessions.pop("test-session-id", None) - - async def test_context_falls_back_to_foreground_mode(self): - """Test that Context uses foreground mode when not in worker context. - - When _current_context has a value (normal request handling), - _CurrentContext should return that context instead of creating a new one. - """ - from unittest.mock import MagicMock - - from fastmcp.server.context import Context, _current_context - from fastmcp.server.dependencies import _CurrentContext - - mcp = MagicMock() - foreground_ctx = Context(mcp) - - # Set the foreground context - token = _current_context.set(foreground_ctx) - try: - dep = _CurrentContext() - ctx = await dep.__aenter__() - - # Should return the foreground context - assert ctx is foreground_ctx - assert ctx.is_background_task is False - - await dep.__aexit__(None, None, None) - finally: - _current_context.reset(token) - - async def test_session_registered_when_task_submitted(self): - """Test that session is registered when a task is submitted to Docket. - - This verifies that submit_to_docket calls register_task_session, - which enables the Context wiring in background workers. - """ - import asyncio - - from fastmcp import FastMCP - from fastmcp.client import Client - from fastmcp.server.dependencies import get_task_session - - mcp = FastMCP("test-server") - - task_started = asyncio.Event() - session_id_captured = None + async def test_report_progress_in_background_task(self): + """report_progress() should complete without error in a background task.""" + mcp = FastMCP("progress-test") + progress_reported = asyncio.Event() @mcp.tool(task=True) - async def capture_session_tool(ctx: Context) -> str: - """Tool that captures the session ID for verification.""" - nonlocal session_id_captured - task_started.set() - # Access session to verify it works - session_id_captured = ctx.session_id + async def progress_tool(ctx: Context) -> str: + await ctx.report_progress(0, 100, "Starting...") + await ctx.report_progress(50, 100, "Half done") + await ctx.report_progress(100, 100, "Complete") + progress_reported.set() return "done" async with Client(mcp) as client: - # Start the task - task = await client.call_tool("capture_session_tool", {}, task=True) - assert task is not None - - # Wait for the task to start - await asyncio.wait_for(task_started.wait(), timeout=5.0) - - # Verify the session was registered - assert session_id_captured is not None - # The session should be retrievable via get_task_session - # (it was registered when the task was submitted) - # Session may be available or None if cleaned up - key is registration happened - _ = get_task_session(session_id_captured) - - # Wait for task to complete + task = await client.call_tool("progress_tool", {}, task=True) + await asyncio.wait_for(progress_reported.wait(), timeout=5.0) await task.wait(timeout=5.0) result = await task.result() assert result.data == "done" - async def test_context_elicit_works_in_background_task(self): - """E2E test: verify Context is properly wired in background tasks. - - This test demonstrates that: - 1. Context.task_id is set correctly in background tasks - 2. Context.is_background_task returns True - 3. Context.session_id is available - - The wiring is what enables ctx.elicit() to work in background tasks. - """ - import asyncio - - from fastmcp import FastMCP - from fastmcp.client import Client - from fastmcp.server.context import Context - - mcp = FastMCP("context-wiring-test") - - # Track what happens in the background task + async def test_context_wiring_in_background_task(self): + """Context should be properly wired with task_id and session_id.""" + mcp = FastMCP("wiring-test") task_completed = asyncio.Event() - captured_task_id: str | None = None - captured_session_id: str | None = None - captured_is_background: bool | None = None + captured: dict[str, object] = {} @mcp.tool(task=True) - async def verify_context_tool(ctx: Context) -> str: - """Tool that verifies Context is wired correctly for background tasks.""" - nonlocal captured_task_id, captured_session_id, captured_is_background - - # Capture context properties - this is the key verification - captured_task_id = ctx.task_id - captured_session_id = ctx.session_id - captured_is_background = ctx.is_background_task - + async def verify_wiring(ctx: Context) -> str: + captured["task_id"] = ctx.task_id + captured["session_id"] = ctx.session_id + captured["is_background"] = ctx.is_background_task task_completed.set() - return f"task_id={ctx.task_id}, is_background={ctx.is_background_task}" + return "ok" async with Client(mcp) as client: - # Start the background task - task = await client.call_tool("verify_context_tool", {}, task=True) - assert task is not None - assert task.task_id is not None - - # Wait for the task to complete - await asyncio.wait_for(task_completed.wait(), timeout=10.0) - - # Verify Context was properly wired in the background task - assert captured_task_id is not None, "Context.task_id should be set" - assert captured_session_id is not None, "Context.session_id should be set" - assert captured_is_background is True, ( - "Context.is_background_task should be True" - ) - - # Wait for task result - await task.wait(timeout=10.0) + task = await client.call_tool("verify_wiring", {}, task=True) + await asyncio.wait_for(task_completed.wait(), timeout=5.0) + await task.wait(timeout=5.0) result = await task.result() - assert "is_background=True" in result.data + assert result.data == "ok" - async def test_context_elicit_full_flow_with_mocked_redis(self): - """E2E test with mocked Redis to show complete elicitation flow. + assert captured["task_id"] is not None + assert captured["session_id"] is not None + assert captured["is_background"] is True - This test demonstrates what the client sees during background task - elicitation, with a mocked Redis layer to avoid requiring real Redis. + async def test_elicit_accept_flow(self): + """E2E: tool elicits input, client accepts, tool receives value. Flow: - 1. Tool calls ctx.elicit() in background task - 2. Elicitation stores request in Redis, sends input_required notification - 3. Simulated client sends response via handle_task_input() - 4. Tool receives response and completes - - This is the key test that fulfills Chris Guidry's request for an - "end-to-end test showing a background task that's eliciting input" - and demonstrates "what the client-side sees when this happens." + 1. Tool calls ctx.elicit("name?", str) — blocks waiting for input + 2. Client polls handle_task_input(action="accept", content={"value":"Bob"}) + 3. Tool resumes with AcceptedElicitation(data="Bob") """ - import asyncio - from unittest.mock import AsyncMock, MagicMock - - from fastmcp.server.context import Context - from fastmcp.server.tasks.elicitation import handle_task_input - - # Shared Redis storage that both elicit and handle_task_input will use - redis_storage: dict[str, bytes] = {} - - # Create a mock Redis that uses our shared storage - class MockRedis: - async def set( - self, key: str, value: str | bytes, ex: int | None = None - ) -> None: - redis_storage[key] = value.encode() if isinstance(value, str) else value - - async def get(self, key: str) -> bytes | None: - return redis_storage.get(key) - - async def delete(self, *keys: str) -> None: - for key in keys: - redis_storage.pop(key, None) + mcp = FastMCP("elicit-accept-test") + elicit_started = asyncio.Event() + captured: dict[str, str | None] = {"task_id": None, "session_id": None} - mock_redis = MockRedis() + @mcp.tool(task=True) + async def ask_name(ctx: Context) -> str: + captured["task_id"] = ctx.task_id + captured["session_id"] = ctx.session_id + elicit_started.set() - # Create mock context manager for redis() - class MockRedisContext: - async def __aenter__(self): - return mock_redis + result = await ctx.elicit("What is your name?", str) + if isinstance(result, AcceptedElicitation): + return f"Hello, {result.data}!" + return "No name provided" - async def __aexit__(self, *args): - pass + async with Client(mcp) as client: + task = await client.call_tool("ask_name", {}, task=True) + await asyncio.wait_for(elicit_started.wait(), timeout=5.0) - mock_docket = MagicMock() - mock_docket.redis = lambda: MockRedisContext() - mock_docket.key = lambda k: k + assert captured["task_id"] is not None + assert captured["session_id"] is not None - mock_fastmcp = MagicMock() - mock_fastmcp._docket = mock_docket + # Poll until the "waiting" status is stored in Redis + success = False + for _ in range(40): + success = await handle_task_input( + task_id=captured["task_id"], + session_id=captured["session_id"], + action="accept", + content={"value": "Bob"}, + fastmcp=mcp, + ) + if success: + break + await asyncio.sleep(0.05) - mock_session = MagicMock() - mock_session._fastmcp_state_prefix = "test-session-123" - mock_session.send_notification = AsyncMock() + assert success is True, "handle_task_input should succeed within 2s" - # Create task-aware context (as would be created in background worker) - ctx = Context( - mock_fastmcp, - session=mock_session, - task_id="test-task-456", - ) + await task.wait(timeout=10.0) + result = await task.result() + assert result.data == "Hello, Bob!" - # Verify context is properly configured for background task - assert ctx.is_background_task is True - assert ctx.task_id == "test-task-456" - - # Start elicit in a background task (simulating the Docket worker) - async def run_elicit(): - return await ctx.elicit("What is your name?", str) - - elicit_task = asyncio.create_task(run_elicit()) - - # Wait for elicit to store request and start polling - # The elicit_for_task function stores the request and sends notification - await asyncio.sleep(0.2) - - # ═══════════════════════════════════════════════════════════════════════ - # CLIENT PERSPECTIVE: What does the client see? - # ═══════════════════════════════════════════════════════════════════════ - - # 1. CLIENT RECEIVES: notifications/tasks/updated notification - mock_session.send_notification.assert_called() - notification = mock_session.send_notification.call_args[0][0] - assert notification.method == "notifications/tasks/updated" - - # 2. CLIENT INSPECTS: The notification metadata tells the client: - # - Which task needs input (taskId) - # - What status the task is in (input_required) - # - What message to display (statusMessage) - # - The schema for the expected response (elicitation.requestedSchema) - meta = notification._meta - related_task = meta["modelcontextprotocol.io/related-task"] - - assert related_task["taskId"] == "test-task-456" - assert related_task["status"] == "input_required" - assert related_task["statusMessage"] == "What is your name?" - assert "elicitation" in related_task - assert related_task["elicitation"]["message"] == "What is your name?" - assert "requestedSchema" in related_task["elicitation"] - - # 3. CLIENT RESPONDS: Send input via handle_task_input - # This is what a real client would do when it receives input_required - success = await handle_task_input( - task_id="test-task-456", - session_id="test-session-123", - action="accept", - content={"value": "Alice"}, - fastmcp=mock_fastmcp, - ) - assert success is True, "Client should successfully send input" + async def test_elicit_decline_flow(self): + """E2E: tool elicits input, client declines, tool gets DeclinedElicitation.""" + mcp = FastMCP("elicit-decline-test") + elicit_started = asyncio.Event() + captured: dict[str, str | None] = {"task_id": None, "session_id": None} - # ═══════════════════════════════════════════════════════════════════════ - # TOOL PERSPECTIVE: What does the tool receive? - # ═══════════════════════════════════════════════════════════════════════ + @mcp.tool(task=True) + async def optional_input(ctx: Context) -> str: + captured["task_id"] = ctx.task_id + captured["session_id"] = ctx.session_id + elicit_started.set() - # Wait for elicit to receive the response and return - result = await asyncio.wait_for(elicit_task, timeout=5.0) + result = await ctx.elicit("Want to provide a name?", str) + if isinstance(result, DeclinedElicitation): + return "User declined" + if isinstance(result, AcceptedElicitation): + return f"Got: {result.data}" + return "Cancelled" - # Verify the result contains what the client sent - # AcceptedElicitation has 'action' and 'data' attributes - assert result.action == "accept" - assert result.data == "Alice" # The value from content["value"] + async with Client(mcp) as client: + task = await client.call_tool("optional_input", {}, task=True) + await asyncio.wait_for(elicit_started.wait(), timeout=5.0) - async def test_context_elicit_with_real_docket_memory_backend(self): - """E2E test using Docket's real memory:// backend. + assert captured["task_id"] is not None + assert captured["session_id"] is not None - This test uses the real Docket memory backend instead of mocking Redis, - as suggested by Chris Guidry during code review. The memory:// backend - provides a fully functional in-memory Redis-like store that Docket uses - automatically when running tests. + success = False + for _ in range(40): + success = await handle_task_input( + task_id=captured["task_id"], + session_id=captured["session_id"], + action="decline", + content=None, + fastmcp=mcp, + ) + if success: + break + await asyncio.sleep(0.05) - Flow: - 1. Create FastMCP server with task-enabled tool that calls ctx.elicit() - 2. Start the task via Client (which initializes Docket with memory://) - 3. Background task blocks waiting for client input - 4. Simulate client sending input via handle_task_input() - 5. Task resumes and completes with the elicited value + assert success is True - This demonstrates the complete elicitation flow with real infrastructure. - """ - import asyncio + await task.wait(timeout=10.0) + result = await task.result() + assert result.data == "User declined" - from fastmcp import FastMCP - from fastmcp.client import Client - from fastmcp.server.context import Context - from fastmcp.server.tasks.elicitation import handle_task_input + async def test_elicit_with_pydantic_model(self): + """E2E: tool elicits structured Pydantic input, data round-trips correctly.""" + from pydantic import BaseModel - mcp = FastMCP("elicit-memory-test") + class UserInfo(BaseModel): + name: str + age: int - # Track task state using mutable container (avoids nonlocal) + mcp = FastMCP("elicit-pydantic-test") elicit_started = asyncio.Event() captured: dict[str, str | None] = {"task_id": None, "session_id": None} @mcp.tool(task=True) - async def ask_for_name(ctx: Context) -> str: - """Tool that elicits user's name via background task.""" - # Capture IDs for handle_task_input call + async def get_user_info(ctx: Context) -> str: captured["task_id"] = ctx.task_id captured["session_id"] = ctx.session_id elicit_started.set() - # This will block until client sends input - result = await ctx.elicit("What is your name?", str) - + result = await ctx.elicit("Provide user info", UserInfo) if isinstance(result, AcceptedElicitation): - return f"Hello, {result.data}!" - else: - return "Elicitation was declined or cancelled" + assert isinstance(result.data, UserInfo) + return f"{result.data.name} is {result.data.age}" + return "No info" async with Client(mcp) as client: - # Start the background task - task = await client.call_tool("ask_for_name", {}, task=True) - assert task is not None - assert task.task_id is not None - - # Wait for task to reach elicit() call + task = await client.call_tool("get_user_info", {}, task=True) await asyncio.wait_for(elicit_started.wait(), timeout=5.0) - # Poll until handle_task_input succeeds - # We need to wait for elicit_for_task to store the "waiting" status in Redis - # before we can send input. Using fixed-interval polling (not exponential - # backoff) because we're waiting for state, not recovering from errors. assert captured["task_id"] is not None assert captured["session_id"] is not None - max_attempts = 40 - poll_interval_seconds = 0.05 # 50ms - fast for tests, 2s max total success = False - for _ in range(max_attempts): + for _ in range(40): success = await handle_task_input( task_id=captured["task_id"], session_id=captured["session_id"], action="accept", - content={"value": "Bob"}, + content={"name": "Alice", "age": 30}, fastmcp=mcp, ) if success: break - await asyncio.sleep(poll_interval_seconds) + await asyncio.sleep(0.05) - assert success is True, ( - f"handle_task_input should succeed within {max_attempts * poll_interval_seconds}s" - ) + assert success is True - # Wait for task to complete await task.wait(timeout=10.0) result = await task.result() + assert result.data == "Alice is 30" - # Verify the tool received the elicited value and returned correctly - assert result.data == "Hello, Bob!" + async def test_handle_task_input_rejects_when_not_waiting(self): + """handle_task_input returns False when no task is waiting for input.""" + mcp = FastMCP("reject-test") + + @mcp.tool(task=True) + async def simple_tool() -> str: + return "done" + + async with Client(mcp) as client: + task = await client.call_tool("simple_tool", {}, task=True) + await task.wait(timeout=5.0) + + # Task already completed — no elicitation waiting + success = await handle_task_input( + task_id=task.task_id, + session_id="nonexistent-session", + action="accept", + content={"value": "too late"}, + fastmcp=mcp, + ) + assert success is False diff --git a/tests/server/tasks/test_notifications.py b/tests/server/tasks/test_notifications.py new file mode 100644 index 0000000000..07b2b7f3b6 --- /dev/null +++ b/tests/server/tasks/test_notifications.py @@ -0,0 +1,165 @@ +"""Tests for distributed notification queue (SEP-1686). + +Integration tests verify that the notification queue works end-to-end +using Client(mcp) with the real memory:// Docket backend. +No mocking of Redis, sessions, or Docket internals. +""" + +import asyncio + +import mcp.types + +from fastmcp import FastMCP +from fastmcp.client import Client +from fastmcp.client.messages import MessageHandler +from fastmcp.server.context import Context +from fastmcp.server.elicitation import AcceptedElicitation +from fastmcp.server.tasks.elicitation import handle_task_input +from fastmcp.server.tasks.notifications import ( + get_subscriber_count, +) + + +class NotificationCaptureHandler(MessageHandler): + """Capture server notifications for test assertions.""" + + def __init__(self) -> None: + super().__init__() + self.notifications: list[mcp.types.ServerNotification] = [] + + async def on_notification(self, message: mcp.types.ServerNotification) -> None: + self.notifications.append(message) + + def for_method(self, method: str) -> list[mcp.types.ServerNotification]: + return [ + notification + for notification in self.notifications + if notification.root.method == method + ] + + +class TestNotificationIntegration: + """Integration tests for the notification queue using real Docket memory backend. + + The elicitation flow implicitly validates the full notification pipeline: + 1. Tool calls ctx.elicit() → stores request in Redis → pushes notification + 2. Subscriber picks up notification → sends MCP notification to client + 3. Client calls handle_task_input() → LPUSH response → BLPOP wakes tool + """ + + async def test_notification_delivered_during_elicitation(self): + """Full E2E: notification queue delivers input_required metadata to client.""" + mcp = FastMCP("notification-test") + notification_handler = NotificationCaptureHandler() + elicit_started = asyncio.Event() + captured: dict[str, str | None] = {"task_id": None, "session_id": None} + + @mcp.tool(task=True) + async def elicit_tool(ctx: Context) -> str: + captured["task_id"] = ctx.task_id + captured["session_id"] = ctx.session_id + elicit_started.set() + + result = await ctx.elicit("Enter value", str) + if isinstance(result, AcceptedElicitation): + return f"got: {result.data}" + return "no value" + + async with Client(mcp, message_handler=notification_handler) as client: + task = await client.call_tool("elicit_tool", {}, task=True) + await asyncio.wait_for(elicit_started.wait(), timeout=5.0) + + assert captured["task_id"] is not None + assert captured["session_id"] is not None + + notification: mcp.types.ServerNotification | None = None + for _ in range(40): + candidates = notification_handler.for_method( + "notifications/tasks/status" + ) + for candidate in reversed(candidates): + candidate_meta = getattr(candidate.root, "_meta", None) + related_task = ( + candidate_meta.get("modelcontextprotocol.io/related-task") + if isinstance(candidate_meta, dict) + else None + ) + if ( + isinstance(related_task, dict) + and related_task.get("status") == "input_required" + ): + notification = candidate + break + if notification is not None: + break + await asyncio.sleep(0.05) + + assert notification is not None, "expected notifications/tasks/status" + task_meta = getattr(notification.root, "_meta", None) + assert isinstance(task_meta, dict) + + related_task = task_meta.get("modelcontextprotocol.io/related-task") + assert isinstance(related_task, dict) + assert related_task.get("taskId") == captured["task_id"] + assert related_task.get("status") == "input_required" + + elicitation = related_task.get("elicitation") + assert isinstance(elicitation, dict) + assert elicitation.get("message") == "Enter value" + assert isinstance(elicitation.get("requestId"), str) + assert isinstance(elicitation.get("requestedSchema"), dict) + + success = False + for _ in range(40): + success = await handle_task_input( + task_id=captured["task_id"], + session_id=captured["session_id"], + action="accept", + content={"value": "hello"}, + fastmcp=mcp, + ) + if success: + break + await asyncio.sleep(0.05) + + assert success is True + + await task.wait(timeout=10.0) + result = await task.result() + assert result.data == "got: hello" + + async def test_subscriber_started_and_cleaned_up(self): + """Subscriber starts during background task and stops when client disconnects.""" + mcp = FastMCP("subscriber-test") + tool_started = asyncio.Event() + tool_continue = asyncio.Event() + + @mcp.tool(task=True) + async def lifecycle_tool(ctx: Context) -> str: + tool_started.set() + await asyncio.wait_for(tool_continue.wait(), timeout=10.0) + return "done" + + count_before = get_subscriber_count() + + async with Client(mcp) as client: + task = await client.call_tool("lifecycle_tool", {}, task=True) + await asyncio.wait_for(tool_started.wait(), timeout=5.0) + + # While a background task is running, subscriber should be active + count_during = get_subscriber_count() + assert count_during > count_before + + # Let the tool complete + tool_continue.set() + await task.wait(timeout=5.0) + result = await task.result() + assert result.data == "done" + + # After client disconnects, subscriber should be cleaned up + # Allow brief time for async cleanup + for _ in range(20): + if get_subscriber_count() == count_before: + break + await asyncio.sleep(0.05) + assert get_subscriber_count() == count_before diff --git a/tests/server/tasks/test_task_protocol.py b/tests/server/tasks/test_task_protocol.py index 1d1daa02c1..9c6a88d22e 100644 --- a/tests/server/tasks/test_task_protocol.py +++ b/tests/server/tasks/test_task_protocol.py @@ -48,7 +48,7 @@ async def test_task_metadata_includes_task_id_and_ttl(task_enabled_server): async def test_task_notification_sent_after_submission(task_enabled_server): - """Server sends notifications/tasks/created after task submission.""" + """Server sends an initial task status notification after submission.""" @task_enabled_server.tool(task=True) async def background_tool(message: str) -> str: