Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
Implementation,
ResourceContentBlock,
SseMcpServer,
StdioMcpServer,
McpServerStdio,
TextContentBlock,
)

Expand Down Expand Up @@ -66,7 +66,7 @@ async def authenticate(self, method_id: str, **kwargs: Any) -> AuthenticateRespo
return AuthenticateResponse()

async def new_session(
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], **kwargs: Any
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any
) -> NewSessionResponse:
logging.info("Received new session request")
session_id = str(self._next_session_id)
Expand All @@ -75,7 +75,7 @@ async def new_session(
return NewSessionResponse(session_id=session_id, modes=None)

async def load_session(
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], session_id: str, **kwargs: Any
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any
) -> LoadSessionResponse | None:
logging.info("Received load session request %s", session_id)
self._sessions.add(session_id)
Expand Down
4 changes: 2 additions & 2 deletions examples/echo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Implementation,
ResourceContentBlock,
SseMcpServer,
StdioMcpServer,
McpServerStdio,
TextContentBlock,
)

Expand All @@ -43,7 +43,7 @@ async def initialize(
return InitializeResponse(protocol_version=protocol_version)

async def new_session(
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], **kwargs: Any
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any
) -> NewSessionResponse:
return NewSessionResponse(session_id=uuid4().hex)

Expand Down
2 changes: 1 addition & 1 deletion schema/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
refs/tags/v0.6.3
refs/tags/v0.7.0
1 change: 1 addition & 0 deletions schema/meta.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"authenticate": "authenticate",
"initialize": "initialize",
"session_cancel": "session/cancel",
"session_list": "session/list",
"session_load": "session/load",
"session_new": "session/new",
"session_prompt": "session/prompt",
Expand Down
1,839 changes: 1,170 additions & 669 deletions schema/schema.json

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions scripts/gen_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def resolve_ref(version: str | None) -> str:

def download_schema(repo: str, ref: str) -> None:
SCHEMA_DIR.mkdir(parents=True, exist_ok=True)
schema_url = f"https://raw.githubusercontent.com/{repo}/{ref}/schema/schema.json"
meta_url = f"https://raw.githubusercontent.com/{repo}/{ref}/schema/meta.json"
schema_url = f"https://raw.githubusercontent.com/{repo}/{ref}/schema/schema.unstable.json"
meta_url = f"https://raw.githubusercontent.com/{repo}/{ref}/schema/meta.unstable.json"
try:
schema_data = fetch_json(schema_url)
meta_data = fetch_json(meta_url)
Expand Down
17 changes: 10 additions & 7 deletions scripts/gen_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

STDIO_TYPE_LITERAL = 'Literal["2#-datamodel-code-generator-#-object-#-special-#"]'
STDIO_TYPE_PATTERN = re.compile(
r"^ type:\s*Literal\[['\"]2#-datamodel-code-generator-#-object-#-special-#['\"]\]"
r"^ type:\s*Literal\[['\"]McpServerStdio['\"]\]"
r"(?:\s*=\s*['\"][^'\"]+['\"])?\s*$",
re.MULTILINE,
)
Expand All @@ -40,7 +40,6 @@
"AgentOutgoingMessage2": "AgentResponseMessage",
"AgentOutgoingMessage3": "AgentErrorMessage",
"AgentOutgoingMessage4": "AgentNotificationMessage",
"AvailableCommandInput1": "CommandInputHint",
"ClientOutgoingMessage1": "ClientRequestMessage",
"ClientOutgoingMessage2": "ClientResponseMessage",
"ClientOutgoingMessage3": "ClientErrorMessage",
Expand All @@ -52,7 +51,6 @@
"ContentBlock5": "EmbeddedResourceContentBlock",
"McpServer1": "HttpMcpServer",
"McpServer2": "SseMcpServer",
"McpServer3": "StdioMcpServer",
"RequestPermissionOutcome1": "DeniedOutcome",
"RequestPermissionOutcome2": "AllowedOutcome",
"SessionUpdate1": "UserMessageChunk",
Expand All @@ -68,6 +66,10 @@
"ToolCallContent3": "TerminalToolCallContent",
}

ALIASES_MAP = {
"StdioMcpServer": "McpServerStdio",
}

ENUM_LITERAL_MAP: dict[str, tuple[str, ...]] = {
"PermissionOptionKind": (
"allow_once",
Expand All @@ -87,16 +89,15 @@
("PlanEntry", "priority", "PlanEntryPriority", False),
("PlanEntry", "status", "PlanEntryStatus", False),
("PromptResponse", "stop_reason", "StopReason", False),
("ToolCallProgress", "kind", "ToolKind", True),
("ToolCallProgress", "status", "ToolCallStatus", True),
("ToolCallStart", "kind", "ToolKind", True),
("ToolCallStart", "status", "ToolCallStatus", True),
("ToolCall", "kind", "ToolKind", True),
("ToolCall", "status", "ToolCallStatus", True),
("ToolCallUpdate", "kind", "ToolKind", True),
("ToolCallUpdate", "status", "ToolCallStatus", True),
)

DEFAULT_VALUE_OVERRIDES: tuple[tuple[str, str, str], ...] = (
("AgentCapabilities", "mcp_capabilities", "McpCapabilities()"),
("AgentCapabilities", "session_capabilities", "SessionCapabilities()"),
(
"AgentCapabilities",
"prompt_capabilities",
Expand Down Expand Up @@ -222,6 +223,7 @@ def _build_header_block() -> str:

def _build_alias_block() -> str:
alias_lines = [f"{old} = {new}" for old, new in sorted(RENAME_MAP.items())]
alias_lines += [f"{old} = {new}" for old, new in sorted(ALIASES_MAP.items())]
return BACKCOMPAT_MARKER + "\n" + "\n".join(alias_lines) + "\n"


Expand Down Expand Up @@ -421,6 +423,7 @@ def _normalize_stdio_model(content: str) -> str:
replacement_line = ' type: Literal["stdio"] = "stdio"'
new_content, count = STDIO_TYPE_PATTERN.subn(replacement_line, content)
if count == 0:
print("Warning: stdio type placeholder not found; no replacements made.", file=sys.stderr)
return content
if count > 1:
print(
Expand Down
8 changes: 5 additions & 3 deletions src/acp/agent/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
SessionNotification,
TerminalOutputRequest,
TerminalOutputResponse,
ToolCall,
ToolCallProgress,
ToolCallStart,
ToolCallUpdate,
UserMessageChunk,
WaitForTerminalExitRequest,
WaitForTerminalExitResponse,
Expand All @@ -56,12 +56,14 @@ def __init__(
input_stream: Any,
output_stream: Any,
listening: bool = True,
*,
use_unstable_protocol: bool = False,
**connection_kwargs: Any,
) -> None:
agent = to_agent(cast(Client, self)) if callable(to_agent) else to_agent
if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader):
raise TypeError(_AGENT_CONNECTION_ERROR)
handler = build_agent_router(cast(Agent, agent))
handler = build_agent_router(cast(Agent, agent), use_unstable_protocol=use_unstable_protocol)
self._conn = Connection(handler, input_stream, output_stream, listening=listening, **connection_kwargs)
if on_connect := getattr(agent, "on_connect", None):
on_connect(self)
Expand Down Expand Up @@ -92,7 +94,7 @@ async def session_update(

@param_model(RequestPermissionRequest)
async def request_permission(
self, options: list[PermissionOption], session_id: str, tool_call: ToolCall, **kwargs: Any
self, options: list[PermissionOption], session_id: str, tool_call: ToolCallUpdate, **kwargs: Any
) -> RequestPermissionResponse:
return await request_model(
self._conn,
Expand Down
7 changes: 5 additions & 2 deletions src/acp/agent/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AuthenticateRequest,
CancelNotification,
InitializeRequest,
ListSessionsRequest,
LoadSessionRequest,
NewSessionRequest,
PromptRequest,
Expand All @@ -21,8 +22,8 @@
__all__ = ["build_agent_router"]


def build_agent_router(agent: Agent) -> MessageRouter:
router = MessageRouter()
def build_agent_router(agent: Agent, use_unstable_protocol: bool = False) -> MessageRouter:
router = MessageRouter(use_unstable_protocol=use_unstable_protocol)

router.route_request(AGENT_METHODS["initialize"], InitializeRequest, agent, "initialize")
router.route_request(AGENT_METHODS["session_new"], NewSessionRequest, agent, "new_session")
Expand All @@ -33,6 +34,7 @@ def build_agent_router(agent: Agent) -> MessageRouter:
"load_session",
adapt_result=normalize_result,
)
router.route_request(AGENT_METHODS["session_list"], ListSessionsRequest, agent, "list_sessions", unstable=True)
router.route_request(
AGENT_METHODS["session_set_mode"],
SetSessionModeRequest,
Expand All @@ -47,6 +49,7 @@ def build_agent_router(agent: Agent) -> MessageRouter:
agent,
"set_session_model",
adapt_result=normalize_result,
unstable=True,
)
router.route_request(
AGENT_METHODS["authenticate"],
Expand Down
23 changes: 19 additions & 4 deletions src/acp/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
Implementation,
InitializeRequest,
InitializeResponse,
ListSessionsRequest,
ListSessionsResponse,
LoadSessionRequest,
LoadSessionResponse,
McpServerStdio,
NewSessionRequest,
NewSessionResponse,
PromptRequest,
Expand All @@ -31,7 +34,6 @@
SetSessionModeRequest,
SetSessionModeResponse,
SseMcpServer,
StdioMcpServer,
TextContentBlock,
)
from ..utils import compatible_class, notify_model, param_model, request_model, request_model_from_dict
Expand All @@ -51,12 +53,14 @@ def __init__(
to_client: Callable[[Agent], Client] | Client,
input_stream: Any,
output_stream: Any,
*,
use_unstable_protocol: bool = False,
**connection_kwargs: Any,
) -> None:
if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader):
raise TypeError(_CLIENT_CONNECTION_ERROR)
client = to_client(cast(Agent, self)) if callable(to_client) else to_client
handler = build_client_router(cast(Client, client))
handler = build_client_router(cast(Client, client), use_unstable_protocol=use_unstable_protocol)
self._conn = Connection(handler, input_stream, output_stream, **connection_kwargs)
if on_connect := getattr(client, "on_connect", None):
on_connect(self)
Expand All @@ -83,7 +87,7 @@ async def initialize(

@param_model(NewSessionRequest)
async def new_session(
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], **kwargs: Any
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any
) -> NewSessionResponse:
return await request_model(
self._conn,
Expand All @@ -94,7 +98,7 @@ async def new_session(

@param_model(LoadSessionRequest)
async def load_session(
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], session_id: str, **kwargs: Any
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any
) -> LoadSessionResponse:
return await request_model_from_dict(
self._conn,
Expand All @@ -103,6 +107,17 @@ async def load_session(
LoadSessionResponse,
)

@param_model(ListSessionsRequest)
async def list_sessions(
self, cursor: str | None = None, cwd: str | None = None, **kwargs: Any
) -> ListSessionsResponse:
return await request_model_from_dict(
self._conn,
AGENT_METHODS["session_list"],
ListSessionsRequest(cursor=cursor, cwd=cwd, field_meta=kwargs or None),
ListSessionsResponse,
)

@param_model(SetSessionModeRequest)
async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) -> SetSessionModeResponse:
return await request_model_from_dict(
Expand Down
4 changes: 2 additions & 2 deletions src/acp/client/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
__all__ = ["build_client_router"]


def build_client_router(client: Client) -> MessageRouter:
router = MessageRouter()
def build_client_router(client: Client, use_unstable_protocol: bool = False) -> MessageRouter:
router = MessageRouter(use_unstable_protocol=use_unstable_protocol)

router.route_request(CLIENT_METHODS["fs_write_text_file"], WriteTextFileRequest, client, "write_text_file")
router.route_request(CLIENT_METHODS["fs_read_text_file"], ReadTextFileRequest, client, "read_text_file")
Expand Down
4 changes: 2 additions & 2 deletions src/acp/contrib/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

from ..helpers import text_block, tool_content
from ..schema import PermissionOption, RequestPermissionRequest, RequestPermissionResponse, ToolCall
from ..schema import PermissionOption, RequestPermissionRequest, RequestPermissionResponse, ToolCallUpdate
from .tool_calls import ToolCallTracker, _copy_model_list


Expand Down Expand Up @@ -60,7 +60,7 @@ async def request_for(
description: str | None = None,
options: Sequence[PermissionOption] | None = None,
content: Sequence[Any] | None = None,
tool_call: ToolCall | None = None,
tool_call: ToolCallUpdate | None = None,
) -> RequestPermissionResponse:
"""Request user approval for a tool call."""
if tool_call is None:
Expand Down
15 changes: 11 additions & 4 deletions src/acp/contrib/tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
from pydantic import BaseModel, ConfigDict

from ..helpers import text_block, tool_content
from ..schema import ToolCall, ToolCallLocation, ToolCallProgress, ToolCallStart, ToolCallStatus, ToolKind
from ..schema import (
ToolCallLocation,
ToolCallProgress,
ToolCallStart,
ToolCallStatus,
ToolCallUpdate,
ToolKind,
)


class _MissingToolCallTitleError(ValueError):
Expand Down Expand Up @@ -91,8 +98,8 @@ def to_view(self) -> TrackedToolCallView:
raw_output=self.raw_output,
)

def to_tool_call_model(self) -> ToolCall:
return ToolCall(
def to_tool_call_model(self) -> ToolCallUpdate:
return ToolCallUpdate(
tool_call_id=self.tool_call_id,
title=self.title,
kind=self.kind,
Expand Down Expand Up @@ -249,7 +256,7 @@ def view(self, external_id: str) -> TrackedToolCallView:
state = self._require_call(external_id)
return state.to_view()

def tool_call_model(self, external_id: str) -> ToolCall:
def tool_call_model(self, external_id: str) -> ToolCallUpdate:
"""Return a deep copy of the tool call suitable for permission requests."""
state = self._require_call(external_id)
return state.to_tool_call_model()
Expand Down
Loading