diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index d4402f1258..d062fd938f 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -2,7 +2,6 @@ import base64 import json -import warnings from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Sequence from contextlib import AsyncExitStack, asynccontextmanager @@ -12,6 +11,7 @@ from types import TracebackType from typing import Any +import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.shared.message import SessionMessage from mcp.types import ( @@ -77,6 +77,9 @@ def _get_log_level(self) -> LoggingLevel | None: """Get the log level for the MCP server.""" raise NotImplementedError('MCP Server subclasses must implement this method.') + def _get_client_initialize_timeout(self) -> float: + return 5 # pragma: no cover + def get_prefixed_tool_name(self, tool_name: str) -> str: """Get the tool name with prefix if `tool_prefix` is set.""" return f'{self.tool_prefix}_{tool_name}' if self.tool_prefix else tool_name @@ -136,7 +139,9 @@ async def __aenter__(self) -> Self: client = ClientSession(read_stream=self._read_stream, write_stream=self._write_stream) self._client = await self._exit_stack.enter_async_context(client) - await self._client.initialize() + with anyio.fail_after(self._get_client_initialize_timeout()): + await self._client.initialize() + if log_level := self._get_log_level(): await self._client.set_logging_level(log_level) self.is_running = True @@ -251,6 +256,9 @@ async def main(): e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar` """ + timeout: float = 5 + """ The timeout in seconds to wait for the client to initialize.""" + @asynccontextmanager async def client_streams( self, @@ -267,6 +275,9 @@ def _get_log_level(self) -> LoggingLevel | None: def __repr__(self) -> str: return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})' + def _get_client_initialize_timeout(self) -> float: + return self.timeout + @dataclass class MCPServerHTTP(MCPServer): @@ -312,15 +323,15 @@ async def main(): Useful for authentication, custom headers, or other HTTP-specific configurations. """ - timeout: timedelta | float = timedelta(seconds=5) - """Initial connection timeout as a timedelta for establishing the connection. + timeout: float = 5 + """Initial connection timeout in seconds for establishing the connection. This timeout applies to the initial connection setup and handshake. If the connection cannot be established within this time, the operation will fail. """ - sse_read_timeout: timedelta | float = timedelta(minutes=5) - """Maximum time as a timedelta to wait for new SSE messages before timing out. + sse_read_timeout: float = 300 + """Maximum time as in seconds to wait for new SSE messages before timing out. This timeout applies to the long-lived SSE connection after it's established. If no new messages are received within this time, the connection will be considered stale @@ -343,21 +354,14 @@ async def main(): """ def __post_init__(self): - if not isinstance(self.timeout, timedelta): - warnings.warn( - 'Passing timeout as a float has been deprecated, please use a timedelta instead.', - DeprecationWarning, - stacklevel=2, - ) - self.timeout = timedelta(seconds=self.timeout) + # streamablehttp_client expects timedeltas, so we accept them too to match, + # but primarily work with floats for a simpler user API. - if not isinstance(self.sse_read_timeout, timedelta): - warnings.warn( - 'Passing sse_read_timeout as a float has been deprecated, please use a timedelta instead.', - DeprecationWarning, - stacklevel=2, - ) - self.sse_read_timeout = timedelta(seconds=self.sse_read_timeout) + if isinstance(self.timeout, timedelta): + self.timeout = self.timeout.total_seconds() + + if isinstance(self.sse_read_timeout, timedelta): + self.sse_read_timeout = self.sse_read_timeout.total_seconds() @asynccontextmanager async def client_streams( @@ -365,24 +369,11 @@ async def client_streams( ) -> AsyncIterator[ tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] ]: # pragma: no cover - if not isinstance(self.timeout, timedelta): - warnings.warn( - 'Passing timeout as a float has been deprecated, please use a timedelta instead.', - DeprecationWarning, - stacklevel=2, - ) - self.timeout = timedelta(seconds=self.timeout) - - if not isinstance(self.sse_read_timeout, timedelta): - warnings.warn( - 'Passing sse_read_timeout as a float has been deprecated, please use a timedelta instead.', - DeprecationWarning, - stacklevel=2, - ) - self.sse_read_timeout = timedelta(seconds=self.sse_read_timeout) - async with streamablehttp_client( - url=self.url, headers=self.headers, timeout=self.timeout, sse_read_timeout=self.sse_read_timeout + url=self.url, + headers=self.headers, + timeout=timedelta(seconds=self.timeout), + sse_read_timeout=timedelta(self.sse_read_timeout), ) as (read_stream, write_stream, _): yield read_stream, write_stream @@ -391,3 +382,6 @@ def _get_log_level(self) -> LoggingLevel | None: def __repr__(self) -> str: # pragma: no cover return f'MCPServerHTTP(url={self.url!r}, tool_prefix={self.tool_prefix!r})' + + def _get_client_initialize_timeout(self) -> float: # pragma: no cover + return self.timeout diff --git a/tests/test_mcp.py b/tests/test_mcp.py index ab3d9b9dbf..ee95885ff2 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -81,30 +81,29 @@ def test_http_server_with_header_and_timeout(): http_server = MCPServerHTTP( url='http://localhost:8000/sse', headers={'my-custom-header': 'my-header-value'}, - timeout=timedelta(seconds=10), - sse_read_timeout=timedelta(seconds=100), + timeout=10, + sse_read_timeout=100, log_level='info', ) assert http_server.url == 'http://localhost:8000/sse' assert http_server.headers is not None and http_server.headers['my-custom-header'] == 'my-header-value' - assert http_server.timeout == timedelta(seconds=10) - assert http_server.sse_read_timeout == timedelta(seconds=100) + assert http_server.timeout == 10 + assert http_server.sse_read_timeout == 100 assert http_server._get_log_level() == 'info' # pyright: ignore[reportPrivateUsage] -def test_http_server_with_deprecated_arguments(): - with pytest.warns(DeprecationWarning): - http_server = MCPServerHTTP( - url='http://localhost:8000/sse', - headers={'my-custom-header': 'my-header-value'}, - timeout=10, - sse_read_timeout=100, - log_level='info', - ) +def test_http_server_with_timedelta_arguments(): + http_server = MCPServerHTTP( + url='http://localhost:8000/sse', + headers={'my-custom-header': 'my-header-value'}, + timeout=timedelta(seconds=10), # type: ignore[arg-type] + sse_read_timeout=timedelta(seconds=100), # type: ignore[arg-type] + log_level='info', + ) assert http_server.url == 'http://localhost:8000/sse' assert http_server.headers is not None and http_server.headers['my-custom-header'] == 'my-header-value' - assert http_server.timeout == timedelta(seconds=10) - assert http_server.sse_read_timeout == timedelta(seconds=100) + assert http_server.timeout == 10 + assert http_server.sse_read_timeout == 100 assert http_server._get_log_level() == 'info' # pyright: ignore[reportPrivateUsage]