diff --git a/docs/servers/middleware.mdx b/docs/servers/middleware.mdx index a701078165..ddf283fb90 100644 --- a/docs/servers/middleware.mdx +++ b/docs/servers/middleware.mdx @@ -555,6 +555,50 @@ my_tool = Tool.from_function(fn=my_tool_fn, name="my_tool") mcp.add_middleware(ToolInjectionMiddleware(tools=[my_tool])) ``` +### Response Limiting + + + +```python +from fastmcp.server.middleware.response_limiting import ResponseLimitingMiddleware +``` + +Large tool responses can overwhelm LLM context windows or cause memory issues. You can add response-limiting middleware to enforce size constraints on tool outputs. + +```python +from fastmcp import FastMCP +from fastmcp.server.middleware.response_limiting import ResponseLimitingMiddleware + +mcp = FastMCP("MyServer") + +# Limit all tool responses to 500KB +mcp.add_middleware(ResponseLimitingMiddleware(max_size=500_000)) + +@mcp.tool +def search(query: str) -> str: + # This could return a very large result + return "x" * 1_000_000 # 1MB response + +# When called, the response will be truncated to ~500KB with: +# "...\n\n[Response truncated due to size limit]" +``` + +When a response exceeds the limit, the middleware extracts all text content, joins it together, truncates to fit within the limit, and returns a single `TextContent` block. For non-text responses, the serialized JSON is used as the text source. + +```python +# Limit only specific tools +mcp.add_middleware(ResponseLimitingMiddleware( + max_size=100_000, + tools=["search", "fetch_data"], +)) +``` + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `max_size` | `int` | `1_000_000` | Maximum response size in bytes (1MB default) | +| `truncation_suffix` | `str` | `"\n\n[Response truncated due to size limit]"` | Suffix appended to truncated responses | +| `tools` | `list[str] \| None` | `None` | Limit only these tools (None = all tools) | + ### Combining Middleware Order matters. Place middleware that should run first (on the way in) earliest: diff --git a/src/fastmcp/server/middleware/response_limiting.py b/src/fastmcp/server/middleware/response_limiting.py new file mode 100644 index 0000000000..df83e81a08 --- /dev/null +++ b/src/fastmcp/server/middleware/response_limiting.py @@ -0,0 +1,125 @@ +"""Response limiting middleware for controlling tool response sizes.""" + +from __future__ import annotations + +import logging + +import mcp.types as mt +import pydantic_core +from mcp.types import TextContent + +from fastmcp.tools.tool import ToolResult + +from .middleware import CallNext, Middleware, MiddlewareContext + +__all__ = ["ResponseLimitingMiddleware"] + +logger = logging.getLogger(__name__) + + +class ResponseLimitingMiddleware(Middleware): + """Middleware that limits the response size of tool calls. + + Intercepts tool call responses and enforces size limits. If a response + exceeds the limit, it extracts text content, truncates it, and returns + a single TextContent block. + + Example: + ```python + from fastmcp import FastMCP + from fastmcp.server.middleware.response_limiting import ( + ResponseLimitingMiddleware, + ) + + mcp = FastMCP("MyServer") + + # Limit all tool responses to 500KB + mcp.add_middleware(ResponseLimitingMiddleware(max_size=500_000)) + + # Limit only specific tools + mcp.add_middleware( + ResponseLimitingMiddleware( + max_size=100_000, + tools=["search", "fetch_data"], + ) + ) + ``` + """ + + def __init__( + self, + *, + max_size: int = 1_000_000, + truncation_suffix: str = "\n\n[Response truncated due to size limit]", + tools: list[str] | None = None, + ) -> None: + """Initialize response limiting middleware. + + Args: + max_size: Maximum response size in bytes. Defaults to 1MB (1,000,000). + truncation_suffix: Suffix to append when truncating responses. + Defaults to "\\n\\n[Response truncated due to size limit]". + tools: List of tool names to apply limiting to. If None, applies to all. + """ + if max_size <= 0: + raise ValueError(f"max_size must be positive, got {max_size}") + self.max_size = max_size + self.truncation_suffix = truncation_suffix + self.tools = set(tools) if tools is not None else None + + def _truncate_to_result(self, text: str) -> ToolResult: + """Truncate text to fit within max_size and wrap in ToolResult.""" + suffix_bytes = len(self.truncation_suffix.encode("utf-8")) + # Account for JSON wrapper overhead: {"content":[{"type":"text","text":"..."}]} + overhead = 50 + target_size = self.max_size - suffix_bytes - overhead + + if target_size <= 0: + # Edge case: max_size too small for even the suffix + truncated = self.truncation_suffix + else: + # Truncate to target size, preserving UTF-8 boundaries + encoded = text.encode("utf-8") + if len(encoded) <= target_size: + truncated = text + self.truncation_suffix + else: + truncated = ( + encoded[:target_size].decode("utf-8", errors="ignore") + + self.truncation_suffix + ) + + return ToolResult(content=[TextContent(type="text", text=truncated)]) + + async def on_call_tool( + self, + context: MiddlewareContext[mt.CallToolRequestParams], + call_next: CallNext[mt.CallToolRequestParams, ToolResult], + ) -> ToolResult: + """Intercept tool calls and limit response size.""" + result = await call_next(context) + + # Check if we should limit this tool + if self.tools is not None and context.message.name not in self.tools: + return result + + # Measure serialized size + serialized = pydantic_core.to_json(result, fallback=str) + if len(serialized) <= self.max_size: + return result + + # Over limit: extract text, truncate, return single TextContent + logger.warning( + "Tool %r response exceeds size limit: %d bytes > %d bytes, truncating", + context.message.name, + len(serialized), + self.max_size, + ) + + texts = [b.text for b in result.content if isinstance(b, TextContent)] + text = ( + "\n\n".join(texts) + if texts + else serialized.decode("utf-8", errors="replace") + ) + + return self._truncate_to_result(text) diff --git a/tests/server/middleware/test_response_limiting.py b/tests/server/middleware/test_response_limiting.py new file mode 100644 index 0000000000..4e89e05dec --- /dev/null +++ b/tests/server/middleware/test_response_limiting.py @@ -0,0 +1,155 @@ +"""Tests for ResponseLimitingMiddleware.""" + +import pytest +from mcp.types import ImageContent, TextContent + +from fastmcp import Client, FastMCP +from fastmcp.server.middleware.response_limiting import ResponseLimitingMiddleware +from fastmcp.tools.tool import ToolResult + + +class TestResponseLimitingMiddleware: + """Tests for ResponseLimitingMiddleware.""" + + @pytest.fixture + def mcp_server(self) -> FastMCP: + """Create a basic MCP server for testing.""" + return FastMCP("test-server") + + async def test_response_under_limit_passes_unchanged(self, mcp_server: FastMCP): + """Test that responses under the limit pass through unchanged.""" + mcp_server.add_middleware(ResponseLimitingMiddleware(max_size=1_000_000)) + + @mcp_server.tool() + def small_tool() -> ToolResult: + return ToolResult(content=[TextContent(type="text", text="hello world")]) + + async with Client(mcp_server) as client: + result = await client.call_tool("small_tool", {}) + assert len(result.content) == 1 + assert result.content[0].text == "hello world" + + async def test_response_over_limit_is_truncated(self, mcp_server: FastMCP): + """Test that responses over the limit are truncated.""" + mcp_server.add_middleware(ResponseLimitingMiddleware(max_size=500)) + + @mcp_server.tool() + def large_tool() -> ToolResult: + return ToolResult(content=[TextContent(type="text", text="x" * 10_000)]) + + async with Client(mcp_server) as client: + result = await client.call_tool("large_tool", {}) + assert len(result.content) == 1 + assert "[Response truncated due to size limit]" in result.content[0].text + # Verify truncated result fits within limit + assert len(result.content[0].text.encode("utf-8")) < 500 + + async def test_tool_filtering(self, mcp_server: FastMCP): + """Test that tool filtering only applies to specified tools.""" + mcp_server.add_middleware( + ResponseLimitingMiddleware(max_size=100, tools=["limited_tool"]) + ) + + @mcp_server.tool() + def limited_tool() -> ToolResult: + return ToolResult(content=[TextContent(type="text", text="x" * 10_000)]) + + @mcp_server.tool() + def unlimited_tool() -> ToolResult: + return ToolResult(content=[TextContent(type="text", text="y" * 10_000)]) + + async with Client(mcp_server) as client: + # Limited tool should be truncated + result = await client.call_tool("limited_tool", {}) + assert "[Response truncated" in result.content[0].text + + # Unlimited tool should pass through + result = await client.call_tool("unlimited_tool", {}) + assert "y" * 100 in result.content[0].text + + async def test_empty_tools_list_limits_nothing(self, mcp_server: FastMCP): + """Test that empty tools list means no tools are limited.""" + mcp_server.add_middleware(ResponseLimitingMiddleware(max_size=100, tools=[])) + + @mcp_server.tool() + def any_tool() -> ToolResult: + return ToolResult(content=[TextContent(type="text", text="x" * 10_000)]) + + async with Client(mcp_server) as client: + result = await client.call_tool("any_tool", {}) + # Should NOT be truncated + assert "[Response truncated" not in result.content[0].text + + async def test_custom_truncation_suffix(self, mcp_server: FastMCP): + """Test that custom truncation suffix is applied.""" + mcp_server.add_middleware( + ResponseLimitingMiddleware(max_size=200, truncation_suffix="\n[CUT]") + ) + + @mcp_server.tool() + def large_tool() -> ToolResult: + return ToolResult(content=[TextContent(type="text", text="x" * 10_000)]) + + async with Client(mcp_server) as client: + result = await client.call_tool("large_tool", {}) + assert "[CUT]" in result.content[0].text + + async def test_multiple_text_blocks_combined(self, mcp_server: FastMCP): + """Test that multiple text blocks are combined when truncating.""" + mcp_server.add_middleware(ResponseLimitingMiddleware(max_size=300)) + + @mcp_server.tool() + def multi_block() -> ToolResult: + return ToolResult( + content=[ + TextContent(type="text", text="First: " + "a" * 500), + TextContent(type="text", text="Second: " + "b" * 500), + ] + ) + + async with Client(mcp_server) as client: + result = await client.call_tool("multi_block", {}) + # Both blocks should be joined and truncated + assert len(result.content) == 1 + assert "[Response truncated" in result.content[0].text + + async def test_binary_only_content_serialized(self, mcp_server: FastMCP): + """Test that binary-only responses fall back to serialized content.""" + mcp_server.add_middleware(ResponseLimitingMiddleware(max_size=200)) + + @mcp_server.tool() + def binary_tool() -> ToolResult: + return ToolResult( + content=[ + ImageContent(type="image", data="x" * 10_000, mimeType="image/png") + ] + ) + + async with Client(mcp_server) as client: + result = await client.call_tool("binary_tool", {}) + # Should be truncated (using serialized fallback) + assert len(result.content) == 1 + assert "[Response truncated" in result.content[0].text + + async def test_default_max_size_is_1mb(self): + """Test that the default max size is 1MB.""" + middleware = ResponseLimitingMiddleware() + assert middleware.max_size == 1_000_000 + + def test_invalid_max_size_raises(self): + """Test that zero or negative max_size raises ValueError.""" + with pytest.raises(ValueError, match="max_size must be positive"): + ResponseLimitingMiddleware(max_size=0) + with pytest.raises(ValueError, match="max_size must be positive"): + ResponseLimitingMiddleware(max_size=-100) + + def test_utf8_truncation_preserves_characters(self): + """Test that UTF-8 truncation doesn't break multi-byte characters.""" + middleware = ResponseLimitingMiddleware(max_size=100) + # Text with multi-byte characters (emoji) + text = "Hello 🌍 World 🎉 Test " * 100 + result = middleware._truncate_to_result(text) + # Should not raise and should be valid UTF-8 + content = result.content[0] + assert isinstance(content, TextContent) + content.text.encode("utf-8")