diff --git a/src/fastmcp/server/server.py b/src/fastmcp/server/server.py index 3ebe56d2e5..02eb30f5bd 100644 --- a/src/fastmcp/server/server.py +++ b/src/fastmcp/server/server.py @@ -606,7 +606,7 @@ def _setup_task_protocol_handlers(self) -> None: ServerResult, ) - from fastmcp.server.tasks.protocol import ( + from fastmcp.server.tasks.requests import ( tasks_cancel_handler, tasks_get_handler, tasks_list_handler, diff --git a/src/fastmcp/server/tasks/config.py b/src/fastmcp/server/tasks/config.py index 1eeecc8bb7..6531eba325 100644 --- a/src/fastmcp/server/tasks/config.py +++ b/src/fastmcp/server/tasks/config.py @@ -15,6 +15,11 @@ # Task execution modes per SEP-1686 / MCP ToolExecution.taskSupport TaskMode = Literal["forbidden", "optional", "required"] +# Default values for task metadata (single source of truth) +DEFAULT_POLL_INTERVAL = timedelta(seconds=5) # Default poll interval +DEFAULT_POLL_INTERVAL_MS = int(DEFAULT_POLL_INTERVAL.total_seconds() * 1000) +DEFAULT_TTL_MS = 60_000 # Default TTL in milliseconds + @dataclass class TaskConfig: @@ -47,7 +52,7 @@ async def flexible_task(): ... """ mode: TaskMode = "optional" - poll_interval: timedelta = timedelta(seconds=5) + poll_interval: timedelta = DEFAULT_POLL_INTERVAL @classmethod def from_bool(cls, value: bool) -> TaskConfig: diff --git a/src/fastmcp/server/tasks/protocol.py b/src/fastmcp/server/tasks/requests.py similarity index 76% rename from src/fastmcp/server/tasks/protocol.py rename to src/fastmcp/server/tasks/requests.py index e4ca09a865..8c7165ebec 100644 --- a/src/fastmcp/server/tasks/protocol.py +++ b/src/fastmcp/server/tasks/requests.py @@ -1,6 +1,7 @@ -"""SEP-1686 task protocol handlers. +"""SEP-1686 task request handlers. -Implements MCP task protocol methods: tasks/get, tasks/result, tasks/list, tasks/cancel, tasks/delete. +Handles MCP task protocol requests: tasks/get, tasks/result, tasks/list, tasks/cancel. +These handlers query and manage existing tasks (contrast with handlers.py which creates tasks). """ from __future__ import annotations @@ -20,6 +21,7 @@ ListTasksResult, ) +from fastmcp.server.tasks.config import DEFAULT_POLL_INTERVAL_MS, DEFAULT_TTL_MS from fastmcp.server.tasks.keys import parse_task_key if TYPE_CHECKING: @@ -37,6 +39,69 @@ } +async def _lookup_task_execution( + docket: Any, + session_id: str, + client_task_id: str, +) -> tuple[Any, str | None, int]: + """Look up task execution and metadata from Redis. + + Consolidates the common pattern of fetching task metadata from Redis, + validating it exists, and retrieving the Docket execution. + + Args: + docket: Docket instance + session_id: Session ID + client_task_id: Client-provided task ID + + Returns: + Tuple of (execution, created_at, poll_interval_ms) + + Raises: + McpError: If task not found or execution not found + """ + # Build Redis keys + redis_key = f"fastmcp:task:{session_id}:{client_task_id}" + created_at_key = f"{redis_key}:created_at" + poll_interval_key = f"{redis_key}:poll_interval" + + # Fetch metadata (single round-trip with mget) + async with docket.redis() as redis: + task_key_bytes, created_at_bytes, poll_interval_bytes = await redis.mget( + redis_key, created_at_key, poll_interval_key + ) + + # Decode and validate task_key + task_key = task_key_bytes.decode("utf-8") if task_key_bytes else None + if not task_key: + raise McpError( + ErrorData(code=INVALID_PARAMS, message=f"Task {client_task_id} not found") + ) + + # Get execution + execution = await docket.get_execution(task_key) + if not execution: + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message=f"Task {client_task_id} execution not found", + ) + ) + + # Parse metadata with defaults + created_at = created_at_bytes.decode("utf-8") if created_at_bytes else None + try: + poll_interval_ms = ( + int(poll_interval_bytes.decode("utf-8")) + if poll_interval_bytes + else DEFAULT_POLL_INTERVAL_MS + ) + except (ValueError, UnicodeDecodeError): + poll_interval_ms = DEFAULT_POLL_INTERVAL_MS + + return execution, created_at, poll_interval_ms + + async def tasks_get_handler(server: FastMCP, params: dict[str, Any]) -> GetTaskResult: """Handle MCP 'tasks/get' request (SEP-1686). @@ -61,7 +126,7 @@ async def tasks_get_handler(server: FastMCP, params: dict[str, Any]) -> GetTaskR # Get session ID from Context session_id = ctx.session_id - # Get execution from Docket (use instance attribute for cross-task access) + # Get Docket instance docket = server._docket if docket is None: raise McpError( @@ -71,45 +136,10 @@ async def tasks_get_handler(server: FastMCP, params: dict[str, Any]) -> GetTaskR ) ) - # Look up task metadata from Redis - redis_key = f"fastmcp:task:{session_id}:{client_task_id}" - created_at_key = f"fastmcp:task:{session_id}:{client_task_id}:created_at" - poll_interval_key = f"fastmcp:task:{session_id}:{client_task_id}:poll_interval" - async with docket.redis() as redis: - task_key_bytes = await redis.get(redis_key) - created_at_bytes = await redis.get(created_at_key) - poll_interval_bytes = await redis.get(poll_interval_key) - - task_key = None if task_key_bytes is None else task_key_bytes.decode("utf-8") - created_at = ( - None if created_at_bytes is None else created_at_bytes.decode("utf-8") + # Look up task execution and metadata + execution, created_at, poll_interval_ms = await _lookup_task_execution( + docket, session_id, client_task_id ) - try: - poll_interval_ms = ( - int(poll_interval_bytes.decode("utf-8")) - if poll_interval_bytes - else 5000 # Default to 5 seconds - ) - except (ValueError, UnicodeDecodeError): - poll_interval_ms = 5000 - - if task_key is None: - # Task not found - raise error per MCP protocol - raise McpError( - ErrorData( - code=INVALID_PARAMS, message=f"Task {client_task_id} not found" - ) - ) - - execution = await docket.get_execution(task_key) - if execution is None: - # Task key exists but no execution - raise error - raise McpError( - ErrorData( - code=INVALID_PARAMS, - message=f"Task {client_task_id} execution not found", - ) - ) # Sync state from Redis await execution.sync() @@ -138,7 +168,7 @@ async def tasks_get_handler(server: FastMCP, params: dict[str, Any]) -> GetTaskR status=mcp_state, # type: ignore[arg-type] createdAt=created_at, # type: ignore[arg-type] lastUpdatedAt=datetime.now(timezone.utc), - ttl=60000, + ttl=DEFAULT_TTL_MS, pollInterval=poll_interval_ms, statusMessage=status_message, ) @@ -345,6 +375,7 @@ async def tasks_cancel_handler( # Get session ID from Context session_id = ctx.session_id + # Get Docket instance docket = server._docket if docket is None: raise McpError( @@ -354,48 +385,14 @@ async def tasks_cancel_handler( ) ) - # Look up task metadata from Redis - redis_key = f"fastmcp:task:{session_id}:{client_task_id}" - created_at_key = f"fastmcp:task:{session_id}:{client_task_id}:created_at" - poll_interval_key = f"fastmcp:task:{session_id}:{client_task_id}:poll_interval" - async with docket.redis() as redis: - task_key_bytes = await redis.get(redis_key) - created_at_bytes = await redis.get(created_at_key) - poll_interval_bytes = await redis.get(poll_interval_key) - - task_key = None if task_key_bytes is None else task_key_bytes.decode("utf-8") - created_at = ( - None if created_at_bytes is None else created_at_bytes.decode("utf-8") + # Look up task execution and metadata + execution, created_at, poll_interval_ms = await _lookup_task_execution( + docket, session_id, client_task_id ) - try: - poll_interval_ms = ( - int(poll_interval_bytes.decode("utf-8")) - if poll_interval_bytes - else 5000 # Default to 5 seconds - ) - except (ValueError, UnicodeDecodeError): - poll_interval_ms = 5000 - - if task_key is None: - raise McpError( - ErrorData( - code=INVALID_PARAMS, - message=f"Invalid taskId: {client_task_id} not found", - ) - ) - - # Check if task exists - execution = await docket.get_execution(task_key) - if execution is None: - raise McpError( - ErrorData( - code=INVALID_PARAMS, - message=f"Invalid taskId: {client_task_id} not found", - ) - ) # Cancel via Docket (now sets CANCELLED state natively) - await docket.cancel(task_key) + # Note: We need to get task_key from execution.key for cancellation + await docket.cancel(execution.key) # Return task status with cancelled state # createdAt is REQUIRED per SEP-1686 final spec (line 430) @@ -403,9 +400,11 @@ async def tasks_cancel_handler( return CancelTaskResult( taskId=client_task_id, status="cancelled", - createdAt=created_at or datetime.now(timezone.utc).isoformat(), + createdAt=datetime.fromisoformat(created_at) + if created_at + else datetime.now(timezone.utc), lastUpdatedAt=datetime.now(timezone.utc), - ttl=60_000, + ttl=DEFAULT_TTL_MS, pollInterval=poll_interval_ms, statusMessage="Task cancelled", ) diff --git a/src/fastmcp/server/tasks/subscriptions.py b/src/fastmcp/server/tasks/subscriptions.py index 90059b7455..2f07b2db87 100644 --- a/src/fastmcp/server/tasks/subscriptions.py +++ b/src/fastmcp/server/tasks/subscriptions.py @@ -13,7 +13,7 @@ from docket.execution import ExecutionState from mcp.types import TaskStatusNotification, TaskStatusNotificationParams -from fastmcp.server.tasks.protocol import DOCKET_TO_MCP_STATE +from fastmcp.server.tasks.requests import DOCKET_TO_MCP_STATE from fastmcp.utilities.logging import get_logger if TYPE_CHECKING: