|
1 | | -import os |
2 | 1 | import sys |
3 | 2 | from contextlib import asynccontextmanager |
4 | | -from pathlib import Path |
5 | | -from typing import Literal, TextIO |
| 3 | +from typing import TextIO |
6 | 4 |
|
7 | | -import anyio |
8 | | -import anyio.lowlevel |
9 | | -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
10 | | -from anyio.streams.text import TextReceiveStream |
11 | | -from pydantic import BaseModel, Field |
12 | | - |
13 | | -import mcp.types as types |
14 | | -from mcp.shared.message import SessionMessage |
15 | | - |
16 | | -from .win32 import ( |
17 | | - create_windows_process, |
18 | | - get_windows_executable_command, |
19 | | - terminate_windows_process, |
20 | | -) |
21 | | - |
22 | | -# Environment variables to inherit by default |
23 | | -DEFAULT_INHERITED_ENV_VARS = ( |
24 | | - [ |
25 | | - "APPDATA", |
26 | | - "HOMEDRIVE", |
27 | | - "HOMEPATH", |
28 | | - "LOCALAPPDATA", |
29 | | - "PATH", |
30 | | - "PROCESSOR_ARCHITECTURE", |
31 | | - "SYSTEMDRIVE", |
32 | | - "SYSTEMROOT", |
33 | | - "TEMP", |
34 | | - "USERNAME", |
35 | | - "USERPROFILE", |
36 | | - ] |
37 | | - if sys.platform == "win32" |
38 | | - else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"] |
39 | | -) |
40 | | - |
41 | | - |
42 | | -def get_default_environment() -> dict[str, str]: |
43 | | - """ |
44 | | - Returns a default environment object including only environment variables deemed |
45 | | - safe to inherit. |
46 | | - """ |
47 | | - env: dict[str, str] = {} |
48 | | - |
49 | | - for key in DEFAULT_INHERITED_ENV_VARS: |
50 | | - value = os.environ.get(key) |
51 | | - if value is None: |
52 | | - continue |
53 | | - |
54 | | - if value.startswith("()"): |
55 | | - # Skip functions, which are a security risk |
56 | | - continue |
57 | | - |
58 | | - env[key] = value |
59 | | - |
60 | | - return env |
61 | | - |
62 | | - |
63 | | -class StdioServerParameters(BaseModel): |
64 | | - command: str |
65 | | - """The executable to run to start the server.""" |
66 | | - |
67 | | - args: list[str] = Field(default_factory=list) |
68 | | - """Command line arguments to pass to the executable.""" |
69 | | - |
70 | | - env: dict[str, str] | None = None |
71 | | - """ |
72 | | - The environment to use when spawning the process. |
73 | | -
|
74 | | - If not specified, the result of get_default_environment() will be used. |
75 | | - """ |
76 | | - |
77 | | - cwd: str | Path | None = None |
78 | | - """The working directory to use when spawning the process.""" |
79 | | - |
80 | | - encoding: str = "utf-8" |
81 | | - """ |
82 | | - The text encoding used when sending/receiving messages to the server |
83 | | -
|
84 | | - defaults to utf-8 |
85 | | - """ |
86 | | - |
87 | | - encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict" |
88 | | - """ |
89 | | - The text encoding error handler. |
90 | | -
|
91 | | - See https://docs.python.org/3/library/codecs.html#codec-base-classes for |
92 | | - explanations of possible values |
93 | | - """ |
| 5 | +# Import from the new files |
| 6 | +from .parameters import StdioServerParameters |
| 7 | +from .transport import StdioClientTransport |
94 | 8 |
|
95 | 9 |
|
96 | 10 | @asynccontextmanager |
97 | 11 | async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr): |
98 | 12 | """ |
99 | | - Client transport for stdio: this will connect to a server by spawning a |
100 | | - process and communicating with it over stdin/stdout. |
| 13 | + Client transport for stdio: connects to a server by spawning a process |
| 14 | + and communicating with it over stdin/stdout, managed by StdioClientTransport. |
101 | 15 | """ |
102 | | - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] |
103 | | - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] |
104 | | - |
105 | | - write_stream: MemoryObjectSendStream[SessionMessage] |
106 | | - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] |
107 | | - |
108 | | - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) |
109 | | - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) |
110 | | - |
111 | | - command = _get_executable_command(server.command) |
112 | | - |
113 | | - # Open process with stderr piped for capture |
114 | | - process = await _create_platform_compatible_process( |
115 | | - command=command, |
116 | | - args=server.args, |
117 | | - env=( |
118 | | - {**get_default_environment(), **server.env} |
119 | | - if server.env is not None |
120 | | - else get_default_environment() |
121 | | - ), |
122 | | - errlog=errlog, |
123 | | - cwd=server.cwd, |
124 | | - ) |
125 | | - |
126 | | - async def stdout_reader(): |
127 | | - assert process.stdout, "Opened process is missing stdout" |
128 | | - |
129 | | - try: |
130 | | - async with read_stream_writer: |
131 | | - buffer = "" |
132 | | - async for chunk in TextReceiveStream( |
133 | | - process.stdout, |
134 | | - encoding=server.encoding, |
135 | | - errors=server.encoding_error_handler, |
136 | | - ): |
137 | | - lines = (buffer + chunk).split("\n") |
138 | | - buffer = lines.pop() |
139 | | - |
140 | | - for line in lines: |
141 | | - try: |
142 | | - message = types.JSONRPCMessage.model_validate_json(line) |
143 | | - except Exception as exc: |
144 | | - await read_stream_writer.send(exc) |
145 | | - continue |
146 | | - |
147 | | - session_message = SessionMessage(message) |
148 | | - await read_stream_writer.send(session_message) |
149 | | - except anyio.ClosedResourceError: |
150 | | - await anyio.lowlevel.checkpoint() |
151 | | - |
152 | | - async def stdin_writer(): |
153 | | - assert process.stdin, "Opened process is missing stdin" |
| 16 | + transport = StdioClientTransport(server_params=server, errlog=errlog) |
| 17 | + async with transport as streams: |
| 18 | + yield streams |
154 | 19 |
|
155 | | - try: |
156 | | - async with write_stream_reader: |
157 | | - async for session_message in write_stream_reader: |
158 | | - json = session_message.message.model_dump_json( |
159 | | - by_alias=True, exclude_none=True |
160 | | - ) |
161 | | - await process.stdin.send( |
162 | | - (json + "\n").encode( |
163 | | - encoding=server.encoding, |
164 | | - errors=server.encoding_error_handler, |
165 | | - ) |
166 | | - ) |
167 | | - except anyio.ClosedResourceError: |
168 | | - await anyio.lowlevel.checkpoint() |
169 | | - |
170 | | - async with ( |
171 | | - anyio.create_task_group() as tg, |
172 | | - process, |
173 | | - ): |
174 | | - tg.start_soon(stdout_reader) |
175 | | - tg.start_soon(stdin_writer) |
176 | | - try: |
177 | | - yield read_stream, write_stream |
178 | | - finally: |
179 | | - # Clean up process to prevent any dangling orphaned processes |
180 | | - if sys.platform == "win32": |
181 | | - await terminate_windows_process(process) |
182 | | - else: |
183 | | - process.terminate() |
184 | | - await read_stream.aclose() |
185 | | - await write_stream.aclose() |
186 | | - |
187 | | - |
188 | | -def _get_executable_command(command: str) -> str: |
189 | | - """ |
190 | | - Get the correct executable command normalized for the current platform. |
191 | | -
|
192 | | - Args: |
193 | | - command: Base command (e.g., 'uvx', 'npx') |
194 | | -
|
195 | | - Returns: |
196 | | - str: Platform-appropriate command |
197 | | - """ |
198 | | - if sys.platform == "win32": |
199 | | - return get_windows_executable_command(command) |
200 | | - else: |
201 | | - return command |
202 | | - |
203 | | - |
204 | | -async def _create_platform_compatible_process( |
205 | | - command: str, |
206 | | - args: list[str], |
207 | | - env: dict[str, str] | None = None, |
208 | | - errlog: TextIO = sys.stderr, |
209 | | - cwd: Path | str | None = None, |
210 | | -): |
211 | | - """ |
212 | | - Creates a subprocess in a platform-compatible way. |
213 | | - Returns a process handle. |
214 | | - """ |
215 | | - if sys.platform == "win32": |
216 | | - process = await create_windows_process(command, args, env, errlog, cwd) |
217 | | - else: |
218 | | - process = await anyio.open_process( |
219 | | - [command, *args], env=env, stderr=errlog, cwd=cwd |
220 | | - ) |
221 | 20 |
|
222 | | - return process |
| 21 | +# Ensure __all__ or exports are updated if this was a public API change, though |
| 22 | +# stdio_client itself remains the primary public entry point from this file. |
0 commit comments