Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions src/fastmcp/server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,13 @@ async def report_progress(
) -> None:
"""Report progress for the current operation.

Works in both foreground (MCP progress notifications) and background
(Docket task execution) contexts.

Args:
progress: Current progress value e.g. 24
total: Optional total value e.g. 100
message: Optional status message describing current progress
"""

progress_token = (
Expand All @@ -349,16 +353,48 @@ async def report_progress(
else None
)

if progress_token is None:
# Foreground: Send MCP progress notification if we have a token
if progress_token is not None:
await self.session.send_progress_notification(
progress_token=progress_token,
progress=progress,
total=total,
message=message,
related_request_id=self.request_id,
)
return

await self.session.send_progress_notification(
progress_token=progress_token,
progress=progress,
total=total,
message=message,
related_request_id=self.request_id,
)
# Background: Update Docket execution progress (stored in Redis)
# This makes progress visible via tasks/get and notifications/tasks/status
from fastmcp.server.dependencies import is_docket_available

if not is_docket_available():
return

try:
from docket.dependencies import Dependency

# Get current execution from worker context
execution = Dependency.execution.get()

# Update progress in Redis using Docket's progress API.
# Docket only exposes increment() (relative), so we compute
# the delta from the last reported value stored on this execution.
if total is not None:
await execution.progress.set_total(int(total))

current = int(progress)
last: int = getattr(execution, "_fastmcp_last_progress", 0)
delta = current - last
if delta > 0:
await execution.progress.increment(delta)
execution._fastmcp_last_progress = current # type: ignore[attr-defined]

if message is not None:
await execution.progress.set_message(message)
except LookupError:
# Not running in Docket worker context - no progress tracking available
pass

async def _paginate_list(
self,
Expand Down
8 changes: 8 additions & 0 deletions src/fastmcp/server/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,23 @@
get_client_task_id_from_key,
parse_task_key,
)
from fastmcp.server.tasks.notifications import (
ensure_subscriber_running,
push_notification,
stop_subscriber,
)

__all__ = [
"TaskConfig",
"TaskMeta",
"TaskMode",
"build_task_key",
"elicit_for_task",
"ensure_subscriber_running",
"get_client_task_id_from_key",
"get_task_capabilities",
"handle_task_input",
"parse_task_key",
"push_notification",
"stop_subscriber",
]
139 changes: 100 additions & 39 deletions src/fastmcp/server/tasks/elicitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
an active request context, so elicitation requires special handling:

1. Set task status to "input_required" via Redis
2. Send notifications/tasks/updated with elicitation metadata
2. Send notifications/tasks/status with elicitation metadata
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! 😅

3. Wait for client to send input via tasks/sendInput
4. Resume task execution with the provided input

Expand All @@ -15,11 +15,11 @@

from __future__ import annotations

import asyncio
import json
import logging
import uuid
from typing import TYPE_CHECKING, Any
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, cast

import mcp.types
from mcp import ServerSession
Expand Down Expand Up @@ -75,12 +75,21 @@ async def elicit_for_task(
# Generate a unique request ID for this elicitation
request_id = str(uuid.uuid4())

# Get session ID for Redis key construction
session_id = getattr(session, "_fastmcp_state_prefix", None)
if session_id is None:
# Generate a session ID if not already set
session_id = str(uuid.uuid4())
session._fastmcp_state_prefix = session_id # type: ignore[attr-defined]
# Get session ID from task context (authoritative source for background tasks)
# This is extracted from the Docket execution key: {session_id}:{task_id}:...
from fastmcp.server.dependencies import get_task_context

task_context = get_task_context()
if task_context is not None:
session_id = task_context.session_id
else:
# Fallback: try to get from session attribute (shouldn't happen in background)
session_id = getattr(session, "_fastmcp_state_prefix", None)
if session_id is None:
raise RuntimeError(
"Cannot determine session_id for elicitation. "
"This typically means elicit_for_task() was called outside a Docket worker context."
)

# Store elicitation request in Redis
request_key = ELICIT_REQUEST_KEY.format(session_id=session_id, task_id=task_id)
Expand All @@ -107,13 +116,24 @@ async def elicit_for_task(
ex=ELICIT_TTL_SECONDS,
)

# Send task status update notification with input_required status
# This follows SEP-1686 for background task status updates
notification = mcp.types.JSONRPCNotification(
jsonrpc="2.0",
method="notifications/tasks/updated",
params={},
_meta={ # type: ignore[call-arg]
# Send task status update notification with input_required status.
# Use notifications/tasks/status so typed MCP clients can consume it.
#
# NOTE: We use the distributed notification queue instead of session.send_notification()
# This enables notifications to work when workers run in separate processes
# (Azure Web PubSub / Service Bus inspired pattern)
timestamp = datetime.now(timezone.utc).isoformat()
notification_dict = {
"method": "notifications/tasks/status",
"params": {
"taskId": task_id,
"status": "input_required",
"statusMessage": message,
"createdAt": timestamp,
"lastUpdatedAt": timestamp,
"ttl": ELICIT_TTL_SECONDS * 1000,
},
"_meta": {
"modelcontextprotocol.io/related-task": {
"taskId": task_id,
"status": "input_required",
Expand All @@ -125,49 +145,87 @@ async def elicit_for_task(
},
}
},
)
}

# Push notification to Redis queue (works from any process)
# Server's subscriber loop will forward to client
from fastmcp.server.tasks.notifications import push_notification

# Send notification (best effort - task status is stored in Redis)
# Log failures for debugging but don't fail the elicitation
try:
await session.send_notification(notification) # type: ignore[arg-type]
await push_notification(session_id, notification_dict, docket)
except Exception as e:
# Fail fast: if notification can't be queued, client won't know to respond
# Return cancel immediately rather than waiting for 1-hour timeout
logger.warning(
"Failed to send input_required notification for task %s: %s",
"Failed to queue input_required notification for task %s, cancelling elicitation: %s",
task_id,
e,
)
# Best-effort cleanup
try:
async with docket.redis() as redis:
await redis.delete(
docket.key(request_key),
docket.key(status_key),
)
except Exception:
pass # Keys will expire via TTL
return mcp.types.ElicitResult(action="cancel", content=None)

# Wait for response (poll Redis)
# In a production implementation, this could use Redis pub/sub for lower latency
# Wait for response using BLPOP (blocking pop)
# This is much more efficient than polling - single Redis round-trip
# that blocks until a response is pushed, vs 7,200 round-trips/hour with polling
max_wait_seconds = ELICIT_TTL_SECONDS
poll_interval = 0.5 # seconds

for _ in range(int(max_wait_seconds / poll_interval)):
try:
async with docket.redis() as redis:
response_data = await redis.get(docket.key(response_key))
if response_data:
# BLPOP blocks until an item is pushed to the list or timeout
# Returns tuple of (key, value) or None on timeout
result = await cast(
Any,
redis.blpop(
[docket.key(response_key)],
timeout=max_wait_seconds,
),
)

if result:
# result is (key, value) tuple
_key, response_data = result
response = json.loads(response_data)

# Clean up Redis keys
await redis.delete(
docket.key(request_key),
docket.key(response_key),
docket.key(status_key),
)

# Convert to ElicitResult
return mcp.types.ElicitResult(
action=response.get("action", "accept"),
content=response.get("content"),
)
except Exception as e:
logger.warning(
"BLPOP failed for task %s elicitation, falling back to cancel: %s",
task_id,
e,
)

await asyncio.sleep(poll_interval)

# Timeout - treat as cancellation
async with docket.redis() as redis:
await redis.delete(
docket.key(request_key),
docket.key(response_key),
docket.key(status_key),
# Timeout or error - treat as cancellation
# Best-effort cleanup - if Redis is unavailable, keys will expire via TTL
try:
async with docket.redis() as redis:
await redis.delete(
docket.key(request_key),
docket.key(response_key),
docket.key(status_key),
)
except Exception as cleanup_error:
logger.debug(
"Failed to clean up elicitation keys for task %s (will expire via TTL): %s",
task_id,
cleanup_error,
)

return mcp.types.ElicitResult(action="cancel", content=None)
Expand Down Expand Up @@ -213,12 +271,15 @@ async def handle_task_input(
if status is None or status.decode("utf-8") != "waiting":
return False

# Store the response
await redis.set(
# Push response to list - this wakes up the BLPOP in elicit_for_task
# Using LPUSH instead of SET enables the efficient blocking wait pattern
await redis.lpush( # type: ignore[invalid-await] # redis-py union type (sync/async)
docket.key(response_key),
json.dumps(response),
ex=ELICIT_TTL_SECONDS,
)
# Set TTL on the response list (in case BLPOP doesn't consume it)
await redis.expire(docket.key(response_key), ELICIT_TTL_SECONDS)

# Update status to "responded"
await redis.set(
docket.key(status_key),
Expand Down
63 changes: 52 additions & 11 deletions src/fastmcp/server/tasks/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
from fastmcp.server.dependencies import _current_docket, get_context
from fastmcp.server.tasks.config import TaskMeta
from fastmcp.server.tasks.keys import build_task_key
from fastmcp.utilities.logging import get_logger

if TYPE_CHECKING:
from fastmcp.prompts.prompt import Prompt
from fastmcp.resources.resource import Resource
from fastmcp.resources.template import ResourceTemplate
from fastmcp.tools.tool import Tool

logger = get_logger(__name__)

# Redis mapping TTL buffer: Add 15 minutes to Docket's execution_ttl
TASK_MAPPING_TTL_BUFFER_SECONDS = 15 * 60

Expand Down Expand Up @@ -109,21 +112,31 @@ async def submit_to_docket(

register_task_session(session_id, ctx.session)

# Send notifications/tasks/created per SEP-1686 (mandatory)
# Send BEFORE queuing to avoid race where task completes before notification
notification = mcp.types.JSONRPCNotification(
jsonrpc="2.0",
method="notifications/tasks/created",
params={}, # Empty params per spec
_meta={ # type: ignore[call-arg] # _meta is Pydantic alias for meta field
"modelcontextprotocol.io/related-task": {
# Send an initial tasks/status notification before queueing.
# This guarantees clients can observe task creation immediately.
notification = mcp.types.TaskStatusNotification.model_validate(
{
"method": "notifications/tasks/status",
"params": {
"taskId": server_task_id,
}
},
"status": "working",
"statusMessage": "Task submitted",
"createdAt": created_at,
"lastUpdatedAt": created_at,
"ttl": ttl_ms,
"pollInterval": poll_interval_ms,
},
"_meta": {
"modelcontextprotocol.io/related-task": {
"taskId": server_task_id,
}
},
}
)
server_notification = mcp.types.ServerNotification(notification)
with suppress(Exception):
# Don't let notification failures break task creation
await ctx.session.send_notification(notification) # type: ignore[arg-type]
await ctx.session.send_notification(server_notification)

# Queue function to Docket by key (result storage via execution_ttl)
# Use component.add_to_docket() which handles calling conventions
Expand Down Expand Up @@ -151,6 +164,34 @@ async def submit_to_docket(
poll_interval_ms,
)

# Start notification subscriber for distributed elicitation (idempotent)
# This enables ctx.elicit() to work when workers run in separate processes
# Subscriber forwards notifications from Redis queue to client session
from fastmcp.server.tasks.notifications import (
ensure_subscriber_running,
stop_subscriber,
)

try:
await ensure_subscriber_running(session_id, ctx.session, docket)

# Register cleanup callback on session exit (once per session)
# This ensures subscriber is stopped when the session disconnects
if (
hasattr(ctx.session, "_exit_stack")
and ctx.session._exit_stack is not None
and not getattr(ctx.session, "_notification_cleanup_registered", False)
):

async def _cleanup_subscriber() -> None:
await stop_subscriber(session_id)

ctx.session._exit_stack.push_async_callback(_cleanup_subscriber)
ctx.session._notification_cleanup_registered = True # type: ignore[attr-defined]
except Exception as e:
# Non-fatal: elicitation will still work via polling fallback
logger.debug("Failed to start notification subscriber: %s", e)

# Return CreateTaskResult with proper Task object
# Tasks MUST begin in "working" status per SEP-1686 final spec (line 381)
return mcp.types.CreateTaskResult(
Expand Down
Loading