Skip to content
Merged
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies = [
"python-dotenv>=1.1.0",
"exceptiongroup>=1.2.2",
"httpx>=0.28.1",
"mcp>=1.19.0,<2.0.0,!=1.21.1",
"mcp @ git+https://github.com/modelcontextprotocol/python-sdk.git@maxisbey/SEP-1686_Tasks",
"openapi-pydantic>=0.5.1",
"platformdirs>=4.0.0",
"pydocket>=0.14.0",
Expand Down
209 changes: 50 additions & 159 deletions src/fastmcp/client/_temporary_sep_1686_shims.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
DO NOT WRITE TESTS FOR THIS FILE - these are temporary hacks.
"""

from __future__ import annotations

import datetime
import weakref
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal

import mcp.types
Expand All @@ -27,7 +25,6 @@
_default_list_roots_callback,
_default_sampling_callback,
)
from pydantic import BaseModel

from fastmcp.client.messages import Message, MessageHandler

Expand All @@ -40,64 +37,6 @@
# ═══════════════════════════════════════════════════════════════════════════


class TaskCapableClientSession(ClientSession):
"""Custom ClientSession that declares task capability.

Overrides initialize() to set experimental={"tasks": {}} in ClientCapabilities.
"""

async def initialize(self) -> mcp.types.InitializeResult:
"""Initialize with task capability declaration."""
# Build capabilities
sampling = (
mcp.types.SamplingCapability()
if self._sampling_callback != _default_sampling_callback
else None
)
elicitation = (
mcp.types.ElicitationCapability()
if self._elicitation_callback != _default_elicitation_callback
else None
)
roots = (
mcp.types.RootsCapability(listChanged=True)
if self._list_roots_callback != _default_list_roots_callback
else None
)

# Send initialize request with task capability
result = await self.send_request(
mcp.types.ClientRequest(
mcp.types.InitializeRequest(
params=mcp.types.InitializeRequestParams(
protocolVersion=mcp.types.LATEST_PROTOCOL_VERSION,
capabilities=mcp.types.ClientCapabilities(
sampling=sampling,
elicitation=elicitation,
experimental={"tasks": {}},
roots=roots,
),
clientInfo=self._client_info,
),
)
),
mcp.types.InitializeResult,
)

# Validate protocol version
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError(
f"Unsupported protocol version from the server: {result.protocolVersion}"
)

# Send initialized notification
await self.send_notification(
mcp.types.ClientNotification(mcp.types.InitializedNotification())
)

return result


async def task_capable_initialize(
session: ClientSession,
) -> mcp.types.InitializeResult:
Expand Down Expand Up @@ -160,128 +99,80 @@ async def task_capable_initialize(


# ═══════════════════════════════════════════════════════════════════════════
# 2. Task Protocol Types (SDK doesn't have these yet)
# 2. Client-Side Type Helpers
# ═══════════════════════════════════════════════════════════════════════════


class TasksGetRequest(BaseModel):
"""Request for tasks/get MCP method."""

method: Literal["tasks/get"] = "tasks/get"
params: TasksGetParams


class TasksGetParams(BaseModel):
"""Parameters for tasks/get request."""

taskId: str
_meta: dict[str, Any] | None = None


class TasksGetResult(BaseModel):
"""Result from tasks/get MCP method."""
class TaskStatusResponse(pydantic.BaseModel):
"""Response from tasks/get endpoint."""

taskId: str
task_id: str = pydantic.Field(alias="taskId")
status: Literal["working", "input_required", "completed", "failed", "cancelled"]
createdAt: str
ttl: int | None = None
pollInterval: int | None = None


class TasksResultRequest(BaseModel):
"""Request for tasks/result MCP method."""

method: Literal["tasks/result"] = "tasks/result"
params: TasksResultParams


class TasksResultParams(BaseModel):
"""Parameters for tasks/result request."""

taskId: str
_meta: dict[str, Any] | None = None


class TasksListRequest(BaseModel):
"""Request for tasks/list MCP method."""

method: Literal["tasks/list"] = "tasks/list"
params: TasksListParams


class TasksListParams(BaseModel):
"""Parameters for tasks/list request."""

cursor: str | None = None
limit: int = 50
_meta: dict[str, Any] | None = None


class TasksListResult(BaseModel):
"""Result from tasks/list MCP method."""

tasks: list[dict[str, Any]]
nextCursor: str | None = None


class TasksDeleteRequest(BaseModel):
"""Request for tasks/delete MCP method."""

method: Literal["tasks/delete"] = "tasks/delete"
params: TasksDeleteParams


class TasksDeleteParams(BaseModel):
"""Parameters for tasks/delete request."""

taskId: str
_meta: dict[str, Any] | None = None
created_at: datetime.datetime = pydantic.Field(alias="createdAt")
ttl: int | None = pydantic.Field(default=None, alias="ttl")
poll_interval: int | None = pydantic.Field(default=None, alias="pollInterval")
status_message: str | None = pydantic.Field(default=None, alias="statusMessage")

model_config = pydantic.ConfigDict(populate_by_name=True)

class TasksDeleteResult(BaseModel):
"""Result from tasks/delete MCP method."""

_meta: dict[str, Any] | None = None
class TasksResponse(pydantic.BaseModel):
"""Generic response wrapper for task protocol methods.

SEP-1686 task responses are dicts that can represent CallToolResult,
GetPromptResult, or ReadResourceResult. This wrapper just passes
through the raw dict.
"""

# ═══════════════════════════════════════════════════════════════════════════
# 3. Client-Side Type Helpers
# ═══════════════════════════════════════════════════════════════════════════
model_config = {"extra": "allow"}

@classmethod
def model_validate(cls, obj: Any) -> Any:
"""Parse response dict back into appropriate MCP type.

@dataclass
class CallToolResult:
"""Parsed result from a tool call."""
The server sends MCP result objects (CallToolResult, GetPromptResult,
ReadResourceResult) serialized as dicts. We parse them back for the client.
"""
if not isinstance(obj, dict):
return obj

content: list[mcp.types.ContentBlock]
structured_content: dict[str, Any] | None
meta: dict[str, Any] | None
data: Any = None
is_error: bool = False
# Try to detect and parse the result type based on structure
import mcp.types

# Check for tool result (has 'content' field)
if "content" in obj:
try:
return mcp.types.CallToolResult.model_validate(obj)
except Exception:
pass

class TaskStatusResponse(pydantic.BaseModel):
"""Response from tasks/get endpoint."""
# Check for prompt result (has 'messages' field)
if "messages" in obj:
try:
return mcp.types.GetPromptResult.model_validate(obj)
except Exception:
pass

task_id: str = pydantic.Field(alias="taskId")
status: Literal["working", "input_required", "completed", "failed", "cancelled"]
created_at: str = pydantic.Field(alias="createdAt")
ttl: int | None = pydantic.Field(default=None, alias="ttl")
poll_interval: int | None = pydantic.Field(default=None, alias="pollInterval")
status_message: str | None = pydantic.Field(default=None, alias="statusMessage")
# Check for resource result (has 'contents' field)
if "contents" in obj:
try:
return mcp.types.ReadResourceResult.model_validate(obj)
except Exception:
pass

model_config = pydantic.ConfigDict(populate_by_name=True)
# Fall back to returning dict as-is
return obj


# ═══════════════════════════════════════════════════════════════════════════
# 4. Task Notification Routing
# 3. Task Notification Routing
# ═══════════════════════════════════════════════════════════════════════════


class ClientMessageHandler(MessageHandler):
class TaskNotificationHandler(MessageHandler):
"""MessageHandler that routes task status notifications to Task objects."""

def __init__(self, client: Client):
def __init__(self, client: "Client"):
super().__init__()
self._client_ref: weakref.ref[Client] = weakref.ref(client)

Expand Down
Loading
Loading