Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/fastmcp/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def _setup_task_protocol_handlers(self) -> None:
ServerResult,
)

from fastmcp.server.tasks.protocol import (
from fastmcp.server.tasks.requests import (
tasks_cancel_handler,
tasks_get_handler,
tasks_list_handler,
Expand Down
7 changes: 6 additions & 1 deletion src/fastmcp/server/tasks/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
# Task execution modes per SEP-1686 / MCP ToolExecution.taskSupport
TaskMode = Literal["forbidden", "optional", "required"]

# Default values for task metadata (single source of truth)
DEFAULT_POLL_INTERVAL = timedelta(seconds=5) # Default poll interval
DEFAULT_POLL_INTERVAL_MS = int(DEFAULT_POLL_INTERVAL.total_seconds() * 1000)
DEFAULT_TTL_MS = 60_000 # Default TTL in milliseconds


@dataclass
class TaskConfig:
Expand Down Expand Up @@ -47,7 +52,7 @@ async def flexible_task(): ...
"""

mode: TaskMode = "optional"
poll_interval: timedelta = timedelta(seconds=5)
poll_interval: timedelta = DEFAULT_POLL_INTERVAL

@classmethod
def from_bool(cls, value: bool) -> TaskConfig:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""SEP-1686 task protocol handlers.
"""SEP-1686 task request handlers.

Implements MCP task protocol methods: tasks/get, tasks/result, tasks/list, tasks/cancel, tasks/delete.
Handles MCP task protocol requests: tasks/get, tasks/result, tasks/list, tasks/cancel.
These handlers query and manage existing tasks (contrast with handlers.py which creates tasks).
"""

from __future__ import annotations
Expand All @@ -20,6 +21,7 @@
ListTasksResult,
)

from fastmcp.server.tasks.config import DEFAULT_POLL_INTERVAL_MS, DEFAULT_TTL_MS
from fastmcp.server.tasks.keys import parse_task_key

if TYPE_CHECKING:
Expand All @@ -37,6 +39,69 @@
}


async def _lookup_task_execution(
docket: Any,
session_id: str,
client_task_id: str,
) -> tuple[Any, str | None, int]:
"""Look up task execution and metadata from Redis.

Consolidates the common pattern of fetching task metadata from Redis,
validating it exists, and retrieving the Docket execution.

Args:
docket: Docket instance
session_id: Session ID
client_task_id: Client-provided task ID

Returns:
Tuple of (execution, created_at, poll_interval_ms)

Raises:
McpError: If task not found or execution not found
"""
# Build Redis keys
redis_key = f"fastmcp:task:{session_id}:{client_task_id}"
created_at_key = f"{redis_key}:created_at"
poll_interval_key = f"{redis_key}:poll_interval"

# Fetch metadata (single round-trip with mget)
async with docket.redis() as redis:
task_key_bytes, created_at_bytes, poll_interval_bytes = await redis.mget(
redis_key, created_at_key, poll_interval_key
)

# Decode and validate task_key
task_key = task_key_bytes.decode("utf-8") if task_key_bytes else None
if not task_key:
raise McpError(
ErrorData(code=INVALID_PARAMS, message=f"Task {client_task_id} not found")
)

# Get execution
execution = await docket.get_execution(task_key)
if not execution:
raise McpError(
ErrorData(
code=INVALID_PARAMS,
message=f"Task {client_task_id} execution not found",
)
)

# Parse metadata with defaults
created_at = created_at_bytes.decode("utf-8") if created_at_bytes else None
try:
poll_interval_ms = (
int(poll_interval_bytes.decode("utf-8"))
if poll_interval_bytes
else DEFAULT_POLL_INTERVAL_MS
)
except (ValueError, UnicodeDecodeError):
poll_interval_ms = DEFAULT_POLL_INTERVAL_MS

return execution, created_at, poll_interval_ms


async def tasks_get_handler(server: FastMCP, params: dict[str, Any]) -> GetTaskResult:
"""Handle MCP 'tasks/get' request (SEP-1686).

Expand All @@ -61,7 +126,7 @@ async def tasks_get_handler(server: FastMCP, params: dict[str, Any]) -> GetTaskR
# Get session ID from Context
session_id = ctx.session_id

# Get execution from Docket (use instance attribute for cross-task access)
# Get Docket instance
docket = server._docket
if docket is None:
raise McpError(
Expand All @@ -71,45 +136,10 @@ async def tasks_get_handler(server: FastMCP, params: dict[str, Any]) -> GetTaskR
)
)

# Look up task metadata from Redis
redis_key = f"fastmcp:task:{session_id}:{client_task_id}"
created_at_key = f"fastmcp:task:{session_id}:{client_task_id}:created_at"
poll_interval_key = f"fastmcp:task:{session_id}:{client_task_id}:poll_interval"
async with docket.redis() as redis:
task_key_bytes = await redis.get(redis_key)
created_at_bytes = await redis.get(created_at_key)
poll_interval_bytes = await redis.get(poll_interval_key)

task_key = None if task_key_bytes is None else task_key_bytes.decode("utf-8")
created_at = (
None if created_at_bytes is None else created_at_bytes.decode("utf-8")
# Look up task execution and metadata
execution, created_at, poll_interval_ms = await _lookup_task_execution(
docket, session_id, client_task_id
)
try:
poll_interval_ms = (
int(poll_interval_bytes.decode("utf-8"))
if poll_interval_bytes
else 5000 # Default to 5 seconds
)
except (ValueError, UnicodeDecodeError):
poll_interval_ms = 5000

if task_key is None:
# Task not found - raise error per MCP protocol
raise McpError(
ErrorData(
code=INVALID_PARAMS, message=f"Task {client_task_id} not found"
)
)

execution = await docket.get_execution(task_key)
if execution is None:
# Task key exists but no execution - raise error
raise McpError(
ErrorData(
code=INVALID_PARAMS,
message=f"Task {client_task_id} execution not found",
)
)

# Sync state from Redis
await execution.sync()
Expand Down Expand Up @@ -138,7 +168,7 @@ async def tasks_get_handler(server: FastMCP, params: dict[str, Any]) -> GetTaskR
status=mcp_state, # type: ignore[arg-type]
createdAt=created_at, # type: ignore[arg-type]
lastUpdatedAt=datetime.now(timezone.utc),
ttl=60000,
ttl=DEFAULT_TTL_MS,
pollInterval=poll_interval_ms,
statusMessage=status_message,
)
Expand Down Expand Up @@ -345,6 +375,7 @@ async def tasks_cancel_handler(
# Get session ID from Context
session_id = ctx.session_id

# Get Docket instance
docket = server._docket
if docket is None:
raise McpError(
Expand All @@ -354,58 +385,26 @@ async def tasks_cancel_handler(
)
)

# Look up task metadata from Redis
redis_key = f"fastmcp:task:{session_id}:{client_task_id}"
created_at_key = f"fastmcp:task:{session_id}:{client_task_id}:created_at"
poll_interval_key = f"fastmcp:task:{session_id}:{client_task_id}:poll_interval"
async with docket.redis() as redis:
task_key_bytes = await redis.get(redis_key)
created_at_bytes = await redis.get(created_at_key)
poll_interval_bytes = await redis.get(poll_interval_key)

task_key = None if task_key_bytes is None else task_key_bytes.decode("utf-8")
created_at = (
None if created_at_bytes is None else created_at_bytes.decode("utf-8")
# Look up task execution and metadata
execution, created_at, poll_interval_ms = await _lookup_task_execution(
docket, session_id, client_task_id
)
try:
poll_interval_ms = (
int(poll_interval_bytes.decode("utf-8"))
if poll_interval_bytes
else 5000 # Default to 5 seconds
)
except (ValueError, UnicodeDecodeError):
poll_interval_ms = 5000

if task_key is None:
raise McpError(
ErrorData(
code=INVALID_PARAMS,
message=f"Invalid taskId: {client_task_id} not found",
)
)

# Check if task exists
execution = await docket.get_execution(task_key)
if execution is None:
raise McpError(
ErrorData(
code=INVALID_PARAMS,
message=f"Invalid taskId: {client_task_id} not found",
)
)

# Cancel via Docket (now sets CANCELLED state natively)
await docket.cancel(task_key)
# Note: We need to get task_key from execution.key for cancellation
await docket.cancel(execution.key)

# Return task status with cancelled state
# createdAt is REQUIRED per SEP-1686 final spec (line 430)
# Per spec lines 447-448: SHOULD NOT include related-task metadata in tasks/cancel
return CancelTaskResult(
taskId=client_task_id,
status="cancelled",
createdAt=created_at or datetime.now(timezone.utc).isoformat(),
createdAt=datetime.fromisoformat(created_at)
if created_at
else datetime.now(timezone.utc),
lastUpdatedAt=datetime.now(timezone.utc),
ttl=60_000,
ttl=DEFAULT_TTL_MS,
pollInterval=poll_interval_ms,
statusMessage="Task cancelled",
)
2 changes: 1 addition & 1 deletion src/fastmcp/server/tasks/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from docket.execution import ExecutionState
from mcp.types import TaskStatusNotification, TaskStatusNotificationParams

from fastmcp.server.tasks.protocol import DOCKET_TO_MCP_STATE
from fastmcp.server.tasks.requests import DOCKET_TO_MCP_STATE
from fastmcp.utilities.logging import get_logger

if TYPE_CHECKING:
Expand Down
Loading