diff --git a/examples/mount_example.py b/examples/mount_example.py index 5e42b67a18..228c90f29e 100644 --- a/examples/mount_example.py +++ b/examples/mount_example.py @@ -1,7 +1,6 @@ """Example of mounting FastMCP apps together. -This example demonstrates how to mount FastMCP apps together using -the ToolManager's import_tools functionality. It shows how to: +This example demonstrates how to mount FastMCP apps together. It shows how to: 1. Create sub-applications for different domains 2. Mount those sub-applications to a main application @@ -103,10 +102,6 @@ async def get_server_details(): print(f" - Imported from weather app: {weather_resources}") print(f" - Imported from news app: {news_resources}") - # Let's try to access resources using the prefixed URI - weather_data = await app._read_resource_mcp(uri="weather://weather/forecast") - print(f"\nWeather data from prefixed URI: {weather_data}") - if __name__ == "__main__": # First run our async function to display info diff --git a/src/fastmcp/contrib/component_manager/component_service.py b/src/fastmcp/contrib/component_manager/component_service.py index 6b37573582..e21c486e84 100644 --- a/src/fastmcp/contrib/component_manager/component_service.py +++ b/src/fastmcp/contrib/component_manager/component_service.py @@ -25,7 +25,7 @@ def _get_mounted_server_and_key( Args: provider: The provider to check. key: The transformed component key. - component_type: Either "tool" (for tools/prompts) or "resource". + component_type: Either "tool", "prompt", or "resource". Returns: Tuple of (server, original_key) if the key matches this provider, @@ -35,6 +35,8 @@ def _get_mounted_server_and_key( # TransformingProvider - reverse the transformation if component_type == "resource": original = provider._reverse_resource_uri(key) + elif component_type == "prompt": + original = provider._reverse_prompt_name(key) else: original = provider._reverse_tool_name(key) @@ -55,9 +57,6 @@ class ComponentService: def __init__(self, server: FastMCP): self._server = server - self._tool_manager = server._tool_manager - self._resource_manager = server._resource_manager - self._prompt_manager = server._prompt_manager async def _enable_tool(self, key: str) -> Tool: """Handle 'enableTool' requests. @@ -71,7 +70,7 @@ async def _enable_tool(self, key: str) -> Tool: logger.debug("Enabling tool: %s", key) # 1. Check local tools first. The server will have already applied its filter. - if key in self._server._tool_manager._tools: + if key in self._server._local_provider._tools: tool: Tool = await self._server.get_tool(key) tool.enable() return tool @@ -98,7 +97,7 @@ async def _disable_tool(self, key: str) -> Tool: logger.debug("Disable tool: %s", key) # 1. Check local tools first. The server will have already applied its filter. - if key in self._server._tool_manager._tools: + if key in self._server._local_provider._tools: tool: Tool = await self._server.get_tool(key) tool.disable() return tool @@ -125,11 +124,11 @@ async def _enable_resource(self, key: str) -> Resource | ResourceTemplate: logger.debug("Enabling resource: %s", key) # 1. Check local resources first. The server will have already applied its filter. - if key in self._resource_manager._resources: + if key in self._server._local_provider._resources: resource: Resource = await self._server.get_resource(key) resource.enable() return resource - if key in self._resource_manager._templates: + if key in self._server._local_provider._templates: template: ResourceTemplate = await self._server.get_resource_template(key) template.enable() return template @@ -158,11 +157,11 @@ async def _disable_resource(self, key: str) -> Resource | ResourceTemplate: logger.debug("Disable resource: %s", key) # 1. Check local resources first. The server will have already applied its filter. - if key in self._resource_manager._resources: + if key in self._server._local_provider._resources: resource: Resource = await self._server.get_resource(key) resource.disable() return resource - if key in self._resource_manager._templates: + if key in self._server._local_provider._templates: template: ResourceTemplate = await self._server.get_resource_template(key) template.disable() return template @@ -191,14 +190,14 @@ async def _enable_prompt(self, key: str) -> Prompt: logger.debug("Enabling prompt: %s", key) # 1. Check local prompts first. The server will have already applied its filter. - if key in self._server._prompt_manager._prompts: + if key in self._server._local_provider._prompts: prompt: Prompt = await self._server.get_prompt(key) prompt.enable() return prompt # 2. Check mounted servers via FastMCPProvider/TransformingProvider for provider in self._server._providers: - result = _get_mounted_server_and_key(provider, key, "tool") + result = _get_mounted_server_and_key(provider, key, "prompt") if result is not None: server, unprefixed = result mounted_service = ComponentService(server) @@ -217,14 +216,14 @@ async def _disable_prompt(self, key: str) -> Prompt: """ # 1. Check local prompts first. The server will have already applied its filter. - if key in self._server._prompt_manager._prompts: + if key in self._server._local_provider._prompts: prompt: Prompt = await self._server.get_prompt(key) prompt.disable() return prompt # 2. Check mounted servers via FastMCPProvider/TransformingProvider for provider in self._server._providers: - result = _get_mounted_server_and_key(provider, key, "tool") + result = _get_mounted_server_and_key(provider, key, "prompt") if result is not None: server, unprefixed = result mounted_service = ComponentService(server) diff --git a/src/fastmcp/prompts/__init__.py b/src/fastmcp/prompts/__init__.py index 440816289b..0b6cf50ee2 100644 --- a/src/fastmcp/prompts/__init__.py +++ b/src/fastmcp/prompts/__init__.py @@ -1,10 +1,8 @@ -from .prompt import Message, Prompt, PromptResult, PromptMessage -from .prompt_manager import PromptManager +from .prompt import Message, Prompt, PromptMessage, PromptResult __all__ = [ "Message", "Prompt", - "PromptManager", "PromptMessage", "PromptResult", ] diff --git a/src/fastmcp/prompts/prompt_manager.py b/src/fastmcp/prompts/prompt_manager.py deleted file mode 100644 index b2951da28e..0000000000 --- a/src/fastmcp/prompts/prompt_manager.py +++ /dev/null @@ -1,129 +0,0 @@ -from __future__ import annotations as _annotations - -import warnings -from collections.abc import Awaitable, Callable -from typing import Any - -import mcp.types - -from fastmcp import settings -from fastmcp.exceptions import FastMCPError, NotFoundError, PromptError -from fastmcp.prompts.prompt import ( - FunctionPrompt, - Prompt, - PromptResult, - _PromptFnReturn, -) -from fastmcp.settings import DuplicateBehavior -from fastmcp.utilities.logging import get_logger - -logger = get_logger(__name__) - - -class PromptManager: - """Manages FastMCP prompts.""" - - def __init__( - self, - duplicate_behavior: DuplicateBehavior | None = None, - mask_error_details: bool | None = None, - ): - self._prompts: dict[str, Prompt] = {} - self.mask_error_details = ( - settings.mask_error_details - if mask_error_details is None - else mask_error_details - ) - - # Default to "warn" if None is provided - if duplicate_behavior is None: - duplicate_behavior = "warn" - - if duplicate_behavior not in DuplicateBehavior.__args__: - raise ValueError( - f"Invalid duplicate_behavior: {duplicate_behavior}. " - f"Must be one of: {', '.join(DuplicateBehavior.__args__)}" - ) - - self.duplicate_behavior = duplicate_behavior - - async def has_prompt(self, key: str) -> bool: - """Check if a prompt exists.""" - prompts = await self.get_prompts() - return key in prompts - - async def get_prompt(self, key: str) -> Prompt: - """Get prompt by key.""" - prompts = await self.get_prompts() - if key in prompts: - return prompts[key] - raise NotFoundError(f"Unknown prompt: {key}") - - async def get_prompts(self) -> dict[str, Prompt]: - """ - Gets the complete, unfiltered inventory of local prompts. - """ - return dict(self._prompts) - - def add_prompt_from_fn( - self, - fn: Callable[..., _PromptFnReturn | Awaitable[_PromptFnReturn]], - name: str | None = None, - description: str | None = None, - tags: set[str] | None = None, - ) -> FunctionPrompt: - """Create a prompt from a function.""" - # deprecated in 2.7.0 - if settings.deprecation_warnings: - warnings.warn( - "PromptManager.add_prompt_from_fn() is deprecated. Use Prompt.from_function() and call add_prompt() instead.", - DeprecationWarning, - stacklevel=2, - ) - prompt = FunctionPrompt.from_function( - fn, name=name, description=description, tags=tags - ) - return self.add_prompt(prompt) # type: ignore - - def add_prompt(self, prompt: Prompt) -> Prompt: - """Add a prompt to the manager.""" - # Check for duplicates - existing = self._prompts.get(prompt.key) - if existing: - if self.duplicate_behavior == "warn": - logger.warning(f"Prompt already exists: {prompt.key}") - self._prompts[prompt.key] = prompt - elif self.duplicate_behavior == "replace": - self._prompts[prompt.key] = prompt - elif self.duplicate_behavior == "error": - raise ValueError(f"Prompt already exists: {prompt.key}") - elif self.duplicate_behavior == "ignore": - return existing - else: - self._prompts[prompt.key] = prompt - return prompt - - async def render_prompt( - self, - name: str, - arguments: dict[str, Any] | None = None, - ) -> PromptResult | mcp.types.CreateTaskResult: - """ - Internal API for servers: Finds and renders a prompt. - - Note: Full error handling (logging, masking) is done at the FastMCP - server level. This method provides basic error wrapping for direct usage. - - Returns: - PromptResult for synchronous execution, or CreateTaskResult if - the prompt was submitted to Docket for background execution. - """ - prompt = await self.get_prompt(name) - try: - return await prompt._render(arguments) - except FastMCPError: - raise - except Exception as e: - if self.mask_error_details: - raise PromptError(f"Error rendering prompt {name!r}") from e - raise PromptError(f"Error rendering prompt {name!r}: {e}") from e diff --git a/src/fastmcp/resources/__init__.py b/src/fastmcp/resources/__init__.py index 5462cd4a84..6632102324 100644 --- a/src/fastmcp/resources/__init__.py +++ b/src/fastmcp/resources/__init__.py @@ -1,5 +1,4 @@ from .resource import FunctionResource, Resource, ResourceContent -from .resource_manager import ResourceManager from .template import ResourceTemplate from .types import ( BinaryResource, @@ -17,7 +16,6 @@ "HttpResource", "Resource", "ResourceContent", - "ResourceManager", "ResourceTemplate", "TextResource", ] diff --git a/src/fastmcp/resources/resource_manager.py b/src/fastmcp/resources/resource_manager.py deleted file mode 100644 index e021e5ac95..0000000000 --- a/src/fastmcp/resources/resource_manager.py +++ /dev/null @@ -1,334 +0,0 @@ -"""Resource manager functionality.""" - -from __future__ import annotations - -import inspect -import warnings -from collections.abc import Callable -from typing import Any - -import mcp.types -from pydantic import AnyUrl - -from fastmcp import settings -from fastmcp.exceptions import FastMCPError, NotFoundError, ResourceError -from fastmcp.resources.resource import Resource, ResourceContent -from fastmcp.resources.template import ( - ResourceTemplate, - match_uri_template, -) -from fastmcp.settings import DuplicateBehavior -from fastmcp.utilities.logging import get_logger - -logger = get_logger(__name__) - - -class ResourceManager: - """Manages FastMCP resources.""" - - def __init__( - self, - duplicate_behavior: DuplicateBehavior | None = None, - mask_error_details: bool | None = None, - ): - """Initialize the ResourceManager. - - Args: - duplicate_behavior: How to handle duplicate resources - (warn, error, replace, ignore) - mask_error_details: Whether to mask error details from exceptions - other than ResourceError - """ - self._resources: dict[str, Resource] = {} - self._templates: dict[str, ResourceTemplate] = {} - self.mask_error_details = mask_error_details or settings.mask_error_details - - # Default to "warn" if None is provided - if duplicate_behavior is None: - duplicate_behavior = "warn" - - if duplicate_behavior not in DuplicateBehavior.__args__: - raise ValueError( - f"Invalid duplicate_behavior: {duplicate_behavior}. " - f"Must be one of: {', '.join(DuplicateBehavior.__args__)}" - ) - self.duplicate_behavior = duplicate_behavior - - async def get_resources(self) -> dict[str, Resource]: - """Get all registered resources, keyed by URI.""" - return dict(self._resources) - - async def get_resource_templates(self) -> dict[str, ResourceTemplate]: - """Get all registered templates, keyed by URI template.""" - return dict(self._templates) - - def add_resource_or_template_from_fn( - self, - fn: Callable[..., Any], - uri: str, - name: str | None = None, - description: str | None = None, - mime_type: str | None = None, - tags: set[str] | None = None, - ) -> Resource | ResourceTemplate: - """Add a resource or template to the manager from a function. - - Args: - fn: The function to register as a resource or template - uri: The URI for the resource or template - name: Optional name for the resource or template - description: Optional description of the resource or template - mime_type: Optional MIME type for the resource or template - tags: Optional set of tags for categorizing the resource or template - - Returns: - The added resource or template. If a resource or template with the same URI already exists, - returns the existing resource or template. - """ - from fastmcp.server.context import Context - - # Check if this should be a template - has_uri_params = "{" in uri and "}" in uri - # check if the function has any parameters (other than injected context) - has_func_params = any( - p - for p in inspect.signature(fn).parameters.values() - if p.annotation is not Context - ) - - if has_uri_params or has_func_params: - return self.add_template_from_fn( - fn, uri, name, description, mime_type, tags - ) - elif not has_uri_params and not has_func_params: - return self.add_resource_from_fn( - fn, uri, name, description, mime_type, tags - ) - else: - raise ValueError( - "Invalid resource or template definition due to a " - "mismatch between URI parameters and function parameters." - ) - - def add_resource_from_fn( - self, - fn: Callable[..., Any], - uri: str, - name: str | None = None, - description: str | None = None, - mime_type: str | None = None, - tags: set[str] | None = None, - ) -> Resource: - """Add a resource to the manager from a function. - - Args: - fn: The function to register as a resource - uri: The URI for the resource - name: Optional name for the resource - description: Optional description of the resource - mime_type: Optional MIME type for the resource - tags: Optional set of tags for categorizing the resource - - Returns: - The added resource. If a resource with the same URI already exists, - returns the existing resource. - """ - # deprecated in 2.7.0 - if settings.deprecation_warnings: - warnings.warn( - "add_resource_from_fn is deprecated. Use Resource.from_function() and call add_resource() instead.", - DeprecationWarning, - stacklevel=2, - ) - resource = Resource.from_function( - fn=fn, - uri=uri, - name=name, - description=description, - mime_type=mime_type, - tags=tags, - ) - return self.add_resource(resource) - - def add_resource(self, resource: Resource) -> Resource: - """Add a resource to the manager. - - Args: - resource: A Resource instance to add. The resource's .key attribute - (which is str(uri)) will be used as the storage key. To use a - different key, change the uri via model_copy(update={"uri": new_uri}). - """ - existing = self._resources.get(resource.key) - if existing: - if self.duplicate_behavior == "warn": - logger.warning(f"Resource already exists: {resource.key}") - self._resources[resource.key] = resource - elif self.duplicate_behavior == "replace": - self._resources[resource.key] = resource - elif self.duplicate_behavior == "error": - raise ValueError(f"Resource already exists: {resource.key}") - elif self.duplicate_behavior == "ignore": - return existing - self._resources[resource.key] = resource - return resource - - def add_template_from_fn( - self, - fn: Callable[..., Any], - uri_template: str, - name: str | None = None, - description: str | None = None, - mime_type: str | None = None, - tags: set[str] | None = None, - ) -> ResourceTemplate: - """Create a template from a function.""" - # deprecated in 2.7.0 - if settings.deprecation_warnings: - warnings.warn( - "add_template_from_fn is deprecated. Use ResourceTemplate.from_function() and call add_template() instead.", - DeprecationWarning, - stacklevel=2, - ) - template = ResourceTemplate.from_function( - fn, - uri_template=uri_template, - name=name, - description=description, - mime_type=mime_type, - tags=tags, - ) - return self.add_template(template) - - def add_template(self, template: ResourceTemplate) -> ResourceTemplate: - """Add a template to the manager. - - Args: - template: A ResourceTemplate instance to add. The template's .key attribute - (which is uri_template) will be used as the storage key. To use a - different key, change uri_template via model_copy(update={"uri_template": new_uri}). - - Returns: - The added template. If a template with the same URI already exists, - returns the existing template. - """ - existing = self._templates.get(template.key) - if existing: - if self.duplicate_behavior == "warn": - logger.warning(f"Template already exists: {template.key}") - self._templates[template.key] = template - elif self.duplicate_behavior == "replace": - self._templates[template.key] = template - elif self.duplicate_behavior == "error": - raise ValueError(f"Template already exists: {template.key}") - elif self.duplicate_behavior == "ignore": - return existing - self._templates[template.key] = template - return template - - async def has_resource(self, uri: AnyUrl | str) -> bool: - """Check if a resource exists.""" - uri_str = str(uri) - - # First check concrete resources (local and mounted) - resources = await self.get_resources() - if uri_str in resources: - return True - - # Then check templates (local and mounted) only if not found in concrete resources - templates = await self.get_resource_templates() - for template_key in templates: - if match_uri_template(uri_str, template_key) is not None: - return True - - return False - - async def get_resource(self, uri: AnyUrl | str) -> Resource: - """Get resource by URI, checking concrete resources first, then templates. - - Args: - uri: The URI of the resource to get - - Raises: - NotFoundError: If no resource or template matching the URI is found. - """ - uri_str = str(uri) - logger.debug("Getting resource", extra={"uri": uri_str}) - - # First check concrete resources - resources = await self.get_resources() - if resource := resources.get(uri_str): - return resource - - # Then check templates - templates = await self.get_resource_templates() - for storage_key, template in templates.items(): - # Try to match against the storage key (which might be a custom key) - if (params := match_uri_template(uri_str, storage_key)) is not None: - try: - return await template.create_resource( - uri_str, - params=params, - ) - # Pass through FastMCPErrors as-is - except FastMCPError as e: - logger.error(f"Error creating resource from template: {e}") - raise e - # Handle other exceptions - except Exception as e: - logger.error(f"Error creating resource from template: {e}") - if self.mask_error_details: - # Mask internal details - raise ValueError("Error creating resource from template") from e - else: - # Include original error details - raise ValueError( - f"Error creating resource from template: {e}" - ) from e - - raise NotFoundError(f"Unknown resource: {uri_str}") - - async def read_resource( - self, uri: AnyUrl | str - ) -> ResourceContent | mcp.types.CreateTaskResult: - """ - Internal API for servers: Finds and reads a resource. - - Note: Full error handling (logging, masking) is done at the FastMCP - server level. This method provides basic error wrapping for direct usage. - - Returns: - ResourceContent for synchronous execution, or CreateTaskResult if - the resource was submitted to Docket for background execution. - """ - uri_str = str(uri) - - # Check local resources first - if uri_str in self._resources: - resource = await self.get_resource(uri_str) - try: - return await resource._read() - except FastMCPError: - raise - except Exception as e: - if self.mask_error_details: - raise ResourceError(f"Error reading resource {uri_str!r}") from e - raise ResourceError(f"Error reading resource {uri_str!r}: {e}") from e - - # Check local templates if not found in concrete resources - for key, template in self._templates.items(): - if (params := match_uri_template(uri_str, key)) is not None: - try: - resource = await template.create_resource(uri_str, params=params) - return await resource._read() - except FastMCPError: - raise - except Exception as e: - if self.mask_error_details: - raise ResourceError( - f"Error reading resource from template {uri_str!r}" - ) from e - raise ResourceError( - f"Error reading resource from template {uri_str!r}: {e}" - ) from e - - raise NotFoundError(f"Resource {uri_str!r} not found.") diff --git a/src/fastmcp/server/providers/__init__.py b/src/fastmcp/server/providers/__init__.py index 5669f6d73f..25af32a55d 100644 --- a/src/fastmcp/server/providers/__init__.py +++ b/src/fastmcp/server/providers/__init__.py @@ -29,6 +29,7 @@ async def get_tool(self, name: str) -> Tool | None: from fastmcp.server.providers.base import Provider from fastmcp.server.providers.fastmcp_provider import FastMCPProvider +from fastmcp.server.providers.local_provider import LocalProvider from fastmcp.server.providers.transforming import TransformingProvider if TYPE_CHECKING: @@ -37,6 +38,7 @@ async def get_tool(self, name: str) -> Tool | None: __all__ = [ "FastMCPProvider", + "LocalProvider", "OpenAPIProvider", "Provider", "ProxyProvider", diff --git a/src/fastmcp/server/providers/base.py b/src/fastmcp/server/providers/base.py index aa5d8665b6..d6f7186e2d 100644 --- a/src/fastmcp/server/providers/base.py +++ b/src/fastmcp/server/providers/base.py @@ -31,6 +31,7 @@ async def get_tool(self, name: str) -> Tool | None: from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass +from typing import Literal from fastmcp.prompts.prompt import Prompt from fastmcp.resources.resource import Resource @@ -70,6 +71,26 @@ class Provider: This allows other providers to still contribute their components. """ + def _notify( + self, notification_type: Literal["tools", "resources", "prompts"] + ) -> None: + """Send a list changed notification if we're in a request context. + + This is a no-op if called outside a request context (e.g., during setup). + """ + try: + from fastmcp.server.dependencies import get_context + + context = get_context() + if notification_type == "tools": + context._queue_tool_list_changed() + elif notification_type == "resources": + context._queue_resource_list_changed() + elif notification_type == "prompts": + context._queue_prompt_list_changed() + except RuntimeError: + pass # No context available + def with_transforms( self, *, diff --git a/src/fastmcp/server/providers/fastmcp_provider.py b/src/fastmcp/server/providers/fastmcp_provider.py index 1fa827a16a..0d69133d98 100644 --- a/src/fastmcp/server/providers/fastmcp_provider.py +++ b/src/fastmcp/server/providers/fastmcp_provider.py @@ -563,32 +563,15 @@ async def get_tasks(self) -> TaskComponents: functions get registered with Docket. TransformingProvider.get_tasks() handles namespace transformation of keys. - Accesses managers directly to avoid triggering middleware during startup. + Iterates through all providers in the wrapped server (including its + LocalProvider) to collect task-eligible components. """ - # Return child's actual components - their .fn gets registered with Docket - # TransformingProvider.get_tasks() transforms keys to include namespace - tools: list[Tool] = [ - t - for t in self.server._tool_manager._tools.values() - if t.task_config.supports_tasks() - ] - resources: list[Resource] = [ - r - for r in self.server._resource_manager._resources.values() - if r.task_config.supports_tasks() - ] - templates: list[ResourceTemplate] = [ - t - for t in self.server._resource_manager._templates.values() - if t.task_config.supports_tasks() - ] - prompts: list[Prompt] = [ - p - for p in self.server._prompt_manager._prompts.values() - if p.task_config.supports_tasks() - ] + tools: list[Tool] = [] + resources: list[Resource] = [] + templates: list[ResourceTemplate] = [] + prompts: list[Prompt] = [] - # Recursively get tasks from nested providers + # Get tasks from all providers in the wrapped server for provider in self.server._providers: nested = await provider.get_tasks() tools.extend(nested.tools) diff --git a/src/fastmcp/server/providers/local_provider.py b/src/fastmcp/server/providers/local_provider.py new file mode 100644 index 0000000000..4b25a8f6a7 --- /dev/null +++ b/src/fastmcp/server/providers/local_provider.py @@ -0,0 +1,816 @@ +"""LocalProvider for locally-defined MCP components. + +This module provides the `LocalProvider` class that manages tools, resources, +templates, and prompts registered via decorators or direct methods. + +LocalProvider can be used standalone and attached to multiple servers: + +```python +from fastmcp.server.providers import LocalProvider + +# Create a reusable provider with tools +provider = LocalProvider() + +@provider.tool +def greet(name: str) -> str: + return f"Hello, {name}!" + +# Attach to any server +from fastmcp import FastMCP +server1 = FastMCP("Server1", providers=[provider]) +server2 = FastMCP("Server2", providers=[provider]) +``` +""" + +from __future__ import annotations + +import inspect +from collections.abc import Callable, Sequence +from functools import partial +from typing import TYPE_CHECKING, Any, Literal, overload + +import mcp.types +from mcp.types import Annotations, AnyFunction, ToolAnnotations + +from fastmcp.prompts.prompt import FunctionPrompt, Prompt +from fastmcp.resources.resource import Resource +from fastmcp.resources.template import ResourceTemplate +from fastmcp.server.providers.base import Provider, TaskComponents +from fastmcp.server.tasks.config import TaskConfig +from fastmcp.tools.tool import FunctionTool, Tool +from fastmcp.tools.tool_transform import ( + ToolTransformConfig, + apply_transformations_to_tools, +) +from fastmcp.utilities.logging import get_logger +from fastmcp.utilities.types import NotSet, NotSetT + +if TYPE_CHECKING: + from fastmcp.tools.tool import ToolResultSerializerType + +logger = get_logger(__name__) + +DuplicateBehavior = Literal["error", "warn", "replace", "ignore"] + + +class LocalProvider(Provider): + """Provider for locally-defined components. + + Supports decorator-based registration (`@provider.tool`, `@provider.resource`, + `@provider.prompt`) and direct object registration methods. + + When used standalone, LocalProvider uses default settings. When attached + to a FastMCP server via the server's decorators, server-level settings + like `_tool_serializer` and `_support_tasks_by_default` are injected. + + Example: + ```python + from fastmcp.server.providers import LocalProvider + + # Standalone usage + provider = LocalProvider() + + @provider.tool + def greet(name: str) -> str: + return f"Hello, {name}!" + + @provider.resource("data://config") + def get_config() -> str: + return '{"setting": "value"}' + + @provider.prompt + def analyze(topic: str) -> list: + return [{"role": "user", "content": f"Analyze: {topic}"}] + + # Attach to server(s) + from fastmcp import FastMCP + server = FastMCP("MyServer", providers=[provider]) + ``` + """ + + def __init__( + self, + on_duplicate: DuplicateBehavior = "error", + ) -> None: + """Initialize a LocalProvider with empty storage. + + Args: + on_duplicate: Behavior when adding a component that already exists: + - "error": Raise ValueError + - "warn": Log warning and replace + - "replace": Silently replace + - "ignore": Keep existing, return it + """ + super().__init__() + self._on_duplicate = on_duplicate + self._tools: dict[str, Tool] = {} + self._resources: dict[str, Resource] = {} + self._templates: dict[str, ResourceTemplate] = {} + self._prompts: dict[str, Prompt] = {} + self._tool_transformations: dict[str, ToolTransformConfig] = {} + + # ========================================================================= + # Storage methods + # ========================================================================= + + def add_tool(self, tool: Tool) -> Tool: + """Add a tool to this provider's storage. + + Args: + tool: The Tool instance to add. + + Returns: + The tool that was added (or existing tool if on_duplicate="ignore"). + """ + existing = self._tools.get(tool.key) + if existing: + if self._on_duplicate == "error": + raise ValueError(f"Tool already exists: {tool.key}") + elif self._on_duplicate == "warn": + logger.warning(f"Tool already exists: {tool.key}") + elif self._on_duplicate == "ignore": + return existing + # "replace" and "warn" fall through to add + + self._tools[tool.key] = tool + self._notify("tools") + return tool + + def remove_tool(self, key: str) -> None: + """Remove a tool from this provider's storage. + + Args: + key: The key of the tool to remove. + + Raises: + KeyError: If the tool is not found. + """ + if key not in self._tools: + raise KeyError(f"Tool {key!r} not found") + del self._tools[key] + self._notify("tools") + + def add_resource(self, resource: Resource) -> Resource: + """Add a resource to this provider's storage. + + Args: + resource: The Resource instance to add. + + Returns: + The resource that was added (or existing if on_duplicate="ignore"). + """ + existing = self._resources.get(resource.key) + if existing: + if self._on_duplicate == "error": + raise ValueError(f"Resource already exists: {resource.key}") + elif self._on_duplicate == "warn": + logger.warning(f"Resource already exists: {resource.key}") + elif self._on_duplicate == "ignore": + return existing + + self._resources[resource.key] = resource + self._notify("resources") + return resource + + def remove_resource(self, key: str) -> None: + """Remove a resource from this provider's storage. + + Args: + key: The key of the resource to remove. + + Raises: + KeyError: If the resource is not found. + """ + if key not in self._resources: + raise KeyError(f"Resource {key!r} not found") + del self._resources[key] + self._notify("resources") + + def add_template(self, template: ResourceTemplate) -> ResourceTemplate: + """Add a resource template to this provider's storage. + + Args: + template: The ResourceTemplate instance to add. + + Returns: + The template that was added (or existing if on_duplicate="ignore"). + """ + existing = self._templates.get(template.key) + if existing: + if self._on_duplicate == "error": + raise ValueError(f"Template already exists: {template.key}") + elif self._on_duplicate == "warn": + logger.warning(f"Template already exists: {template.key}") + elif self._on_duplicate == "ignore": + return existing + + self._templates[template.key] = template + self._notify("resources") + return template + + def remove_template(self, key: str) -> None: + """Remove a resource template from this provider's storage. + + Args: + key: The key of the template to remove. + + Raises: + KeyError: If the template is not found. + """ + if key not in self._templates: + raise KeyError(f"Template {key!r} not found") + del self._templates[key] + self._notify("resources") + + def add_prompt(self, prompt: Prompt) -> Prompt: + """Add a prompt to this provider's storage. + + Args: + prompt: The Prompt instance to add. + + Returns: + The prompt that was added (or existing if on_duplicate="ignore"). + """ + existing = self._prompts.get(prompt.key) + if existing: + if self._on_duplicate == "error": + raise ValueError(f"Prompt already exists: {prompt.key}") + elif self._on_duplicate == "warn": + logger.warning(f"Prompt already exists: {prompt.key}") + elif self._on_duplicate == "ignore": + return existing + + self._prompts[prompt.key] = prompt + self._notify("prompts") + return prompt + + def remove_prompt(self, key: str) -> None: + """Remove a prompt from this provider's storage. + + Args: + key: The key of the prompt to remove. + + Raises: + KeyError: If the prompt is not found. + """ + if key not in self._prompts: + raise KeyError(f"Prompt {key!r} not found") + del self._prompts[key] + self._notify("prompts") + + # ========================================================================= + # Tool transformation methods + # ========================================================================= + + def add_tool_transformation( + self, tool_name: str, transformation: ToolTransformConfig + ) -> None: + """Add a tool transformation. + + Args: + tool_name: The name of the tool to transform. + transformation: The transformation configuration. + """ + self._tool_transformations[tool_name] = transformation + + def get_tool_transformation(self, tool_name: str) -> ToolTransformConfig | None: + """Get a tool transformation. + + Args: + tool_name: The name of the tool. + + Returns: + The transformation config, or None if not found. + """ + return self._tool_transformations.get(tool_name) + + def remove_tool_transformation(self, tool_name: str) -> None: + """Remove a tool transformation. + + Args: + tool_name: The name of the tool. + """ + if tool_name in self._tool_transformations: + del self._tool_transformations[tool_name] + + # ========================================================================= + # Provider interface implementation + # ========================================================================= + + async def list_tools(self) -> Sequence[Tool]: + """Return all tools with transformations applied.""" + transformed = apply_transformations_to_tools( + tools=self._tools, + transformations=self._tool_transformations, + ) + return list(transformed.values()) + + async def get_tool(self, name: str) -> Tool | None: + """Get a tool by name, with transformations applied.""" + tools = await self.list_tools() + return next((t for t in tools if t.name == name), None) + + async def list_resources(self) -> Sequence[Resource]: + """Return all resources.""" + return list(self._resources.values()) + + async def get_resource(self, uri: str) -> Resource | None: + """Get a resource by URI.""" + return self._resources.get(uri) + + async def list_resource_templates(self) -> Sequence[ResourceTemplate]: + """Return all resource templates.""" + return list(self._templates.values()) + + async def get_resource_template(self, uri: str) -> ResourceTemplate | None: + """Get a resource template that matches the given URI.""" + for template in self._templates.values(): + if template.matches(uri) is not None: + return template + return None + + async def list_prompts(self) -> Sequence[Prompt]: + """Return all prompts.""" + return list(self._prompts.values()) + + async def get_prompt(self, name: str) -> Prompt | None: + """Get a prompt by name.""" + return self._prompts.get(name) + + # ========================================================================= + # Task registration + # ========================================================================= + + async def get_tasks(self) -> TaskComponents: + """Return components eligible for background task execution. + + Returns components that have task_config.mode != 'forbidden'. + This includes both FunctionTool/Resource/Prompt instances created via + decorators and custom Tool/Resource/Prompt subclasses. + """ + return TaskComponents( + tools=[t for t in self._tools.values() if t.task_config.supports_tasks()], + resources=[ + r for r in self._resources.values() if r.task_config.supports_tasks() + ], + templates=[ + t for t in self._templates.values() if t.task_config.supports_tasks() + ], + prompts=[ + p for p in self._prompts.values() if p.task_config.supports_tasks() + ], + ) + + # ========================================================================= + # Decorator methods + # ========================================================================= + + @overload + def tool( + self, + name_or_fn: AnyFunction, + *, + name: str | None = None, + title: str | None = None, + description: str | None = None, + icons: list[mcp.types.Icon] | None = None, + tags: set[str] | None = None, + output_schema: dict[str, Any] | NotSetT | None = NotSet, + annotations: ToolAnnotations | dict[str, Any] | None = None, + exclude_args: list[str] | None = None, + meta: dict[str, Any] | None = None, + enabled: bool | None = None, + task: bool | TaskConfig | None = None, + serializer: ToolResultSerializerType | None = None, + ) -> FunctionTool: ... + + @overload + def tool( + self, + name_or_fn: str | None = None, + *, + name: str | None = None, + title: str | None = None, + description: str | None = None, + icons: list[mcp.types.Icon] | None = None, + tags: set[str] | None = None, + output_schema: dict[str, Any] | NotSetT | None = NotSet, + annotations: ToolAnnotations | dict[str, Any] | None = None, + exclude_args: list[str] | None = None, + meta: dict[str, Any] | None = None, + enabled: bool | None = None, + task: bool | TaskConfig | None = None, + serializer: ToolResultSerializerType | None = None, + ) -> Callable[[AnyFunction], FunctionTool]: ... + + def tool( + self, + name_or_fn: str | AnyFunction | None = None, + *, + name: str | None = None, + title: str | None = None, + description: str | None = None, + icons: list[mcp.types.Icon] | None = None, + tags: set[str] | None = None, + output_schema: dict[str, Any] | NotSetT | None = NotSet, + annotations: ToolAnnotations | dict[str, Any] | None = None, + exclude_args: list[str] | None = None, + meta: dict[str, Any] | None = None, + enabled: bool | None = None, + task: bool | TaskConfig | None = None, + serializer: ToolResultSerializerType | None = None, + ) -> ( + Callable[[AnyFunction], FunctionTool] + | FunctionTool + | partial[Callable[[AnyFunction], FunctionTool] | FunctionTool] + ): + """Decorator to register a tool. + + This decorator supports multiple calling patterns: + - @provider.tool (without parentheses) + - @provider.tool() (with empty parentheses) + - @provider.tool("custom_name") (with name as first argument) + - @provider.tool(name="custom_name") (with name as keyword argument) + - provider.tool(function, name="custom_name") (direct function call) + + Args: + name_or_fn: Either a function (when used as @tool), a string name, or None + name: Optional name for the tool (keyword-only, alternative to name_or_fn) + title: Optional title for the tool + description: Optional description of what the tool does + icons: Optional icons for the tool + tags: Optional set of tags for categorizing the tool + output_schema: Optional JSON schema for the tool's output + annotations: Optional annotations about the tool's behavior + exclude_args: Optional list of argument names to exclude from the tool schema + meta: Optional meta information about the tool + enabled: Optional boolean to enable or disable the tool + task: Optional task configuration for background execution + serializer: Optional serializer for the tool result + + Returns: + The registered FunctionTool or a decorator function. + + Example: + ```python + provider = LocalProvider() + + @provider.tool + def greet(name: str) -> str: + return f"Hello, {name}!" + + @provider.tool("custom_name") + def my_tool(x: int) -> str: + return str(x) + ``` + """ + if isinstance(annotations, dict): + annotations = ToolAnnotations(**annotations) + + if isinstance(name_or_fn, classmethod): + raise ValueError( + inspect.cleandoc( + """ + To decorate a classmethod, first define the method and then call + tool() directly on the method instead of using it as a + decorator. See https://gofastmcp.com/patterns/decorating-methods + for examples and more information. + """ + ) + ) + + # Determine the actual name and function based on the calling pattern + if inspect.isroutine(name_or_fn): + # Case 1: @tool (without parens) - function passed directly + # Case 2: direct call like tool(fn, name="something") + fn = name_or_fn + tool_name = name # Use keyword name if provided, otherwise None + + # Resolve task parameter - default to False for standalone usage + supports_task: bool | TaskConfig = task if task is not None else False + + # Register the tool immediately and return the tool object + tool_obj = Tool.from_function( + fn, + name=tool_name, + title=title, + description=description, + icons=icons, + tags=tags, + output_schema=output_schema, + annotations=annotations, + exclude_args=exclude_args, + meta=meta, + serializer=serializer, + enabled=enabled, + task=supports_task, + ) + self.add_tool(tool_obj) + return tool_obj + + elif isinstance(name_or_fn, str): + # Case 3: @tool("custom_name") - name passed as first argument + if name is not None: + raise TypeError( + "Cannot specify both a name as first argument and as keyword argument. " + f"Use either @tool('{name_or_fn}') or @tool(name='{name}'), not both." + ) + tool_name = name_or_fn + elif name_or_fn is None: + # Case 4: @tool() or @tool(name="something") - use keyword name + tool_name = name + else: + raise TypeError( + f"First argument to @tool must be a function, string, or None, got {type(name_or_fn)}" + ) + + # Return partial for cases where we need to wait for the function + return partial( + self.tool, + name=tool_name, + title=title, + description=description, + icons=icons, + tags=tags, + output_schema=output_schema, + annotations=annotations, + exclude_args=exclude_args, + meta=meta, + enabled=enabled, + task=task, + serializer=serializer, + ) + + def resource( + self, + uri: str, + *, + name: str | None = None, + title: str | None = None, + description: str | None = None, + icons: list[mcp.types.Icon] | None = None, + mime_type: str | None = None, + tags: set[str] | None = None, + enabled: bool | None = None, + annotations: Annotations | dict[str, Any] | None = None, + meta: dict[str, Any] | None = None, + task: bool | TaskConfig | None = None, + ) -> Callable[[AnyFunction], Resource | ResourceTemplate]: + """Decorator to register a function as a resource. + + If the URI contains parameters (e.g. "resource://{param}") or the function + has parameters, it will be registered as a template resource. + + Args: + uri: URI for the resource (e.g. "resource://my-resource" or "resource://{param}") + name: Optional name for the resource + title: Optional title for the resource + description: Optional description of the resource + icons: Optional icons for the resource + mime_type: Optional MIME type for the resource + tags: Optional set of tags for categorizing the resource + enabled: Optional boolean to enable or disable the resource + annotations: Optional annotations about the resource's behavior + meta: Optional meta information about the resource + task: Optional task configuration for background execution + + Returns: + A decorator function. + + Example: + ```python + provider = LocalProvider() + + @provider.resource("data://config") + def get_config() -> str: + return '{"setting": "value"}' + + @provider.resource("data://{city}/weather") + def get_weather(city: str) -> str: + return f"Weather for {city}" + ``` + """ + if isinstance(annotations, dict): + annotations = Annotations(**annotations) + + # Check if user passed function directly instead of calling decorator + if inspect.isroutine(uri): + raise TypeError( + "The @resource decorator was used incorrectly. " + "Did you forget to call it? Use @resource('uri') instead of @resource" + ) + + def decorator(fn: AnyFunction) -> Resource | ResourceTemplate: + if isinstance(fn, classmethod): + raise ValueError( + inspect.cleandoc( + """ + To decorate a classmethod, first define the method and then call + resource() directly on the method instead of using it as a + decorator. See https://gofastmcp.com/patterns/decorating-methods + for examples and more information. + """ + ) + ) + + # Resolve task parameter - default to False for standalone usage + supports_task: bool | TaskConfig = task if task is not None else False + + # Check if this should be a template + has_uri_params = "{" in uri and "}" in uri + # Use wrapper to check for user-facing parameters + from fastmcp.server.dependencies import without_injected_parameters + + wrapper_fn = without_injected_parameters(fn) + has_func_params = bool(inspect.signature(wrapper_fn).parameters) + + if has_uri_params or has_func_params: + template = ResourceTemplate.from_function( + fn=fn, + uri_template=uri, + name=name, + title=title, + description=description, + icons=icons, + mime_type=mime_type, + tags=tags, + enabled=enabled, + annotations=annotations, + meta=meta, + task=supports_task, + ) + self.add_template(template) + return template + elif not has_uri_params and not has_func_params: + resource_obj = Resource.from_function( + fn=fn, + uri=uri, + name=name, + title=title, + description=description, + icons=icons, + mime_type=mime_type, + tags=tags, + enabled=enabled, + annotations=annotations, + meta=meta, + task=supports_task, + ) + self.add_resource(resource_obj) + return resource_obj + else: + raise ValueError( + "Invalid resource or template definition due to a " + "mismatch between URI parameters and function parameters." + ) + + return decorator + + @overload + def prompt( + self, + name_or_fn: AnyFunction, + *, + name: str | None = None, + title: str | None = None, + description: str | None = None, + icons: list[mcp.types.Icon] | None = None, + tags: set[str] | None = None, + enabled: bool | None = None, + meta: dict[str, Any] | None = None, + task: bool | TaskConfig | None = None, + ) -> FunctionPrompt: ... + + @overload + def prompt( + self, + name_or_fn: str | None = None, + *, + name: str | None = None, + title: str | None = None, + description: str | None = None, + icons: list[mcp.types.Icon] | None = None, + tags: set[str] | None = None, + enabled: bool | None = None, + meta: dict[str, Any] | None = None, + task: bool | TaskConfig | None = None, + ) -> Callable[[AnyFunction], FunctionPrompt]: ... + + def prompt( + self, + name_or_fn: str | AnyFunction | None = None, + *, + name: str | None = None, + title: str | None = None, + description: str | None = None, + icons: list[mcp.types.Icon] | None = None, + tags: set[str] | None = None, + enabled: bool | None = None, + meta: dict[str, Any] | None = None, + task: bool | TaskConfig | None = None, + ) -> ( + Callable[[AnyFunction], FunctionPrompt] + | FunctionPrompt + | partial[Callable[[AnyFunction], FunctionPrompt] | FunctionPrompt] + ): + """Decorator to register a prompt. + + This decorator supports multiple calling patterns: + - @provider.prompt (without parentheses) + - @provider.prompt() (with empty parentheses) + - @provider.prompt("custom_name") (with name as first argument) + - @provider.prompt(name="custom_name") (with name as keyword argument) + - provider.prompt(function, name="custom_name") (direct function call) + + Args: + name_or_fn: Either a function (when used as @prompt), a string name, or None + name: Optional name for the prompt (keyword-only, alternative to name_or_fn) + title: Optional title for the prompt + description: Optional description of what the prompt does + icons: Optional icons for the prompt + tags: Optional set of tags for categorizing the prompt + enabled: Optional boolean to enable or disable the prompt + meta: Optional meta information about the prompt + task: Optional task configuration for background execution + + Returns: + The registered FunctionPrompt or a decorator function. + + Example: + ```python + provider = LocalProvider() + + @provider.prompt + def analyze(topic: str) -> list: + return [{"role": "user", "content": f"Analyze: {topic}"}] + + @provider.prompt("custom_name") + def my_prompt(data: str) -> list: + return [{"role": "user", "content": data}] + ``` + """ + if isinstance(name_or_fn, classmethod): + raise ValueError( + inspect.cleandoc( + """ + To decorate a classmethod, first define the method and then call + prompt() directly on the method instead of using it as a + decorator. See https://gofastmcp.com/patterns/decorating-methods + for examples and more information. + """ + ) + ) + + # Determine the actual name and function based on the calling pattern + if inspect.isroutine(name_or_fn): + # Case 1: @prompt (without parens) - function passed directly + # Case 2: direct call like prompt(fn, name="something") + fn = name_or_fn + prompt_name = name # Use keyword name if provided, otherwise None + + # Resolve task parameter - default to False for standalone usage + supports_task: bool | TaskConfig = task if task is not None else False + + # Register the prompt immediately + prompt_obj = Prompt.from_function( + fn=fn, + name=prompt_name, + title=title, + description=description, + icons=icons, + tags=tags, + enabled=enabled, + meta=meta, + task=supports_task, + ) + self.add_prompt(prompt_obj) + return prompt_obj + + elif isinstance(name_or_fn, str): + # Case 3: @prompt("custom_name") - name passed as first argument + if name is not None: + raise TypeError( + "Cannot specify both a name as first argument and as keyword argument. " + f"Use either @prompt('{name_or_fn}') or @prompt(name='{name}'), not both." + ) + prompt_name = name_or_fn + elif name_or_fn is None: + # Case 4: @prompt() or @prompt(name="something") - use keyword name + prompt_name = name + else: + raise TypeError( + f"First argument to @prompt must be a function, string, or None, got {type(name_or_fn)}" + ) + + # Return partial for cases where we need to wait for the function + return partial( + self.prompt, + name=prompt_name, + title=title, + description=description, + icons=icons, + tags=tags, + enabled=enabled, + meta=meta, + task=task, + ) diff --git a/src/fastmcp/server/server.py b/src/fastmcp/server/server.py index d116310029..5ca4248100 100644 --- a/src/fastmcp/server/server.py +++ b/src/fastmcp/server/server.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio -import inspect import re import secrets import warnings @@ -58,6 +57,7 @@ import fastmcp.server from fastmcp.exceptions import ( DisabledError, + FastMCPError, NotFoundError, PromptError, ResourceError, @@ -67,9 +67,7 @@ from fastmcp.mcp_config import MCPConfig from fastmcp.prompts import Prompt from fastmcp.prompts.prompt import FunctionPrompt, PromptResult -from fastmcp.prompts.prompt_manager import PromptManager from fastmcp.resources.resource import Resource, ResourceContent -from fastmcp.resources.resource_manager import ResourceManager from fastmcp.resources.template import ResourceTemplate from fastmcp.server.auth import AuthProvider from fastmcp.server.event_store import EventStore @@ -80,12 +78,12 @@ ) from fastmcp.server.low_level import LowLevelServer from fastmcp.server.middleware import Middleware, MiddlewareContext -from fastmcp.server.providers import Provider +from fastmcp.server.providers import LocalProvider, Provider from fastmcp.server.tasks.capabilities import get_task_capabilities from fastmcp.server.tasks.config import TaskConfig +from fastmcp.settings import DuplicateBehavior as DuplicateBehaviorSetting from fastmcp.settings import Settings from fastmcp.tools.tool import FunctionTool, Tool, ToolResult -from fastmcp.tools.tool_manager import ToolManager from fastmcp.tools.tool_transform import ToolTransformConfig from fastmcp.utilities.cli import log_server_banner from fastmcp.utilities.components import FastMCPComponent @@ -107,6 +105,43 @@ DuplicateBehavior = Literal["warn", "error", "replace", "ignore"] + + +def _resolve_on_duplicate( + on_duplicate: DuplicateBehavior | None, + on_duplicate_tools: DuplicateBehavior | None, + on_duplicate_resources: DuplicateBehavior | None, + on_duplicate_prompts: DuplicateBehavior | None, +) -> DuplicateBehavior: + """Resolve on_duplicate from deprecated per-type params. + + Takes the most strict value if multiple are provided. + Delete this function when removing deprecated params. + """ + strictness_order: list[DuplicateBehavior] = ["error", "warn", "replace", "ignore"] + deprecated_values: list[DuplicateBehavior] = [] + + deprecated_params: list[tuple[str, DuplicateBehavior | None]] = [ + ("on_duplicate_tools", on_duplicate_tools), + ("on_duplicate_resources", on_duplicate_resources), + ("on_duplicate_prompts", on_duplicate_prompts), + ] + for name, value in deprecated_params: + if value is not None: + if fastmcp.settings.deprecation_warnings: + warnings.warn( + f"{name} is deprecated, use on_duplicate instead", + DeprecationWarning, + stacklevel=4, + ) + deprecated_values.append(value) + + if on_duplicate is None and deprecated_values: + return min(deprecated_values, key=lambda x: strictness_order.index(x)) + + return on_duplicate or "warn" + + Transport = Literal["stdio", "http", "sse", "streamable-http"] # Compiled URI parsing regex to split a URI into protocol and path components @@ -174,16 +209,15 @@ def __init__( include_tags: Collection[str] | None = None, exclude_tags: Collection[str] | None = None, include_fastmcp_meta: bool | None = None, - on_duplicate_tools: DuplicateBehavior | None = None, - on_duplicate_resources: DuplicateBehavior | None = None, - on_duplicate_prompts: DuplicateBehavior | None = None, + on_duplicate: DuplicateBehavior | None = None, strict_input_validation: bool | None = None, tasks: bool | None = None, # --- + # --- DEPRECATED parameters --- # --- - # --- The following arguments are DEPRECATED --- - # --- - # --- + on_duplicate_tools: DuplicateBehavior | None = None, + on_duplicate_resources: DuplicateBehavior | None = None, + on_duplicate_prompts: DuplicateBehavior | None = None, log_level: str | None = None, debug: bool | None = None, host: str | None = None, @@ -196,6 +230,14 @@ def __init__( sampling_handler: SamplingHandler | None = None, sampling_handler_behavior: Literal["always", "fallback"] | None = None, ): + # Resolve on_duplicate from deprecated params (delete when removing deprecation) + self._on_duplicate: DuplicateBehaviorSetting = _resolve_on_duplicate( + on_duplicate, + on_duplicate_tools, + on_duplicate_resources, + on_duplicate_prompts, + ) + # Resolve server default for background task support self._support_tasks_by_default: bool = tasks if tasks is not None else False @@ -203,20 +245,23 @@ def __init__( self._docket = None self._additional_http_routes: list[BaseRoute] = [] - self._providers: list[Provider] = list(providers or []) - self._tool_manager: ToolManager = ToolManager( - duplicate_behavior=on_duplicate_tools, - mask_error_details=mask_error_details, - transformations=tool_transformations, - ) - self._resource_manager: ResourceManager = ResourceManager( - duplicate_behavior=on_duplicate_resources, - mask_error_details=mask_error_details, - ) - self._prompt_manager: PromptManager = PromptManager( - duplicate_behavior=on_duplicate_prompts, - mask_error_details=mask_error_details, + + # Create LocalProvider for local components + self._local_provider: LocalProvider = LocalProvider( + on_duplicate=self._on_duplicate ) + + # Apply tool transformations to LocalProvider + if tool_transformations: + for tool_name, transformation in tool_transformations.items(): + self._local_provider.add_tool_transformation(tool_name, transformation) + + # LocalProvider is always first in the provider list + self._providers: list[Provider] = [ + self._local_provider, + *(providers or []), + ] + # Store mask_error_details for execution error handling self._mask_error_details: bool = ( mask_error_details @@ -413,21 +458,7 @@ async def _docket_lifespan(self) -> AsyncIterator[None]: # Store on server instance for cross-task access (FastMCPTransport) self._docket = docket - # Register local task-enabled components with Docket - # Each component checks task_config internally and no-ops if forbidden - for tool in self._tool_manager._tools.values(): - tool.register_with_docket(docket) - - for prompt in self._prompt_manager._prompts.values(): - prompt.register_with_docket(docket) - - for resource in self._resource_manager._resources.values(): - resource.register_with_docket(docket) - - for template in self._resource_manager._templates.values(): - template.register_with_docket(docket) - - # Register provider components + # Register task-enabled components from all providers (LocalProvider first) for provider in self._providers: try: tasks = await provider.get_tasks() @@ -676,15 +707,18 @@ def add_provider(self, provider: Provider) -> None: self._providers.append(provider) async def get_tools(self) -> dict[str, Tool]: - """Get all tools (unfiltered), including from providers, indexed by key.""" - all_tools = dict(await self._tool_manager.get_tools()) + """Get all tools (unfiltered), including from providers, indexed by key. - # Get tools from providers (including FastMCPProvider) + Iterates through all providers (LocalProvider first) and collects tools. + First provider wins for duplicate keys. + """ + all_tools: dict[str, Tool] = {} for provider in self._providers: try: provider_tools = await provider.list_tools() for tool in provider_tools: - all_tools[tool.key] = tool + if tool.key not in all_tools: + all_tools[tool.key] = tool except Exception as e: provider_name = getattr(provider, "server", provider).__class__.__name__ logger.warning( @@ -693,18 +727,14 @@ async def get_tools(self) -> dict[str, Tool]: if fastmcp.settings.mounted_components_raise_on_load_error: raise continue - return all_tools async def get_tool(self, key: str) -> Tool: - """Get a tool by key, checking local tools first then providers.""" - # Check local tools first - try: - return await self._tool_manager.get_tool(key) - except NotFoundError: - pass + """Get a tool by key. - # Try providers + Iterates through all providers (LocalProvider first) to find the tool. + First provider wins. + """ for provider in self._providers: try: tool = await provider.get_tool(key) @@ -722,19 +752,11 @@ async def _get_resource_or_template_or_none( Returns the original ResourceTemplate (not a Resource created from it) to preserve the registered function for task execution. - """ - # Check local concrete resources first - local_resources = await self._resource_manager.get_resources() - if uri in local_resources: - return local_resources[uri] - - # Check local templates - return the template itself, not a created resource - local_templates = await self._resource_manager.get_resource_templates() - for template in local_templates.values(): - if template.matches(uri): - return template - # Check providers + Iterates through all providers (LocalProvider first). + First provider wins. Checks concrete resources first, then templates. + """ + # First pass: check concrete resources from all providers for provider in self._providers: try: resource = await provider.get_resource(uri) @@ -743,7 +765,7 @@ async def _get_resource_or_template_or_none( except NotFoundError: continue - # Check provider templates + # Second pass: check templates from all providers for provider in self._providers: try: template = await provider.get_resource_template(uri) @@ -755,15 +777,18 @@ async def _get_resource_or_template_or_none( return None async def get_resources(self) -> dict[str, Resource]: - """Get all resources (unfiltered), including from providers, indexed by key.""" - all_resources = dict(await self._resource_manager.get_resources()) + """Get all resources (unfiltered), including from providers, indexed by key. - # Get resources from providers (including FastMCPProvider) + Iterates through all providers (LocalProvider first) and collects resources. + First provider wins for duplicate keys. + """ + all_resources: dict[str, Resource] = {} for provider in self._providers: try: provider_resources = await provider.list_resources() for resource in provider_resources: - all_resources[resource.key] = resource + if resource.key not in all_resources: + all_resources[resource.key] = resource except Exception as e: provider_name = getattr(provider, "server", provider).__class__.__name__ logger.warning( @@ -772,18 +797,14 @@ async def get_resources(self) -> dict[str, Resource]: if fastmcp.settings.mounted_components_raise_on_load_error: raise continue - return all_resources async def get_resource(self, key: str) -> Resource: - """Get a resource by key, checking local resources first then providers.""" - # Check local resources first - try: - return await self._resource_manager.get_resource(key) - except NotFoundError: - pass + """Get a resource by key. - # Try providers + Iterates through all providers (LocalProvider first) to find the resource. + First provider wins. + """ for provider in self._providers: try: resource = await provider.get_resource(key) @@ -795,15 +816,18 @@ async def get_resource(self, key: str) -> Resource: raise NotFoundError(f"Unknown resource: {key}") async def get_resource_templates(self) -> dict[str, ResourceTemplate]: - """Get all resource templates (unfiltered), including from providers, indexed by key.""" - all_templates = dict(await self._resource_manager.get_resource_templates()) + """Get all resource templates (unfiltered), including from providers, indexed by key. - # Get templates from providers (including FastMCPProvider) + Iterates through all providers (LocalProvider first) and collects templates. + First provider wins for duplicate keys. + """ + all_templates: dict[str, ResourceTemplate] = {} for provider in self._providers: try: provider_templates = await provider.list_resource_templates() for template in provider_templates: - all_templates[template.key] = template + if template.key not in all_templates: + all_templates[template.key] = template except Exception as e: provider_name = getattr(provider, "server", provider).__class__.__name__ logger.warning( @@ -812,20 +836,16 @@ async def get_resource_templates(self) -> dict[str, ResourceTemplate]: if fastmcp.settings.mounted_components_raise_on_load_error: raise continue - return all_templates async def get_resource_template(self, key: str) -> ResourceTemplate: - """Get a registered resource template by key.""" - # Check local templates first - local_templates = await self._resource_manager.get_resource_templates() - if key in local_templates: - return local_templates[key] + """Get a registered resource template by key. - # Try providers + Iterates through all providers (LocalProvider first) to find the template. + First provider wins. + """ for provider in self._providers: try: - # For templates, we use get_resource_template which matches by URI template = await provider.get_resource_template(key) if template is not None: return template @@ -835,15 +855,18 @@ async def get_resource_template(self, key: str) -> ResourceTemplate: raise NotFoundError(f"Unknown resource template: {key}") async def get_prompts(self) -> dict[str, Prompt]: - """Get all prompts (unfiltered), including from providers, indexed by key.""" - all_prompts = dict(await self._prompt_manager.get_prompts()) + """Get all prompts (unfiltered), including from providers, indexed by key. - # Get prompts from providers (including FastMCPProvider) + Iterates through all providers (LocalProvider first) and collects prompts. + First provider wins for duplicate keys. + """ + all_prompts: dict[str, Prompt] = {} for provider in self._providers: try: provider_prompts = await provider.list_prompts() for prompt in provider_prompts: - all_prompts[prompt.key] = prompt + if prompt.key not in all_prompts: + all_prompts[prompt.key] = prompt except Exception as e: provider_name = getattr(provider, "server", provider).__class__.__name__ logger.warning( @@ -852,18 +875,14 @@ async def get_prompts(self) -> dict[str, Prompt]: if fastmcp.settings.mounted_components_raise_on_load_error: raise continue - return all_prompts async def get_prompt(self, key: str) -> Prompt: - """Get a prompt by key, checking local prompts first then providers.""" - # Check local prompts first - try: - return await self._prompt_manager.get_prompt(key) - except NotFoundError: - pass + """Get a prompt by key. - # Try providers + Iterates through all providers (LocalProvider first) to find the prompt. + First provider wins. + """ for provider in self._providers: try: prompt = await provider.get_prompt(key) @@ -980,36 +999,25 @@ async def _list_tools( ) -> list[Tool]: """ List all available tools. - """ - # 1. Get local tools and filter them - local_tools = await self._tool_manager.get_tools() - local_tools_dict: dict[str, Tool] = { - tool.key: tool - for tool in local_tools.values() - if self._should_enable_component(tool) - } - # 2. Get tools from providers (later providers win for deduplication) - provider_tools_dict: dict[str, Tool] = {} + Iterates through all providers (LocalProvider first) and collects tools. + First provider wins for duplicate keys. + """ + all_tools: dict[str, Tool] = {} for provider in self._providers: try: provider_tools = await provider.list_tools() for tool in provider_tools: - if self._should_enable_component(tool): - # Later providers override earlier ones - provider_tools_dict[tool.key] = tool + if ( + self._should_enable_component(tool) + and tool.key not in all_tools + ): + all_tools[tool.key] = tool except Exception: logger.exception("Error listing tools from provider") if fastmcp.settings.mounted_components_raise_on_load_error: raise - - # Remove provider tools that conflict with local tools (local wins) - for key in local_tools_dict: - provider_tools_dict.pop(key, None) - - # Provider tools come first in the list (for visibility), - # but local tools take precedence for execution - return list(provider_tools_dict.values()) + list(local_tools_dict.values()) + return list(all_tools.values()) async def _list_resources_mcp(self) -> list[SDKResource]: """ @@ -1056,38 +1064,25 @@ async def _list_resources( ) -> list[Resource]: """ List all available resources. - """ - # 1. Filter local resources - local_resources = await self._resource_manager.get_resources() - local_resources_dict: dict[str, Resource] = { - resource.key: resource - for resource in local_resources.values() - if self._should_enable_component(resource) - } - # 2. Get resources from providers (later providers win for deduplication) - provider_resources_dict: dict[str, Resource] = {} + Iterates through all providers (LocalProvider first) and collects resources. + First provider wins for duplicate keys. + """ + all_resources: dict[str, Resource] = {} for provider in self._providers: try: provider_resources = await provider.list_resources() for resource in provider_resources: - if self._should_enable_component(resource): - # Later providers override earlier ones - provider_resources_dict[resource.key] = resource + if ( + self._should_enable_component(resource) + and resource.key not in all_resources + ): + all_resources[resource.key] = resource except Exception: logger.exception("Error listing resources from provider") if fastmcp.settings.mounted_components_raise_on_load_error: raise - - # Remove provider resources that conflict with local resources (local wins) - for key in local_resources_dict: - provider_resources_dict.pop(key, None) - - # Provider resources come first in the list (for visibility), - # but local resources take precedence for read operations - return list(provider_resources_dict.values()) + list( - local_resources_dict.values() - ) + return list(all_resources.values()) async def _list_resource_templates_mcp(self) -> list[SDKResourceTemplate]: """ @@ -1135,36 +1130,25 @@ async def _list_resource_templates( ) -> list[ResourceTemplate]: """ List all available resource templates. - """ - # 1. Filter local templates - local_templates = await self._resource_manager.get_resource_templates() - local_templates_dict: dict[str, ResourceTemplate] = { - template.key: template - for template in local_templates.values() - if self._should_enable_component(template) - } - # 2. Get resource templates from providers (later providers win for deduplication) - provider_templates_dict: dict[str, ResourceTemplate] = {} + Iterates through all providers (LocalProvider first) and collects templates. + First provider wins for duplicate keys. + """ + all_templates: dict[str, ResourceTemplate] = {} for provider in self._providers: try: provider_templates = await provider.list_resource_templates() for template in provider_templates: - if self._should_enable_component(template): - # Later providers override earlier ones - provider_templates_dict[template.key] = template + if ( + self._should_enable_component(template) + and template.key not in all_templates + ): + all_templates[template.key] = template except Exception: logger.exception("Error listing resource templates from provider") if fastmcp.settings.mounted_components_raise_on_load_error: raise - - # Remove provider templates that conflict with local templates (local wins) - for key in local_templates_dict: - provider_templates_dict.pop(key, None) - - return list(provider_templates_dict.values()) + list( - local_templates_dict.values() - ) + return list(all_templates.values()) async def _list_prompts_mcp(self) -> list[SDKPrompt]: """ @@ -1212,36 +1196,25 @@ async def _list_prompts( ) -> list[Prompt]: """ List all available prompts. - """ - # 1. Filter local prompts - local_prompts = await self._prompt_manager.get_prompts() - local_prompts_dict: dict[str, Prompt] = { - prompt.key: prompt - for prompt in local_prompts.values() - if self._should_enable_component(prompt) - } - # 2. Get prompts from providers (later providers win for deduplication) - provider_prompts_dict: dict[str, Prompt] = {} + Iterates through all providers (LocalProvider first) and collects prompts. + First provider wins for duplicate keys. + """ + all_prompts: dict[str, Prompt] = {} for provider in self._providers: try: provider_prompts = await provider.list_prompts() for prompt in provider_prompts: - if self._should_enable_component(prompt): - # Later providers override earlier ones - provider_prompts_dict[prompt.key] = prompt + if ( + self._should_enable_component(prompt) + and prompt.key not in all_prompts + ): + all_prompts[prompt.key] = prompt except Exception: logger.exception("Error listing prompts from provider") if fastmcp.settings.mounted_components_raise_on_load_error: raise - - # Remove provider prompts that conflict with local prompts (local wins) - for key in local_prompts_dict: - provider_prompts_dict.pop(key, None) - - # Provider prompts come first in the list (for visibility), - # but local prompts take precedence for render operations - return list(provider_prompts_dict.values()) + list(local_prompts_dict.values()) + return list(all_prompts.values()) async def _call_tool_mcp( self, key: str, arguments: dict[str, Any] @@ -1428,21 +1401,13 @@ async def _call_tool( context: MiddlewareContext[mcp.types.CallToolRequestParams], ) -> ToolResult | mcp.types.CreateTaskResult: """ - Call a tool + Call a tool. + + Iterates through all providers (LocalProvider first) to find the tool. + First provider wins. """ tool_name = context.message.name - # Try local tools first (static tools take precedence) - try: - tool = await self._tool_manager.get_tool(tool_name) - if self._should_enable_component(tool): - return await self._execute_tool( - tool, tool_name, context.message.arguments or {} - ) - except NotFoundError: - pass - - # Try component providers (first registered wins) for provider in self._providers: tool = await provider.get_tool(tool_name) if tool is not None and self._should_enable_component(tool): @@ -1462,13 +1427,13 @@ async def _execute_tool( """ try: return await tool._run(arguments) + except FastMCPError: + logger.exception(f"Error calling tool {tool_name!r}") + raise except (ValidationError, PydanticValidationError): # Validation errors are never masked - they indicate client input issues logger.exception(f"Error validating tool {tool_name!r}") raise - except ToolError: - logger.exception(f"Error calling tool {tool_name!r}") - raise except Exception as e: logger.exception(f"Error calling tool {tool_name!r}") if self._mask_error_details: @@ -1537,41 +1502,15 @@ async def _read_resource( """ Read a resource. + Iterates through all providers (LocalProvider first) to find the resource. + First provider wins. Checks concrete resources first, then templates. + Returns list[ResourceContent] for synchronous execution, or CreateTaskResult if the resource was submitted to Docket for background execution. """ - uri_str = str(context.message.uri) - # Try local concrete resources first (static resources take precedence) - # Note: Don't use get_resource() here because it creates resources from templates, - # which would bypass our template execution flow that handles task routing properly. - local_resources = await self._resource_manager.get_resources() - if uri_str in local_resources: - resource = local_resources[uri_str] - if self._should_enable_component(resource): - result = await self._execute_resource(resource, uri_str) - if isinstance(result, mcp.types.CreateTaskResult): - return result - # Use mime_type from ResourceContent if set, otherwise from resource - if result.mime_type is None: - result.mime_type = resource.mime_type - return [result] - - # Try local templates - templates = await self._resource_manager.get_resource_templates() - for template in templates.values(): - params = template.matches(uri_str) - if params is not None: - if self._should_enable_component(template): - # Templates need special task routing - call _execute_template - # which handles passing params to Docket - result = await self._execute_template(template, uri_str, params) - if isinstance(result, mcp.types.CreateTaskResult): - return result - return [result] - - # Try component providers (first registered wins) - concrete resources + # First pass: try concrete resources from all providers for provider in self._providers: resource = await provider.get_resource(uri_str) if resource is not None and self._should_enable_component(resource): @@ -1582,14 +1521,12 @@ async def _read_resource( result.mime_type = resource.mime_type return [result] - # Try component providers (first registered wins) - templates + # Second pass: try templates from all providers for provider in self._providers: template = await provider.get_resource_template(uri_str) if template is not None and self._should_enable_component(template): params = template.matches(uri_str) if params is not None: - # Templates need special task routing - call _execute_template - # which handles passing params to Docket result = await self._execute_template(template, uri_str, params) if isinstance(result, mcp.types.CreateTaskResult): return result @@ -1607,7 +1544,7 @@ async def _execute_resource( """ try: return await resource._read() - except (ResourceError, McpError): + except (FastMCPError, McpError): logger.exception(f"Error reading resource {uri_str!r}") raise except Exception as e: @@ -1629,7 +1566,7 @@ async def _execute_template( """ try: return await template._read(uri_str, params) - except (ResourceError, McpError): + except (FastMCPError, McpError): logger.exception(f"Error reading resource {uri_str!r}") raise except Exception as e: @@ -1709,22 +1646,14 @@ async def _get_prompt( """ Get a prompt. + Iterates through all providers (LocalProvider first) to find the prompt. + First provider wins. + Returns PromptResult for synchronous execution, or CreateTaskResult if the prompt was submitted to Docket for background execution. """ name = context.message.name - # Try local prompts first (static prompts take precedence) - try: - prompt = await self._prompt_manager.get_prompt(name) - if self._should_enable_component(prompt): - return await self._execute_prompt( - prompt, name, context.message.arguments - ) - except NotFoundError: - pass - - # Try component providers (first registered wins) for provider in self._providers: prompt = await provider.get_prompt(name) if prompt is not None and self._should_enable_component(prompt): @@ -1744,7 +1673,7 @@ async def _execute_prompt( """ try: return await prompt._render(arguments) - except (PromptError, McpError): + except (FastMCPError, McpError): logger.exception(f"Error rendering prompt {name!r}") raise except Exception as e: @@ -1765,18 +1694,7 @@ def add_tool(self, tool: Tool) -> Tool: Returns: The tool instance that was added to the server. """ - self._tool_manager.add_tool(tool) - - # Send notification if we're in a request context - try: - from fastmcp.server.dependencies import get_context - - context = get_context() - context._queue_tool_list_changed() # type: ignore[private-use] - except RuntimeError: - pass # No context available - - return tool + return self._local_provider.add_tool(tool) def remove_tool(self, name: str) -> None: """Remove a tool from the server. @@ -1787,26 +1705,20 @@ def remove_tool(self, name: str) -> None: Raises: NotFoundError: If the tool is not found """ - self._tool_manager.remove_tool(name) - - # Send notification if we're in a request context try: - from fastmcp.server.dependencies import get_context - - context = get_context() - context._queue_tool_list_changed() # type: ignore[private-use] - except RuntimeError: - pass # No context available + self._local_provider.remove_tool(name) + except KeyError: + raise NotFoundError(f"Tool {name!r} not found") from None def add_tool_transformation( self, tool_name: str, transformation: ToolTransformConfig ) -> None: """Add a tool transformation.""" - self._tool_manager.add_tool_transformation(tool_name, transformation) + self._local_provider.add_tool_transformation(tool_name, transformation) def remove_tool_transformation(self, tool_name: str) -> None: """Remove a tool transformation.""" - self._tool_manager.remove_tool_transformation(tool_name) + self._local_provider.remove_tool_transformation(tool_name) @overload def tool( @@ -1913,73 +1825,10 @@ def my_tool(x: int) -> str: server.tool(my_function, name="custom_name") ``` """ - if isinstance(annotations, dict): - annotations = ToolAnnotations(**annotations) - - if isinstance(name_or_fn, classmethod): - raise ValueError( - inspect.cleandoc( - """ - To decorate a classmethod, first define the method and then call - tool() directly on the method instead of using it as a - decorator. See https://gofastmcp.com/patterns/decorating-methods - for examples and more information. - """ - ) - ) - - # Determine the actual name and function based on the calling pattern - if inspect.isroutine(name_or_fn): - # Case 1: @tool (without parens) - function passed directly - # Case 2: direct call like tool(fn, name="something") - fn = name_or_fn - tool_name = name # Use keyword name if provided, otherwise None - - # Resolve task parameter - supports_task: bool | TaskConfig = ( - task if task is not None else self._support_tasks_by_default - ) - - # Register the tool immediately and return the tool object - # Note: Deprecation warning for exclude_args is handled in Tool.from_function - tool = Tool.from_function( - fn, - name=tool_name, - title=title, - description=description, - icons=icons, - tags=tags, - output_schema=output_schema, - annotations=annotations, - exclude_args=exclude_args, - meta=meta, - serializer=self._tool_serializer, - enabled=enabled, - task=supports_task, - ) - self.add_tool(tool) - return tool - - elif isinstance(name_or_fn, str): - # Case 3: @tool("custom_name") - name passed as first argument - if name is not None: - raise TypeError( - "Cannot specify both a name as first argument and as keyword argument. " - f"Use either @tool('{name_or_fn}') or @tool(name='{name}'), not both." - ) - tool_name = name_or_fn - elif name_or_fn is None: - # Case 4: @tool or @tool(name="something") - use keyword name - tool_name = name - else: - raise TypeError( - f"First argument to @tool must be a function, string, or None, got {type(name_or_fn)}" - ) - - # Return partial for cases where we need to wait for the function - return partial( - self.tool, - name=tool_name, + # Delegate to LocalProvider with server-level defaults + result = self._local_provider.tool( + name_or_fn, + name=name, title=title, description=description, icons=icons, @@ -1989,9 +1838,12 @@ def my_tool(x: int) -> str: exclude_args=exclude_args, meta=meta, enabled=enabled, - task=task, + task=task if task is not None else self._support_tasks_by_default, + serializer=self._tool_serializer, ) + return result + def add_resource(self, resource: Resource) -> Resource: """Add a resource to the server. @@ -2001,18 +1853,7 @@ def add_resource(self, resource: Resource) -> Resource: Returns: The resource instance that was added to the server. """ - self._resource_manager.add_resource(resource) - - # Send notification if we're in a request context - try: - from fastmcp.server.dependencies import get_context - - context = get_context() - context._queue_resource_list_changed() # type: ignore[private-use] - except RuntimeError: - pass # No context available - - return resource + return self._local_provider.add_resource(resource) def add_template(self, template: ResourceTemplate) -> ResourceTemplate: """Add a resource template to the server. @@ -2023,18 +1864,7 @@ def add_template(self, template: ResourceTemplate) -> ResourceTemplate: Returns: The template instance that was added to the server. """ - self._resource_manager.add_template(template) - - # Send notification if we're in a request context - try: - from fastmcp.server.dependencies import get_context - - context = get_context() - context._queue_resource_list_changed() # type: ignore[private-use] - except RuntimeError: - pass # No context available - - return template + return self._local_provider.add_template(template) def resource( self, @@ -2103,81 +1933,23 @@ async def get_weather(city: str) -> str: return f"Weather for {city}: {data}" ``` """ - if isinstance(annotations, dict): - annotations = Annotations(**annotations) - - # Check if user passed function directly instead of calling decorator - if inspect.isroutine(uri): - raise TypeError( - "The @resource decorator was used incorrectly. " - "Did you forget to call it? Use @resource('uri') instead of @resource" - ) + # Delegate to LocalProvider with server-level defaults + inner_decorator = self._local_provider.resource( + uri, + name=name, + title=title, + description=description, + icons=icons, + mime_type=mime_type, + tags=tags, + enabled=enabled, + annotations=annotations, + meta=meta, + task=task if task is not None else self._support_tasks_by_default, + ) def decorator(fn: AnyFunction) -> Resource | ResourceTemplate: - if isinstance(fn, classmethod): # type: ignore[reportUnnecessaryIsInstance] - raise ValueError( - inspect.cleandoc( - """ - To decorate a classmethod, first define the method and then call - resource() directly on the method instead of using it as a - decorator. See https://gofastmcp.com/patterns/decorating-methods - for examples and more information. - """ - ) - ) - - # Resolve task parameter - supports_task: bool | TaskConfig = ( - task if task is not None else self._support_tasks_by_default - ) - - # Check if this should be a template - has_uri_params = "{" in uri and "}" in uri - # Use wrapper to check for user-facing parameters - from fastmcp.server.dependencies import without_injected_parameters - - wrapper_fn = without_injected_parameters(fn) - has_func_params = bool(inspect.signature(wrapper_fn).parameters) - - if has_uri_params or has_func_params: - template = ResourceTemplate.from_function( - fn=fn, - uri_template=uri, - name=name, - title=title, - description=description, - icons=icons, - mime_type=mime_type, - tags=tags, - enabled=enabled, - annotations=annotations, - meta=meta, - task=supports_task, - ) - self.add_template(template) - return template - elif not has_uri_params and not has_func_params: - resource = Resource.from_function( - fn=fn, - uri=uri, - name=name, - title=title, - description=description, - icons=icons, - mime_type=mime_type, - tags=tags, - enabled=enabled, - annotations=annotations, - meta=meta, - task=supports_task, - ) - self.add_resource(resource) - return resource - else: - raise ValueError( - "Invalid resource or template definition due to a " - "mismatch between URI parameters and function parameters." - ) + return inner_decorator(fn) return decorator @@ -2190,18 +1962,7 @@ def add_prompt(self, prompt: Prompt) -> Prompt: Returns: The prompt instance that was added to the server. """ - self._prompt_manager.add_prompt(prompt) - - # Send notification if we're in a request context - try: - from fastmcp.server.dependencies import get_context - - context = get_context() - context._queue_prompt_list_changed() # type: ignore[private-use] - except RuntimeError: - pass # No context available - - return prompt + return self._local_provider.add_prompt(prompt) @overload def prompt( @@ -2319,74 +2080,17 @@ def another_prompt(data: str) -> list[Message]: server.prompt(my_function, name="custom_name") ``` """ - - if isinstance(name_or_fn, classmethod): - raise ValueError( - inspect.cleandoc( - """ - To decorate a classmethod, first define the method and then call - prompt() directly on the method instead of using it as a - decorator. See https://gofastmcp.com/patterns/decorating-methods - for examples and more information. - """ - ) - ) - - # Determine the actual name and function based on the calling pattern - if inspect.isroutine(name_or_fn): - # Case 1: @prompt (without parens) - function passed directly as decorator - # Case 2: direct call like prompt(fn, name="something") - fn = name_or_fn - prompt_name = name # Use keyword name if provided, otherwise None - - # Resolve task parameter - supports_task: bool | TaskConfig = ( - task if task is not None else self._support_tasks_by_default - ) - - # Register the prompt immediately - prompt = Prompt.from_function( - fn=fn, - name=prompt_name, - title=title, - description=description, - icons=icons, - tags=tags, - enabled=enabled, - meta=meta, - task=supports_task, - ) - self.add_prompt(prompt) - - return prompt - - elif isinstance(name_or_fn, str): - # Case 3: @prompt("custom_name") - name passed as first argument - if name is not None: - raise TypeError( - "Cannot specify both a name as first argument and as keyword argument. " - f"Use either @prompt('{name_or_fn}') or @prompt(name='{name}'), not both." - ) - prompt_name = name_or_fn - elif name_or_fn is None: - # Case 4: @prompt() or @prompt(name="something") - use keyword name - prompt_name = name - else: - raise TypeError( - f"First argument to @prompt must be a function, string, or None, got {type(name_or_fn)}" - ) - - # Return partial for cases where we need to wait for the function - return partial( - self.prompt, - name=prompt_name, + # Delegate to LocalProvider with server-level defaults + return self._local_provider.prompt( + name_or_fn, + name=name, title=title, description=description, icons=icons, tags=tags, enabled=enabled, meta=meta, - task=task, + task=task if task is not None else self._support_tasks_by_default, ) async def run_stdio_async( @@ -2713,14 +2417,14 @@ def add_resource_prefix(uri: str, prefix: str) -> str: for tool in (await server.get_tools()).values(): if prefix: tool = tool.model_copy(update={"name": f"{prefix}_{tool.name}"}) - self._tool_manager.add_tool(tool) + self.add_tool(tool) # Import resources and templates from the server for resource in (await server.get_resources()).values(): if prefix: new_uri = add_resource_prefix(str(resource.uri), prefix) resource = resource.model_copy(update={"uri": new_uri}) - self._resource_manager.add_resource(resource) + self.add_resource(resource) for template in (await server.get_resource_templates()).values(): if prefix: @@ -2728,13 +2432,13 @@ def add_resource_prefix(uri: str, prefix: str) -> str: template = template.model_copy( update={"uri_template": new_uri_template} ) - self._resource_manager.add_template(template) + self.add_template(template) # Import prompts from the server for prompt in (await server.get_prompts()).values(): if prefix: prompt = prompt.model_copy(update={"name": f"{prefix}_{prompt.name}"}) - self._prompt_manager.add_prompt(prompt) + self.add_prompt(prompt) if server._lifespan != default_lifespan: from warnings import warn diff --git a/src/fastmcp/tools/__init__.py b/src/fastmcp/tools/__init__.py index 6406020dcb..460396e828 100644 --- a/src/fastmcp/tools/__init__.py +++ b/src/fastmcp/tools/__init__.py @@ -1,5 +1,4 @@ -from .tool import Tool, FunctionTool -from .tool_manager import ToolManager +from .tool import FunctionTool, Tool from .tool_transform import forward, forward_raw -__all__ = ["FunctionTool", "Tool", "ToolManager", "forward", "forward_raw"] +__all__ = ["FunctionTool", "Tool", "forward", "forward_raw"] diff --git a/src/fastmcp/tools/tool_manager.py b/src/fastmcp/tools/tool_manager.py deleted file mode 100644 index 1c1e5bba43..0000000000 --- a/src/fastmcp/tools/tool_manager.py +++ /dev/null @@ -1,170 +0,0 @@ -from __future__ import annotations - -import warnings -from collections.abc import Callable, Mapping -from typing import Any - -from mcp.types import ToolAnnotations -from pydantic import ValidationError - -from fastmcp import settings -from fastmcp.exceptions import FastMCPError, NotFoundError, ToolError -from fastmcp.settings import DuplicateBehavior -from fastmcp.tools.tool import Tool, ToolResult -from fastmcp.tools.tool_transform import ( - ToolTransformConfig, - apply_transformations_to_tools, -) -from fastmcp.utilities.logging import get_logger - -logger = get_logger(__name__) - - -class ToolManager: - """Manages FastMCP tools.""" - - def __init__( - self, - duplicate_behavior: DuplicateBehavior | None = None, - mask_error_details: bool | None = None, - transformations: Mapping[str, ToolTransformConfig] | None = None, - ): - self._tools: dict[str, Tool] = {} - self.mask_error_details: bool = ( - mask_error_details or settings.mask_error_details - ) - self.transformations: dict[str, ToolTransformConfig] = dict( - transformations or {} - ) - - # Default to "warn" if None is provided - if duplicate_behavior is None: - duplicate_behavior = "warn" - - if duplicate_behavior not in DuplicateBehavior.__args__: - raise ValueError( - f"Invalid duplicate_behavior: {duplicate_behavior}. " - f"Must be one of: {', '.join(DuplicateBehavior.__args__)}" - ) - - self.duplicate_behavior = duplicate_behavior - - async def _load_tools(self) -> dict[str, Tool]: - """Return this manager's local tools with transformations applied.""" - transformed_tools = apply_transformations_to_tools( - tools=self._tools, - transformations=self.transformations, - ) - return transformed_tools - - async def has_tool(self, key: str) -> bool: - """Check if a tool exists.""" - tools = await self.get_tools() - return key in tools - - async def get_tool(self, key: str) -> Tool: - """Get tool by key.""" - tools = await self.get_tools() - if key in tools: - return tools[key] - raise NotFoundError(f"Tool {key!r} not found") - - async def get_tools(self) -> dict[str, Tool]: - """ - Gets the complete, unfiltered inventory of local tools. - """ - return await self._load_tools() - - def add_tool_from_fn( - self, - fn: Callable[..., Any], - name: str | None = None, - description: str | None = None, - tags: set[str] | None = None, - annotations: ToolAnnotations | None = None, - serializer: Callable[[Any], str] | None = None, - exclude_args: list[str] | None = None, - ) -> Tool: - """Add a tool to the server.""" - # deprecated in 2.7.0 - if settings.deprecation_warnings: - warnings.warn( - "ToolManager.add_tool_from_fn() is deprecated. Use Tool.from_function() and call add_tool() instead.", - DeprecationWarning, - stacklevel=2, - ) - tool = Tool.from_function( - fn, - name=name, - description=description, - tags=tags, - annotations=annotations, - exclude_args=exclude_args, - serializer=serializer, - ) - return self.add_tool(tool) - - def add_tool(self, tool: Tool) -> Tool: - """Register a tool with the server.""" - existing = self._tools.get(tool.key) - if existing: - if self.duplicate_behavior == "warn": - logger.warning(f"Tool already exists: {tool.key}") - self._tools[tool.key] = tool - elif self.duplicate_behavior == "replace": - self._tools[tool.key] = tool - elif self.duplicate_behavior == "error": - raise ValueError(f"Tool already exists: {tool.key}") - elif self.duplicate_behavior == "ignore": - return existing - else: - self._tools[tool.key] = tool - return tool - - def add_tool_transformation( - self, tool_name: str, transformation: ToolTransformConfig - ) -> None: - """Add a tool transformation.""" - self.transformations[tool_name] = transformation - - def get_tool_transformation(self, tool_name: str) -> ToolTransformConfig | None: - """Get a tool transformation.""" - return self.transformations.get(tool_name) - - def remove_tool_transformation(self, tool_name: str) -> None: - """Remove a tool transformation.""" - if tool_name in self.transformations: - del self.transformations[tool_name] - - def remove_tool(self, key: str) -> None: - """Remove a tool from the server. - - Args: - key: The key of the tool to remove - - Raises: - NotFoundError: If the tool is not found - """ - if key in self._tools: - del self._tools[key] - else: - raise NotFoundError(f"Tool {key!r} not found") - - async def call_tool(self, key: str, arguments: dict[str, Any]) -> ToolResult: - """ - Internal API for servers: Finds and calls a tool. - - Note: Full error handling (logging, masking) is done at the FastMCP - server level. This method provides basic error wrapping for direct usage. - """ - tool = await self.get_tool(key) - try: - return await tool.run(arguments) - except FastMCPError: - raise - except ValidationError: - raise - except Exception as e: - if self.mask_error_details: - raise ToolError(f"Error calling tool {key!r}") from e - raise ToolError(f"Error calling tool {key!r}: {e}") from e diff --git a/tests/client/test_client.py b/tests/client/test_client.py index b2b7adecab..0990443fea 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1102,7 +1102,8 @@ def test_infer_composite_client(self): transport = infer_transport(config) assert isinstance(transport, MCPConfigTransport) assert isinstance(transport.transport, FastMCPTransport) - assert len(cast(FastMCP, transport.transport.server)._providers) == 2 + # 3 providers: LocalProvider (always first) + 2 mounted MCP servers + assert len(cast(FastMCP, transport.transport.server)._providers) == 3 def test_infer_fastmcp_server(self, fastmcp_server): """FastMCP server instances should infer to FastMCPTransport.""" diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index 0ddc071bca..b4407e0c35 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -135,7 +135,7 @@ async def nested_server(): yield f"http://127.0.0.1:{port}/nest-outer/nest-inner/final/mcp" - # Cleanup: signal uvicorn to shutdown, then cancel the task + # Graceful shutdown - required for uvicorn 0.39+ due to context isolation uvicorn_server.should_exit = True with suppress(asyncio.CancelledError, asyncio.TimeoutError): await asyncio.wait_for(server_task, timeout=2.0) diff --git a/tests/contrib/test_component_manager.py b/tests/contrib/test_component_manager.py index e0fd2f04a8..29f64d5c75 100644 --- a/tests/contrib/test_component_manager.py +++ b/tests/contrib/test_component_manager.py @@ -79,7 +79,7 @@ def client(self, mcp): async def test_enable_tool_route(self, client, mcp): """Test enabling a tool via the HTTP route.""" # First disable the tool - tool = await mcp._tool_manager.get_tool("test_tool") + tool = await mcp.get_tool("test_tool") tool.enabled = False # Enable the tool via the HTTP route @@ -89,13 +89,13 @@ async def test_enable_tool_route(self, client, mcp): assert response.json() == {"message": "Enabled tool: test_tool"} # Verify the tool is enabled - tool = await mcp._tool_manager.get_tool("test_tool") + tool = await mcp.get_tool("test_tool") assert tool.enabled is True async def test_disable_tool_route(self, client, mcp): """Test disabling a tool via the HTTP route.""" # First ensure the tool is enabled - tool = await mcp._tool_manager.get_tool("test_tool") + tool = await mcp.get_tool("test_tool") tool.enabled = True # Disable the tool via the HTTP route @@ -105,13 +105,13 @@ async def test_disable_tool_route(self, client, mcp): assert response.json() == {"message": "Disabled tool: test_tool"} # Verify the tool is disabled - tool = await mcp._tool_manager.get_tool("test_tool") + tool = await mcp.get_tool("test_tool") assert tool.enabled is False async def test_enable_resource_route(self, client, mcp): """Test enabling a resource via the HTTP route.""" # First disable the resource - resource = await mcp._resource_manager.get_resource("data://test_resource") + resource = await mcp.get_resource("data://test_resource") resource.enabled = False # Enable the resource via the HTTP route @@ -121,13 +121,13 @@ async def test_enable_resource_route(self, client, mcp): assert response.json() == {"message": "Enabled resource: data://test_resource"} # Verify the resource is enabled - resource = await mcp._resource_manager.get_resource("data://test_resource") + resource = await mcp.get_resource("data://test_resource") assert resource.enabled is True async def test_disable_resource_route(self, client, mcp): """Test disabling a resource via the HTTP route.""" # First ensure the resource is enabled - resource = await mcp._resource_manager.get_resource("data://test_resource") + resource = await mcp.get_resource("data://test_resource") resource.enabled = True # Disable the resource via the HTTP route @@ -137,13 +137,13 @@ async def test_disable_resource_route(self, client, mcp): assert response.json() == {"message": "Disabled resource: data://test_resource"} # Verify the resource is disabled - resource = await mcp._resource_manager.get_resource("data://test_resource") + resource = await mcp.get_resource("data://test_resource") assert resource.enabled is False async def test_enable_template_route(self, client, mcp): """Test enabling a resource on a mounted server via the parent server's HTTP route.""" key = "data://test_resource/{id}" - resource = mcp._resource_manager._templates[key] + resource = await mcp.get_resource_template(key) resource.enabled = False response = client.post("/resources/data://test_resource/{id}/enable") assert response.status_code == status.HTTP_200_OK @@ -155,7 +155,7 @@ async def test_enable_template_route(self, client, mcp): async def test_disable_template_route(self, client, mcp): """Test disabling a resource on a mounted server via the parent server's HTTP route.""" key = "data://test_resource/{id}" - resource = mcp._resource_manager._templates[key] + resource = await mcp.get_resource_template(key) resource.enabled = True response = client.post("/resources/data://test_resource/{id}/disable") assert response.status_code == status.HTTP_200_OK @@ -167,7 +167,7 @@ async def test_disable_template_route(self, client, mcp): async def test_enable_prompt_route(self, client, mcp): """Test enabling a prompt via the HTTP route.""" # First disable the prompt - prompt = await mcp._prompt_manager.get_prompt("test_prompt") + prompt = await mcp.get_prompt("test_prompt") prompt.enabled = False # Enable the prompt via the HTTP route @@ -177,13 +177,13 @@ async def test_enable_prompt_route(self, client, mcp): assert response.json() == {"message": "Enabled prompt: test_prompt"} # Verify the prompt is enabled - prompt = await mcp._prompt_manager.get_prompt("test_prompt") + prompt = await mcp.get_prompt("test_prompt") assert prompt.enabled is True async def test_disable_prompt_route(self, client, mcp): """Test disabling a prompt via the HTTP route.""" # First ensure the prompt is enabled - prompt = await mcp._prompt_manager.get_prompt("test_prompt") + prompt = await mcp.get_prompt("test_prompt") prompt.enabled = True # Disable the prompt via the HTTP route @@ -193,13 +193,13 @@ async def test_disable_prompt_route(self, client, mcp): assert response.json() == {"message": "Disabled prompt: test_prompt"} # Verify the prompt is disabled - prompt = await mcp._prompt_manager.get_prompt("test_prompt") + prompt = await mcp.get_prompt("test_prompt") assert prompt.enabled is False async def test_enable_tool_route_on_mounted_server(self, client, mounted_mcp): """Test enabling a tool on a mounted server via the parent server's HTTP route.""" # Disable the tool on the sub-server - sub_tool = await mounted_mcp._tool_manager.get_tool("mounted_tool") + sub_tool = await mounted_mcp.get_tool("mounted_tool") sub_tool.enabled = False # Enable via parent response = client.post("/tools/sub_mounted_tool/enable") @@ -211,7 +211,7 @@ async def test_enable_tool_route_on_mounted_server(self, client, mounted_mcp): async def test_disable_tool_route_on_mounted_server(self, client, mounted_mcp): """Test disabling a tool on a mounted server via the parent server's HTTP route.""" # Enable the tool on the sub-server - sub_tool = await mounted_mcp._tool_manager.get_tool("mounted_tool") + sub_tool = await mounted_mcp.get_tool("mounted_tool") sub_tool.enabled = True # Disable via parent response = client.post("/tools/sub_mounted_tool/disable") @@ -222,40 +222,32 @@ async def test_disable_tool_route_on_mounted_server(self, client, mounted_mcp): async def test_enable_resource_route_on_mounted_server(self, client, mounted_mcp): """Test enabling a resource on a mounted server via the parent server's HTTP route.""" - resource = await mounted_mcp._resource_manager.get_resource( - "data://mounted_resource" - ) + resource = await mounted_mcp.get_resource("data://mounted_resource") resource.enabled = False response = client.post("/resources/data://sub/mounted_resource/enable") assert response.status_code == status.HTTP_200_OK assert response.json() == { "message": "Enabled resource: data://sub/mounted_resource" } - resource = await mounted_mcp._resource_manager.get_resource( - "data://mounted_resource" - ) + resource = await mounted_mcp.get_resource("data://mounted_resource") assert resource.enabled is True async def test_disable_resource_route_on_mounted_server(self, client, mounted_mcp): """Test disabling a resource on a mounted server via the parent server's HTTP route.""" - resource = await mounted_mcp._resource_manager.get_resource( - "data://mounted_resource" - ) + resource = await mounted_mcp.get_resource("data://mounted_resource") resource.enabled = True response = client.post("/resources/data://sub/mounted_resource/disable") assert response.status_code == status.HTTP_200_OK assert response.json() == { "message": "Disabled resource: data://sub/mounted_resource" } - resource = await mounted_mcp._resource_manager.get_resource( - "data://mounted_resource" - ) + resource = await mounted_mcp.get_resource("data://mounted_resource") assert resource.enabled is False async def test_enable_template_route_on_mounted_server(self, client, mounted_mcp): """Test enabling a resource on a mounted server via the parent server's HTTP route.""" key = "data://mounted_resource/{id}" - resource = mounted_mcp._resource_manager._templates[key] + resource = await mounted_mcp.get_resource_template(key) resource.enabled = False response = client.post("/resources/data://sub/mounted_resource/{id}/enable") assert response.status_code == status.HTTP_200_OK @@ -267,7 +259,7 @@ async def test_enable_template_route_on_mounted_server(self, client, mounted_mcp async def test_disable_template_route_on_mounted_server(self, client, mounted_mcp): """Test disabling a resource on a mounted server via the parent server's HTTP route.""" key = "data://mounted_resource/{id}" - resource = mounted_mcp._resource_manager._templates[key] + resource = await mounted_mcp.get_resource_template(key) resource.enabled = True response = client.post("/resources/data://sub/mounted_resource/{id}/disable") assert response.status_code == status.HTTP_200_OK @@ -278,22 +270,22 @@ async def test_disable_template_route_on_mounted_server(self, client, mounted_mc async def test_enable_prompt_route_on_mounted_server(self, client, mounted_mcp): """Test enabling a prompt on a mounted server via the parent server's HTTP route.""" - prompt = await mounted_mcp._prompt_manager.get_prompt("mounted_prompt") + prompt = await mounted_mcp.get_prompt("mounted_prompt") prompt.enabled = False response = client.post("/prompts/sub_mounted_prompt/enable") assert response.status_code == status.HTTP_200_OK assert response.json() == {"message": "Enabled prompt: sub_mounted_prompt"} - prompt = await mounted_mcp._prompt_manager.get_prompt("mounted_prompt") + prompt = await mounted_mcp.get_prompt("mounted_prompt") assert prompt.enabled is True async def test_disable_prompt_route_on_mounted_server(self, client, mounted_mcp): """Test disabling a prompt on a mounted server via the parent server's HTTP route.""" - prompt = await mounted_mcp._prompt_manager.get_prompt("mounted_prompt") + prompt = await mounted_mcp.get_prompt("mounted_prompt") prompt.enabled = True response = client.post("/prompts/sub_mounted_prompt/disable") assert response.status_code == status.HTTP_200_OK assert response.json() == {"message": "Disabled prompt: sub_mounted_prompt"} - prompt = await mounted_mcp._prompt_manager.get_prompt("mounted_prompt") + prompt = await mounted_mcp.get_prompt("mounted_prompt") assert prompt.enabled is False def test_enable_nonexistent_tool(self, client): @@ -383,7 +375,7 @@ def test_prompt() -> str: async def test_unauthorized_enable_tool(self): """Test that unauthenticated requests to enable a tool are rejected.""" - tool = await self.mcp._tool_manager.get_tool("test_tool") + tool = await self.mcp.get_tool("test_tool") tool.enabled = False response = self.client.post("/tools/test_tool/enable") @@ -392,7 +384,7 @@ async def test_unauthorized_enable_tool(self): async def test_authorized_enable_tool(self): """Test that authenticated requests to enable a tool are allowed.""" - tool = await self.mcp._tool_manager.get_tool("test_tool") + tool = await self.mcp.get_tool("test_tool") tool.enabled = False response = self.client.post( @@ -404,7 +396,7 @@ async def test_authorized_enable_tool(self): async def test_unauthorized_disable_tool(self): """Test that unauthenticated requests to disable a tool are rejected.""" - tool = await self.mcp._tool_manager.get_tool("test_tool") + tool = await self.mcp.get_tool("test_tool") tool.enabled = True response = self.client.post("/tools/test_tool/disable") @@ -413,7 +405,7 @@ async def test_unauthorized_disable_tool(self): async def test_authorized_disable_tool(self): """Test that authenticated requests to disable a tool are allowed.""" - tool = await self.mcp._tool_manager.get_tool("test_tool") + tool = await self.mcp.get_tool("test_tool") tool.enabled = True response = self.client.post( @@ -426,7 +418,7 @@ async def test_authorized_disable_tool(self): async def test_forbidden_enable_tool(self): """Test that requests with insufficient scopes are rejected.""" - tool = await self.mcp._tool_manager.get_tool("test_tool") + tool = await self.mcp.get_tool("test_tool") tool.enabled = False response = self.client.post( @@ -438,7 +430,7 @@ async def test_forbidden_enable_tool(self): async def test_authorized_enable_resource(self): """Test that authenticated requests to enable a resource are allowed.""" - resource = await self.mcp._resource_manager.get_resource("data://test_resource") + resource = await self.mcp.get_resource("data://test_resource") resource.enabled = False response = self.client.post( @@ -451,7 +443,7 @@ async def test_authorized_enable_resource(self): async def test_unauthorized_disable_resource(self): """Test that unauthenticated requests to disable a resource are rejected.""" - resource = await self.mcp._resource_manager.get_resource("data://test_resource") + resource = await self.mcp.get_resource("data://test_resource") resource.enabled = True response = self.client.post("/resources/data://test_resource/disable") @@ -460,7 +452,7 @@ async def test_unauthorized_disable_resource(self): async def test_forbidden_enable_resource(self): """Test that requests with insufficient scopes are rejected.""" - resource = await self.mcp._resource_manager.get_resource("data://test_resource") + resource = await self.mcp.get_resource("data://test_resource") resource.enabled = False response = self.client.post( @@ -472,7 +464,7 @@ async def test_forbidden_enable_resource(self): async def test_authorized_disable_resource(self): """Test that authenticated requests to disable a resource are allowed.""" - resource = await self.mcp._resource_manager.get_resource("data://test_resource") + resource = await self.mcp.get_resource("data://test_resource") resource.enabled = True response = self.client.post( @@ -485,7 +477,7 @@ async def test_authorized_disable_resource(self): async def test_unauthorized_enable_prompt(self): """Test that unauthenticated requests to enable a prompt are rejected.""" - prompt = await self.mcp._prompt_manager.get_prompt("test_prompt") + prompt = await self.mcp.get_prompt("test_prompt") prompt.enabled = False response = self.client.post("/prompts/test_prompt/enable") @@ -494,7 +486,7 @@ async def test_unauthorized_enable_prompt(self): async def test_authorized_enable_prompt(self): """Test that authenticated requests to enable a prompt are allowed.""" - prompt = await self.mcp._prompt_manager.get_prompt("test_prompt") + prompt = await self.mcp.get_prompt("test_prompt") prompt.enabled = False response = self.client.post( @@ -507,7 +499,7 @@ async def test_authorized_enable_prompt(self): async def test_unauthorized_disable_prompt(self): """Test that unauthenticated requests to disable a prompt are rejected.""" - prompt = await self.mcp._prompt_manager.get_prompt("test_prompt") + prompt = await self.mcp.get_prompt("test_prompt") prompt.enabled = True response = self.client.post("/prompts/test_prompt/disable") @@ -516,7 +508,7 @@ async def test_unauthorized_disable_prompt(self): async def test_forbidden_disable_prompt(self): """Test that requests with insufficient scopes are rejected.""" - prompt = await self.mcp._prompt_manager.get_prompt("test_prompt") + prompt = await self.mcp.get_prompt("test_prompt") prompt.enabled = True response = self.client.post( @@ -528,7 +520,7 @@ async def test_forbidden_disable_prompt(self): async def test_authorized_disable_prompt(self): """Test that authenticated requests to disable a prompt are allowed.""" - prompt = await self.mcp._prompt_manager.get_prompt("test_prompt") + prompt = await self.mcp.get_prompt("test_prompt") prompt.enabled = True response = self.client.post( @@ -567,36 +559,32 @@ def client_with_path(self, mcp_with_path): return TestClient(mcp_with_path.http_app()) async def test_enable_tool_route_with_path(self, client_with_path, mcp_with_path): - tool = await mcp_with_path._tool_manager.get_tool("test_tool") + tool = await mcp_with_path.get_tool("test_tool") tool.enabled = False response = client_with_path.post("/test/tools/test_tool/enable") assert response.status_code == status.HTTP_200_OK assert response.json() == {"message": "Enabled tool: test_tool"} - tool = await mcp_with_path._tool_manager.get_tool("test_tool") + tool = await mcp_with_path.get_tool("test_tool") assert tool.enabled is True async def test_disable_resource_route_with_path( self, client_with_path, mcp_with_path ): - resource = await mcp_with_path._resource_manager.get_resource( - "data://test_resource" - ) + resource = await mcp_with_path.get_resource("data://test_resource") resource.enabled = True response = client_with_path.post("/test/resources/data://test_resource/disable") assert response.status_code == status.HTTP_200_OK assert response.json() == {"message": "Disabled resource: data://test_resource"} - resource = await mcp_with_path._resource_manager.get_resource( - "data://test_resource" - ) + resource = await mcp_with_path.get_resource("data://test_resource") assert resource.enabled is False async def test_enable_prompt_route_with_path(self, client_with_path, mcp_with_path): - prompt = await mcp_with_path._prompt_manager.get_prompt("test_prompt") + prompt = await mcp_with_path.get_prompt("test_prompt") prompt.enabled = False response = client_with_path.post("/test/prompts/test_prompt/enable") assert response.status_code == status.HTTP_200_OK assert response.json() == {"message": "Enabled prompt: test_prompt"} - prompt = await mcp_with_path._prompt_manager.get_prompt("test_prompt") + prompt = await mcp_with_path.get_prompt("test_prompt") assert prompt.enabled is True @@ -643,14 +631,14 @@ def test_prompt() -> str: self.client = TestClient(self.mcp.http_app()) async def test_unauthorized_enable_tool(self): - tool = await self.mcp._tool_manager.get_tool("test_tool") + tool = await self.mcp.get_tool("test_tool") tool.enabled = False response = self.client.post("/test/tools/test_tool/enable") assert response.status_code == 401 assert tool.enabled is False async def test_forbidden_enable_tool(self): - tool = await self.mcp._tool_manager.get_tool("test_tool") + tool = await self.mcp.get_tool("test_tool") tool.enabled = False response = self.client.post( "/test/tools/test_tool/enable", @@ -660,7 +648,7 @@ async def test_forbidden_enable_tool(self): assert tool.enabled is False async def test_authorized_enable_tool(self): - tool = await self.mcp._tool_manager.get_tool("test_tool") + tool = await self.mcp.get_tool("test_tool") tool.enabled = False response = self.client.post( "/test/tools/test_tool/enable", @@ -668,18 +656,18 @@ async def test_authorized_enable_tool(self): ) assert response.status_code == 200 assert response.json() == {"message": "Enabled tool: test_tool"} - tool = await self.mcp._tool_manager.get_tool("test_tool") + tool = await self.mcp.get_tool("test_tool") assert tool.enabled is True async def test_unauthorized_disable_resource(self): - resource = await self.mcp._resource_manager.get_resource("data://test_resource") + resource = await self.mcp.get_resource("data://test_resource") resource.enabled = True response = self.client.post("/test/resources/data://test_resource/disable") assert response.status_code == 401 assert resource.enabled is True async def test_forbidden_disable_resource(self): - resource = await self.mcp._resource_manager.get_resource("data://test_resource") + resource = await self.mcp.get_resource("data://test_resource") resource.enabled = True response = self.client.post( "/test/resources/data://test_resource/disable", @@ -689,7 +677,7 @@ async def test_forbidden_disable_resource(self): assert resource.enabled is True async def test_authorized_disable_resource(self): - resource = await self.mcp._resource_manager.get_resource("data://test_resource") + resource = await self.mcp.get_resource("data://test_resource") resource.enabled = True response = self.client.post( "/test/resources/data://test_resource/disable", @@ -697,18 +685,18 @@ async def test_authorized_disable_resource(self): ) assert response.status_code == 200 assert response.json() == {"message": "Disabled resource: data://test_resource"} - resource = await self.mcp._resource_manager.get_resource("data://test_resource") + resource = await self.mcp.get_resource("data://test_resource") assert resource.enabled is False async def test_unauthorized_enable_prompt(self): - prompt = await self.mcp._prompt_manager.get_prompt("test_prompt") + prompt = await self.mcp.get_prompt("test_prompt") prompt.enabled = False response = self.client.post("/test/prompts/test_prompt/enable") assert response.status_code == 401 assert prompt.enabled is False async def test_forbidden_enable_prompt(self): - prompt = await self.mcp._prompt_manager.get_prompt("test_prompt") + prompt = await self.mcp.get_prompt("test_prompt") prompt.enabled = False response = self.client.post( "/test/prompts/test_prompt/enable", @@ -718,7 +706,7 @@ async def test_forbidden_enable_prompt(self): assert prompt.enabled is False async def test_authorized_enable_prompt(self): - prompt = await self.mcp._prompt_manager.get_prompt("test_prompt") + prompt = await self.mcp.get_prompt("test_prompt") prompt.enabled = False response = self.client.post( "/test/prompts/test_prompt/enable", @@ -726,5 +714,5 @@ async def test_authorized_enable_prompt(self): ) assert response.status_code == 200 assert response.json() == {"message": "Enabled prompt: test_prompt"} - prompt = await self.mcp._prompt_manager.get_prompt("test_prompt") + prompt = await self.mcp.get_prompt("test_prompt") assert prompt.enabled is True diff --git a/tests/deprecated/test_exclude_args.py b/tests/deprecated/test_exclude_args.py index 759dd58117..141cfda361 100644 --- a/tests/deprecated/test_exclude_args.py +++ b/tests/deprecated/test_exclude_args.py @@ -7,8 +7,8 @@ from fastmcp.tools.tool import Tool -async def test_tool_exclude_args_in_tool_manager(): - """Test that tool args are excluded in the tool manager.""" +async def test_tool_exclude_args(): + """Test that tool args are excluded.""" mcp = FastMCP("Test Server") @mcp.tool(exclude_args=["state"]) @@ -19,7 +19,7 @@ def echo(message: str, state: dict[str, Any] | None = None) -> str: pass return message - tools_dict = await mcp._tool_manager.get_tools() + tools_dict = await mcp.get_tools() tools = list(tools_dict.values()) assert len(tools) == 1 assert "state" not in tools[0].parameters["properties"] @@ -60,8 +60,8 @@ def create_item( ) mcp.add_tool(tool) - # Check internal tool objects directly - tools_dict = await mcp._tool_manager.get_tools() + # Check tool via public API + tools_dict = await mcp.get_tools() tools = list(tools_dict.values()) assert len(tools) == 1 assert "state" not in tools[0].parameters["properties"] diff --git a/tests/deprecated/test_import_server.py b/tests/deprecated/test_import_server.py index 11208dbf04..a8f2ad70b6 100644 --- a/tests/deprecated/test_import_server.py +++ b/tests/deprecated/test_import_server.py @@ -24,11 +24,13 @@ def sub_tool() -> str: await main_app.import_server(sub_app, "sub") # Verify the tool was imported with the prefix - assert "sub_sub_tool" in main_app._tool_manager._tools - assert "sub_tool" in sub_app._tool_manager._tools + main_tools = await main_app.get_tools() + sub_tools = await sub_app.get_tools() + assert "sub_sub_tool" in main_tools + assert "sub_tool" in sub_tools # Verify the original tool still exists in the sub-app - tool = await main_app._tool_manager.get_tool("sub_sub_tool") + tool = await main_app.get_tool("sub_sub_tool") assert tool is not None # import_server creates copies with prefixed names (unlike mount which proxies) assert tool.name == "sub_sub_tool" @@ -57,8 +59,9 @@ def get_headlines() -> str: await main_app.import_server(news_app, "news") # Verify tools were imported with the correct prefixes - assert "weather_get_forecast" in main_app._tool_manager._tools - assert "news_get_headlines" in main_app._tool_manager._tools + tools = await main_app.get_tools() + assert "weather_get_forecast" in tools + assert "news_get_headlines" in tools async def test_import_combines_tools(): @@ -79,16 +82,18 @@ def second_tool() -> str: # Import first app await main_app.import_server(first_app, "api") - assert "api_first_tool" in main_app._tool_manager._tools + tools = await main_app.get_tools() + assert "api_first_tool" in tools # Import second app to same prefix await main_app.import_server(second_app, "api") # Verify second tool is there - assert "api_second_tool" in main_app._tool_manager._tools + tools = await main_app.get_tools() + assert "api_second_tool" in tools # Tools from both imports are combined - assert "api_first_tool" in main_app._tool_manager._tools + assert "api_first_tool" in tools async def test_import_with_resources(): @@ -106,7 +111,8 @@ async def get_users(): await main_app.import_server(data_app, "data") # Verify the resource was imported with the prefix - assert "data://data/users" in main_app._resource_manager._resources + resources = await main_app.get_resources() + assert "data://data/users" in resources async def test_import_with_resource_templates(): @@ -124,7 +130,8 @@ def get_user_profile(user_id: str) -> dict: await main_app.import_server(user_app, "api") # Verify the template was imported with the prefix - assert "users://api/{user_id}/profile" in main_app._resource_manager._templates + templates = await main_app.get_resource_templates() + assert "users://api/{user_id}/profile" in templates async def test_import_with_prompts(): @@ -142,7 +149,8 @@ def greeting(name: str) -> str: await main_app.import_server(assistant_app, "assistant") # Verify the prompt was imported with the prefix - assert "assistant_greeting" in main_app._prompt_manager._prompts + prompts = await main_app.get_prompts() + assert "assistant_greeting" in prompts async def test_import_multiple_resource_templates(): @@ -166,8 +174,9 @@ def get_news(category: str) -> str: await main_app.import_server(news_app, "content") # Verify templates were imported with correct prefixes - assert "weather://data/{city}" in main_app._resource_manager._templates - assert "news://content/{category}" in main_app._resource_manager._templates + templates = await main_app.get_resource_templates() + assert "weather://data/{city}" in templates + assert "news://content/{category}" in templates async def test_import_multiple_prompts(): @@ -191,8 +200,9 @@ def explain_sql(query: str) -> str: await main_app.import_server(sql_app, "sql") # Verify prompts were imported with correct prefixes - assert "python_review_python" in main_app._prompt_manager._prompts - assert "sql_explain_sql" in main_app._prompt_manager._prompts + prompts = await main_app.get_prompts() + assert "python_review_python" in prompts + assert "sql_explain_sql" in prompts async def test_tool_custom_name_preserved_when_imported(): @@ -207,7 +217,7 @@ def fetch_data(query: str) -> str: await main_app.import_server(api_app, "api") # Check that the tool is accessible by its prefixed name - tool = await main_app._tool_manager.get_tool("api_get_data") + tool = await main_app.get_tool("api_get_data") assert tool is not None # Check that the function name is preserved @@ -243,7 +253,7 @@ def calculate_value(input: int) -> int: await service_app.import_server(provider_app, "provider") # Tool is accessible in the service app with the first prefix - tool = await service_app._tool_manager.get_tool("provider_compute") + tool = await service_app.get_tool("provider_compute") assert tool is not None assert isinstance(tool, FunctionTool) assert get_fn_name(tool.fn) == "calculate_value" @@ -263,7 +273,7 @@ def calculate_value(input: int) -> int: await main_app.import_server(service_app, "service") # Tool is accessible in the main app with both prefixes - tool = await main_app._tool_manager.get_tool("service_provider_compute") + tool = await main_app.get_tool("service_provider_compute") assert tool is not None @@ -422,10 +432,14 @@ def sub_prompt() -> str: await main_app.import_server(sub_app) # Verify all component types are accessible with original names - assert "sub_tool" in main_app._tool_manager._tools - assert "data://config" in main_app._resource_manager._resources - assert "users://{user_id}/info" in main_app._resource_manager._templates - assert "sub_prompt" in main_app._prompt_manager._prompts + tools = await main_app.get_tools() + resources = await main_app.get_resources() + templates = await main_app.get_resource_templates() + prompts = await main_app.get_prompts() + assert "sub_tool" in tools + assert "data://config" in resources + assert "users://{user_id}/info" in templates + assert "sub_prompt" in prompts # Test actual functionality through Client async with Client(main_app) as client: diff --git a/tests/deprecated/test_settings.py b/tests/deprecated/test_settings.py index 1d95b94075..3766c878d5 100644 --- a/tests/deprecated/test_settings.py +++ b/tests/deprecated/test_settings.py @@ -173,9 +173,7 @@ def test_non_deprecated_kwargs_no_warnings(self): server = FastMCP( name="TestServer", instructions="Test instructions", - on_duplicate_tools="warn", - on_duplicate_resources="error", - on_duplicate_prompts="replace", + on_duplicate="warn", # New unified parameter mask_error_details=True, ) @@ -189,6 +187,29 @@ def test_non_deprecated_kwargs_no_warnings(self): assert server.name == "TestServer" assert server.instructions == "Test instructions" + def test_deprecated_duplicate_kwargs_raise_warnings(self): + """Test that deprecated on_duplicate_* kwargs raise warnings.""" + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") + FastMCP( + name="TestServer", + on_duplicate_tools="warn", + on_duplicate_resources="error", + on_duplicate_prompts="replace", + ) + + # Should have 3 deprecation warnings (one for each deprecated param) + deprecation_warnings = [ + w for w in recorded_warnings if issubclass(w.category, DeprecationWarning) + ] + assert len(deprecation_warnings) == 3 + + # Check warning messages + warning_messages = [str(w.message) for w in deprecation_warnings] + assert any("on_duplicate_tools" in msg for msg in warning_messages) + assert any("on_duplicate_resources" in msg for msg in warning_messages) + assert any("on_duplicate_prompts" in msg for msg in warning_messages) + def test_none_values_no_warnings(self): """Test that None values for deprecated kwargs don't raise warnings.""" with warnings.catch_warnings(record=True) as recorded_warnings: diff --git a/tests/prompts/test_prompt_manager.py b/tests/prompts/test_prompt_manager.py deleted file mode 100644 index d9c67270be..0000000000 --- a/tests/prompts/test_prompt_manager.py +++ /dev/null @@ -1,479 +0,0 @@ -import functools -from typing import Annotated - -import pytest - -from fastmcp import Context, FastMCP -from fastmcp.exceptions import NotFoundError, PromptError -from fastmcp.prompts import Prompt -from fastmcp.prompts.prompt import ( - FunctionPrompt, - PromptMessage, - PromptResult, - TextContent, -) -from fastmcp.prompts.prompt_manager import PromptManager -from fastmcp.utilities.tests import caplog_for_fastmcp -from tests.conftest import get_fn_name - - -class TestPromptManager: - async def test_add_prompt(self): - """Test adding a prompt to the manager.""" - - def fn() -> str: - return "Hello, world!" - - manager = PromptManager() - prompt = Prompt.from_function(fn) - added = manager.add_prompt(prompt) - assert added == prompt - assert await manager.get_prompt("fn") == prompt - - async def test_add_duplicate_prompt(self, caplog): - """Test adding the same prompt twice.""" - - def fn() -> str: - return "Hello, world!" - - manager = PromptManager(duplicate_behavior="warn") - prompt = Prompt.from_function(fn) - first = manager.add_prompt(prompt) - - with caplog_for_fastmcp(caplog): - second = manager.add_prompt(prompt) - - assert first == second - assert "Prompt already exists" in caplog.text - - async def test_disable_warn_on_duplicate_prompts(self, caplog): - """Test disabling warning on duplicate prompts.""" - - def fn() -> str: - return "Hello, world!" - - manager = PromptManager(duplicate_behavior="ignore") - prompt = Prompt.from_function(fn) - first = manager.add_prompt(prompt) - second = manager.add_prompt(prompt) - assert first == second - assert "Prompt already exists" not in caplog.text - - async def test_warn_on_duplicate_prompts(self, caplog): - """Test warning on duplicate prompts.""" - manager = PromptManager(duplicate_behavior="warn") - - def test_fn() -> str: - return "Test prompt" - - prompt = Prompt.from_function(test_fn, name="test_prompt") - - manager.add_prompt(prompt) - - with caplog_for_fastmcp(caplog): - manager.add_prompt(prompt) - - assert "Prompt already exists: test_prompt" in caplog.text - # Should have the prompt - assert await manager.get_prompt("test_prompt") is not None - - async def test_error_on_duplicate_prompts(self): - """Test error on duplicate prompts.""" - manager = PromptManager(duplicate_behavior="error") - - def test_fn() -> str: - return "Test prompt" - - prompt = Prompt.from_function(test_fn, name="test_prompt") - - manager.add_prompt(prompt) - - with pytest.raises(ValueError, match="Prompt already exists: test_prompt"): - manager.add_prompt(prompt) - - async def test_replace_duplicate_prompts(self): - """Test replacing duplicate prompts.""" - manager = PromptManager(duplicate_behavior="replace") - - def original_fn() -> str: - return "Original prompt" - - def replacement_fn() -> str: - return "Replacement prompt" - - prompt1 = Prompt.from_function(original_fn, name="test_prompt") - prompt2 = Prompt.from_function(replacement_fn, name="test_prompt") - - manager.add_prompt(prompt1) - manager.add_prompt(prompt2) - - # Should have replaced with the new prompt - prompt = await manager.get_prompt("test_prompt") - assert prompt is not None - assert isinstance(prompt, FunctionPrompt) - assert get_fn_name(prompt.fn) == "replacement_fn" - - async def test_ignore_duplicate_prompts(self): - """Test ignoring duplicate prompts.""" - manager = PromptManager(duplicate_behavior="ignore") - - def original_fn() -> str: - return "Original prompt" - - def replacement_fn() -> str: - return "Replacement prompt" - - prompt1 = Prompt.from_function(original_fn, name="test_prompt") - prompt2 = Prompt.from_function(replacement_fn, name="test_prompt") - - manager.add_prompt(prompt1) - result = manager.add_prompt(prompt2) - - # Should keep the original - prompt = await manager.get_prompt("test_prompt") - assert prompt is not None - assert isinstance(prompt, FunctionPrompt) - assert get_fn_name(prompt.fn) == "original_fn" - # Result should be the original prompt - assert isinstance(result, FunctionPrompt) - assert get_fn_name(result.fn) == "original_fn" - - async def test_get_prompts(self): - """Test retrieving all prompts.""" - - def fn1() -> str: - return "Hello, world!" - - def fn2() -> str: - return "Goodbye, world!" - - manager = PromptManager() - prompt1 = Prompt.from_function(fn1) - prompt2 = Prompt.from_function(fn2) - manager.add_prompt(prompt1) - manager.add_prompt(prompt2) - prompts = await manager.get_prompts() - assert len(prompts) == 2 - assert prompts["fn1"] == prompt1 - assert prompts["fn2"] == prompt2 - - -class TestRenderPrompt: - async def test_render_prompt(self): - """Test rendering a prompt.""" - - def fn() -> str: - """An example prompt.""" - return "Hello, world!" - - manager = PromptManager() - prompt = Prompt.from_function(fn) - manager.add_prompt(prompt) - result = await manager.render_prompt("fn") - assert isinstance(result, PromptResult) - assert result.description == "An example prompt." - assert result.messages == [ - PromptMessage( - role="user", content=TextContent(type="text", text="Hello, world!") - ) - ] - - async def test_render_prompt_with_args(self): - """Test rendering a prompt with arguments.""" - - def fn(name: str) -> str: - """An example prompt.""" - return f"Hello, {name}!" - - manager = PromptManager() - prompt = Prompt.from_function(fn) - manager.add_prompt(prompt) - result = await manager.render_prompt("fn", arguments={"name": "World"}) - assert isinstance(result, PromptResult) - assert result.description == "An example prompt." - assert result.messages == [ - PromptMessage( - role="user", content=TextContent(type="text", text="Hello, World!") - ) - ] - - async def test_render_prompt_callable_object(self): - """Test rendering a prompt with a callable object.""" - - class MyPrompt: - """A callable object that can be used as a prompt.""" - - def __call__(self, name: str) -> str: - """ignore this""" - return f"Hello, {name}!" - - manager = PromptManager() - prompt = Prompt.from_function(MyPrompt()) - manager.add_prompt(prompt) - result = await manager.render_prompt("MyPrompt", arguments={"name": "World"}) - assert isinstance(result, PromptResult) - assert result.description == "A callable object that can be used as a prompt." - assert result.messages == [ - PromptMessage( - role="user", content=TextContent(type="text", text="Hello, World!") - ) - ] - - async def test_render_prompt_callable_object_async(self): - """Test rendering a prompt with a callable object.""" - - class MyPrompt: - """A callable object that can be used as a prompt.""" - - async def __call__(self, name: str) -> str: - """ignore this""" - return f"Hello, {name}!" - - manager = PromptManager() - prompt = Prompt.from_function(MyPrompt()) - manager.add_prompt(prompt) - result = await manager.render_prompt("MyPrompt", arguments={"name": "World"}) - assert isinstance(result, PromptResult) - assert result.description == "A callable object that can be used as a prompt." - assert result.messages == [ - PromptMessage( - role="user", content=TextContent(type="text", text="Hello, World!") - ) - ] - - async def test_render_unknown_prompt(self): - """Test rendering a non-existent prompt.""" - manager = PromptManager() - with pytest.raises(NotFoundError, match="Unknown prompt: unknown"): - await manager.render_prompt("unknown") - - async def test_render_prompt_with_missing_args(self): - """Test rendering a prompt with missing required arguments.""" - - def fn(name: str) -> str: - return f"Hello, {name}!" - - manager = PromptManager() - prompt = Prompt.from_function(fn) - manager.add_prompt(prompt) - with pytest.raises(PromptError, match="Missing required arguments"): - await manager.render_prompt("fn") - - async def test_prompt_with_varargs_not_allowed(self): - """Test that a prompt with *args is not allowed.""" - - def fn(*args: int) -> str: - return f"Hello, {args}!" - - manager = PromptManager() - with pytest.raises( - ValueError, match=r"Functions with \*args are not supported as prompts" - ): - manager.add_prompt(Prompt.from_function(fn)) - - async def test_prompt_with_varkwargs_not_allowed(self): - """Test that a prompt with **kwargs is not allowed.""" - - def fn(**kwargs: int) -> str: - return f"Hello, {kwargs}!" - - manager = PromptManager() - with pytest.raises( - ValueError, match=r"Functions with \*\*kwargs are not supported as prompts" - ): - manager.add_prompt(Prompt.from_function(fn)) - - -class TestPromptTags: - """Test functionality related to prompt tags.""" - - async def test_add_prompt_with_tags(self): - """Test adding a prompt with tags.""" - - def greeting() -> str: - return "Hello, world!" - - manager = PromptManager() - prompt = Prompt.from_function(greeting, tags={"greeting", "simple"}) - manager.add_prompt(prompt) - - prompt = await manager.get_prompt("greeting") - assert prompt is not None - assert prompt.tags == {"greeting", "simple"} - - async def test_add_prompt_with_empty_tags(self): - """Test adding a prompt with empty tags.""" - - def greeting() -> str: - return "Hello, world!" - - manager = PromptManager() - prompt = Prompt.from_function(greeting, tags=set()) - manager.add_prompt(prompt) - - prompt = await manager.get_prompt("greeting") - assert prompt is not None - assert prompt.tags == set() - - async def test_add_prompt_with_none_tags(self): - """Test adding a prompt with None tags.""" - - def greeting() -> str: - return "Hello, world!" - - manager = PromptManager() - prompt = Prompt.from_function(greeting, tags=None) - manager.add_prompt(prompt) - - prompt = await manager.get_prompt("greeting") - assert prompt is not None - assert prompt.tags == set() - - async def test_list_prompts_with_tags(self): - """Test listing prompts with specific tags.""" - - def greeting() -> str: - return "Hello, world!" - - def weather(location: str) -> str: - return f"Weather for {location}" - - def summary(text: str) -> str: - return f"Summary of: {text}" - - manager = PromptManager() - manager.add_prompt(Prompt.from_function(greeting, tags={"greeting", "simple"})) - manager.add_prompt(Prompt.from_function(weather, tags={"weather", "location"})) - manager.add_prompt( - Prompt.from_function(summary, tags={"summary", "nlp", "simple"}) - ) - - # Filter prompts by tags - prompts = await manager.get_prompts() - simple_prompts = [p for p in prompts.values() if "simple" in p.tags] - assert len(simple_prompts) == 2 - assert {p.name for p in simple_prompts} == {"greeting", "summary"} - - nlp_prompts = [p for p in prompts.values() if "nlp" in p.tags] - assert len(nlp_prompts) == 1 - assert nlp_prompts[0].name == "summary" - - -class TestContextHandling: - """Test context handling in prompts.""" - - def test_context_parameter_detection(self): - """Test that context parameters are properly detected in - Prompt.from_function().""" - - def prompt_with_context(x: int, ctx: Context) -> str: - return str(x) - - Prompt.from_function(prompt_with_context) - - def prompt_without_context(x: int) -> str: - return str(x) - - Prompt.from_function(prompt_without_context) - - def test_parameterized_context_parameter_detection(self): - """Test that parameterized context parameters are properly detected in - Prompt.from_function().""" - - def prompt_with_context(x: int, ctx: Context) -> str: - return str(x) - - Prompt.from_function(prompt_with_context) - - def test_parameterized_union_context_parameter_detection(self): - """Test that context parameters in a union are properly detected in - Prompt.from_function().""" - - def prompt_with_context(x: int, ctx: Context | None) -> str: - return str(x) - - Prompt.from_function(prompt_with_context) - - async def test_context_injection(self): - """Test that context is properly injected during prompt rendering.""" - - def prompt_with_context(x: int, ctx: Context) -> str: - assert isinstance(ctx, Context) - return str(x) - - prompt = Prompt.from_function(prompt_with_context) - - from fastmcp import FastMCP - - mcp = FastMCP() - context = Context(fastmcp=mcp) - - async with context: - result = await prompt.render(arguments={"x": 42}) - - assert isinstance(result, PromptResult) - assert len(result.messages) == 1 - assert isinstance(result.messages[0].content, TextContent) - assert result.messages[0].content.text == "42" - - async def test_context_optional(self): - """Test that context is optional when rendering prompts.""" - - def prompt_with_context(x: int, ctx: Context | None = None) -> str: - return str(x) - - prompt = Prompt.from_function(prompt_with_context) - - # Even for optional context, we need to provide a context - from fastmcp import FastMCP - - mcp = FastMCP() - context = Context(fastmcp=mcp) - - async with context: - result = await prompt.render( - arguments={"x": 42}, - ) - - assert isinstance(result, PromptResult) - assert len(result.messages) == 1 - assert isinstance(result.messages[0].content, TextContent) - assert result.messages[0].content.text == "42" - - async def test_annotated_context_parameter_detection(self): - """Test that annotated context parameters are properly detected in - Prompt.from_function().""" - - def prompt_with_context(x: int, ctx: Annotated[Context, "ctx"]) -> str: - return str(x) - - Prompt.from_function(prompt_with_context) - - async def test_context_with_functools_wraps_decorator(self): - """Regression test for #2524: decorated prompts with Context should work.""" - - def custom_decorator(func): - @functools.wraps(func) - async def wrapper(*args, **kwargs): - return await func(*args, **kwargs) - - return wrapper - - @custom_decorator - async def decorated_prompt(ctx: Context, topic: str) -> str: - assert isinstance(ctx, Context) - return f"Write about {topic}" - - prompt = Prompt.from_function(decorated_prompt) - - # Verify ctx is excluded from arguments - assert "ctx" not in [arg.name for arg in prompt.arguments or []] - - mcp = FastMCP() - context = Context(fastmcp=mcp) - - async with context: - result = await prompt.render(arguments={"topic": "cats"}) - assert isinstance(result, PromptResult) - assert isinstance(result.messages[0].content, TextContent) - assert result.messages[0].content.text == "Write about cats" diff --git a/tests/resources/test_resource_manager.py b/tests/resources/test_resource_manager.py deleted file mode 100644 index 8b9e48d395..0000000000 --- a/tests/resources/test_resource_manager.py +++ /dev/null @@ -1,653 +0,0 @@ -from pathlib import Path -from tempfile import NamedTemporaryFile - -import pytest -from pydantic import AnyUrl, FileUrl - -from fastmcp.exceptions import NotFoundError, ResourceError -from fastmcp.resources import ( - FileResource, - ResourceManager, - ResourceTemplate, -) -from fastmcp.resources.resource import FunctionResource, ResourceContent -from fastmcp.utilities.tests import caplog_for_fastmcp - - -@pytest.fixture -def temp_file(): - """Create a temporary file for testing. - - File is automatically cleaned up after the test if it still exists. - """ - content = "test content" - with NamedTemporaryFile(mode="w", delete=False) as f: - f.write(content) - path = Path(f.name).resolve() - yield path - try: - path.unlink() - except FileNotFoundError: - pass # File was already deleted by the test - - -class TestResourceManager: - """Test ResourceManager functionality.""" - - async def test_add_resource(self, temp_file: Path): - """Test adding a resource.""" - manager = ResourceManager() - file_url = "file://test-resource" - resource = FileResource( - uri=FileUrl(file_url), - name="test", - path=temp_file, - ) - added = manager.add_resource(resource) - assert added == resource - # Get the actual key from the resource manager - resources = await manager.get_resources() - assert len(resources) == 1 - assert resource in resources.values() - - async def test_add_duplicate_resource(self, temp_file: Path): - """Test adding the same resource twice.""" - manager = ResourceManager() - file_url = "file://test-resource" - resource = FileResource( - uri=FileUrl(file_url), - name="test", - path=temp_file, - ) - first = manager.add_resource(resource) - second = manager.add_resource(resource) - assert first == second - # Check the resource is there - resources = await manager.get_resources() - assert len(resources) == 1 - assert resource in resources.values() - - async def test_warn_on_duplicate_resources(self, temp_file: Path, caplog): - """Test warning on duplicate resources.""" - manager = ResourceManager(duplicate_behavior="warn") - - file_url = "file://test-resource" - resource = FileResource( - uri=FileUrl(file_url), - name="test_resource", - path=temp_file, - ) - - manager.add_resource(resource) - - with caplog_for_fastmcp(caplog): - manager.add_resource(resource) - - assert "Resource already exists" in caplog.text - # Should have the resource - resources = await manager.get_resources() - assert len(resources) == 1 - assert resource in resources.values() - - async def test_disable_warn_on_duplicate_resources(self, temp_file: Path, caplog): - """Test disabling warning on duplicate resources.""" - manager = ResourceManager(duplicate_behavior="ignore") - resource = FileResource( - uri=FileUrl(f"file://{temp_file.name}"), - name="test", - path=temp_file, - ) - manager.add_resource(resource) - manager.add_resource(resource) - assert "Resource already exists" not in caplog.text - - async def test_error_on_duplicate_resources(self, temp_file: Path): - """Test error on duplicate resources.""" - manager = ResourceManager(duplicate_behavior="error") - - resource = FileResource( - uri=FileUrl(f"file://{temp_file.name}"), - name="test_resource", - path=temp_file, - ) - - manager.add_resource(resource) - - with pytest.raises(ValueError, match="Resource already exists"): - manager.add_resource(resource) - - async def test_replace_duplicate_resources(self, temp_file: Path): - """Test replacing duplicate resources.""" - manager = ResourceManager(duplicate_behavior="replace") - - file_url = "file://test-resource" - resource1 = FileResource( - uri=FileUrl(file_url), - name="original", - path=temp_file, - ) - - resource2 = FileResource( - uri=FileUrl(file_url), - name="replacement", - path=temp_file, - ) - - manager.add_resource(resource1) - manager.add_resource(resource2) - - # Should have replaced with the new resource - resources = await manager.get_resources() - resource_list = list(resources.values()) - assert len(resource_list) == 1 - assert resource_list[0].name == "replacement" - - async def test_ignore_duplicate_resources(self, temp_file: Path): - """Test ignoring duplicate resources.""" - manager = ResourceManager(duplicate_behavior="ignore") - - file_url = "file://test-resource" - resource1 = FileResource( - uri=FileUrl(file_url), - name="original", - path=temp_file, - ) - - resource2 = FileResource( - uri=FileUrl(file_url), - name="replacement", - path=temp_file, - ) - - manager.add_resource(resource1) - result = manager.add_resource(resource2) - - # Should keep the original - resources = await manager.get_resources() - resource_list = list(resources.values()) - assert len(resource_list) == 1 - assert resource_list[0].name == "original" - # Result should be the original resource - assert result.name == "original" - - async def test_warn_on_duplicate_templates(self, caplog): - """Test warning on duplicate templates.""" - manager = ResourceManager(duplicate_behavior="warn") - - def template_fn(id: str) -> str: - return f"Template {id}" - - template = ResourceTemplate.from_function( - fn=template_fn, - uri_template="test://{id}", - name="test_template", - ) - - manager.add_template(template) - - with caplog_for_fastmcp(caplog): - manager.add_template(template) - - assert "Template already exists" in caplog.text - # Should have the template - templates = await manager.get_resource_templates() - assert templates == {"test://{id}": template} - - async def test_error_on_duplicate_templates(self): - """Test error on duplicate templates.""" - manager = ResourceManager(duplicate_behavior="error") - - def template_fn(id: str) -> str: - return f"Template {id}" - - template = ResourceTemplate.from_function( - fn=template_fn, - uri_template="test://{id}", - name="test_template", - ) - - manager.add_template(template) - - with pytest.raises(ValueError, match="Template already exists"): - manager.add_template(template) - - async def test_replace_duplicate_templates(self): - """Test replacing duplicate templates.""" - manager = ResourceManager(duplicate_behavior="replace") - - def original_fn(id: str) -> str: - return f"Original {id}" - - def replacement_fn(id: str) -> str: - return f"Replacement {id}" - - template1 = ResourceTemplate.from_function( - fn=original_fn, - uri_template="test://{id}", - name="original", - ) - - template2 = ResourceTemplate.from_function( - fn=replacement_fn, - uri_template="test://{id}", - name="replacement", - ) - - manager.add_template(template1) - manager.add_template(template2) - - # Should have replaced with the new template - templates_dict = await manager.get_resource_templates() - templates = list(templates_dict.values()) - assert len(templates) == 1 - assert templates[0].name == "replacement" - - async def test_ignore_duplicate_templates(self): - """Test ignoring duplicate templates.""" - manager = ResourceManager(duplicate_behavior="ignore") - - def original_fn(id: str) -> str: - return f"Original {id}" - - def replacement_fn(id: str) -> str: - return f"Replacement {id}" - - template1 = ResourceTemplate.from_function( - fn=original_fn, - uri_template="test://{id}", - name="original", - ) - - template2 = ResourceTemplate.from_function( - fn=replacement_fn, - uri_template="test://{id}", - name="replacement", - ) - - manager.add_template(template1) - result = manager.add_template(template2) - - # Should keep the original - templates_dict = await manager.get_resource_templates() - templates = list(templates_dict.values()) - assert len(templates) == 1 - assert templates[0].name == "original" - # Result should be the original template - assert result.name == "original" - - async def test_get_resource(self, temp_file: Path): - """Test getting a resource by URI.""" - manager = ResourceManager() - resource = FileResource( - uri=FileUrl(f"file://{temp_file.name}"), - name="test", - path=temp_file, - ) - manager.add_resource(resource) - retrieved = await manager.get_resource(resource.uri) - assert retrieved == resource - - async def test_get_resource_from_template(self): - """Test getting a resource through a template.""" - manager = ResourceManager() - - def greet(name: str) -> str: - return f"Hello, {name}!" - - template = ResourceTemplate.from_function( - fn=greet, - uri_template="greet://{name}", - name="greeter", - ) - manager._templates[template.uri_template] = template - - resource = await manager.get_resource(AnyUrl("greet://world")) - assert isinstance(resource, FunctionResource) - result = await resource.read() - assert isinstance(result, ResourceContent) - assert result.content == "Hello, world!" - - async def test_get_unknown_resource(self): - """Test getting a non-existent resource.""" - manager = ResourceManager() - with pytest.raises(NotFoundError, match="Unknown resource"): - await manager.get_resource(AnyUrl("unknown://test")) - - async def test_get_resources(self, temp_file: Path): - """Test retrieving all resources.""" - manager = ResourceManager() - file_url1 = "file://test-resource1" - resource1 = FileResource( - uri=FileUrl(file_url1), - name="test1", - path=temp_file, - ) - file_url2 = "file://test-resource2" - resource2 = FileResource( - uri=FileUrl(file_url2), - name="test2", - path=temp_file, - ) - manager.add_resource(resource1) - manager.add_resource(resource2) - resources = await manager.get_resources() - assert len(resources) == 2 - values = list(resources.values()) - assert resource1 in values - assert resource2 in values - - -class TestResourceTags: - """Test functionality related to resource tags.""" - - async def test_add_resource_with_tags(self, temp_file: Path): - """Test adding a resource with tags.""" - manager = ResourceManager() - resource = FileResource( - uri=FileUrl("file://weather-data"), - name="weather_data", - path=temp_file, - tags={"weather", "data"}, - ) - manager.add_resource(resource) - - # Check that tags are preserved - resources_dict = await manager.get_resources() - resources = list(resources_dict.values()) - assert len(resources) == 1 - assert resources[0].tags == {"weather", "data"} - - async def test_add_function_resource_with_tags(self): - """Test adding a function resource with tags.""" - manager = ResourceManager() - - async def get_data(): - return "Sample data" - - resource = FunctionResource( - uri=AnyUrl("data://sample"), - name="sample_data", - description="Sample data resource", - mime_type="text/plain", - fn=get_data, - tags={"sample", "test", "data"}, - ) - - manager.add_resource(resource) - resources_dict = await manager.get_resources() - resources = list(resources_dict.values()) - assert len(resources) == 1 - assert resources[0].tags == {"sample", "test", "data"} - - async def test_add_template_with_tags(self): - """Test adding a resource template with tags.""" - manager = ResourceManager() - - def user_data(user_id: str) -> str: - return f"Data for user {user_id}" - - template = ResourceTemplate.from_function( - fn=user_data, - uri_template="users://{user_id}", - name="user_template", - description="Get user data by ID", - tags={"users", "template", "data"}, - ) - - manager.add_template(template) - templates_dict = await manager.get_resource_templates() - templates = list(templates_dict.values()) - assert len(templates) == 1 - assert templates[0].tags == {"users", "template", "data"} - - async def test_filter_resources_by_tags(self, temp_file: Path): - """Test filtering resources by tags.""" - manager = ResourceManager() - - # Create multiple resources with different tags - resource1 = FileResource( - uri=FileUrl("file://weather-data"), - name="weather_data", - path=temp_file, - tags={"weather", "data"}, - ) - - async def get_user_data(): - return "User data" - - resource2 = FunctionResource( - uri=AnyUrl("data://users"), - name="user_data", - description="User data resource", - mime_type="text/plain", - fn=get_user_data, - tags={"users", "data"}, - ) - - async def get_system_data(): - return "System data" - - resource3 = FunctionResource( - uri=AnyUrl("data://system"), - name="system_data", - description="System data resource", - mime_type="text/plain", - fn=get_system_data, - tags={"system", "admin"}, - ) - - manager.add_resource(resource1) - manager.add_resource(resource2) - manager.add_resource(resource3) - - # Filter by tags - resources_dict = await manager.get_resources() - data_resources = [r for r in resources_dict.values() if "data" in r.tags] - assert len(data_resources) == 2 - assert {r.name for r in data_resources} == {"weather_data", "user_data"} - - admin_resources = [r for r in resources_dict.values() if "admin" in r.tags] - assert len(admin_resources) == 1 - assert admin_resources[0].name == "system_data" - - -class TestQueryOnlyTemplates: - """Test resource templates with only query parameters (no path params).""" - - async def test_template_with_only_query_params_no_query_string(self): - """Test that templates with only query params work without query string. - - Regression test for bug where empty parameter dict {} was treated as falsy, - causing templates with only query parameters to fail when no query string - was provided in the URI. - """ - manager = ResourceManager() - - def get_config(format: str = "json") -> str: - return f"Config in {format} format" - - template = ResourceTemplate.from_function( - fn=get_config, - uri_template="data://config{?format}", - name="config", - ) - manager.add_template(template) - - # Should work without query param (uses default) - resource = await manager.get_resource("data://config") - result = await resource.read() - assert isinstance(result, ResourceContent) - assert isinstance(result.content, str) - assert result.content == "Config in json format" - - # Should also work via read_resource - result = await manager.read_resource("data://config") - assert isinstance(result, ResourceContent) - assert result.content == "Config in json format" - - async def test_template_with_only_query_params_with_query_string(self): - """Test that templates with only query params work with query string.""" - manager = ResourceManager() - - def get_config(format: str = "json") -> str: - return f"Config in {format} format" - - template = ResourceTemplate.from_function( - fn=get_config, - uri_template="data://config{?format}", - name="config", - ) - manager.add_template(template) - - # Should work with query param (overrides default) - resource = await manager.get_resource("data://config?format=xml") - result = await resource.read() - assert isinstance(result, ResourceContent) - assert isinstance(result.content, str) - assert result.content == "Config in xml format" - - # Should also work via read_resource - result = await manager.read_resource("data://config?format=xml") - assert isinstance(result, ResourceContent) - assert result.content == "Config in xml format" - - async def test_template_with_only_multiple_query_params(self): - """Test template with only multiple query parameters.""" - manager = ResourceManager() - - def get_data(format: str = "json", limit: int = 10) -> str: - return f"Data in {format} (limit: {limit})" - - template = ResourceTemplate.from_function( - fn=get_data, - uri_template="data://items{?format,limit}", - name="items", - ) - manager.add_template(template) - - # No query params - use all defaults - result = await manager.read_resource("data://items") - assert isinstance(result, ResourceContent) - assert result.content == "Data in json (limit: 10)" - - # Partial query params - result = await manager.read_resource("data://items?format=xml") - assert isinstance(result, ResourceContent) - assert result.content == "Data in xml (limit: 10)" - - # All query params - result = await manager.read_resource("data://items?format=xml&limit=20") - assert isinstance(result, ResourceContent) - assert result.content == "Data in xml (limit: 20)" - - async def test_has_resource_with_query_only_template(self): - """Test that has_resource() works with query-only templates. - - Regression test for bug where empty parameter dict {} was treated as falsy, - causing has_resource() to return False for query-only templates when no - query string was provided. - """ - manager = ResourceManager() - - def get_config(format: str = "json") -> str: - return f"Config in {format} format" - - template = ResourceTemplate.from_function( - fn=get_config, - uri_template="data://config{?format}", - name="config", - ) - manager.add_template(template) - - # Should find resource without query param (uses default) - assert await manager.has_resource("data://config") - - # Should also find resource with query param - assert await manager.has_resource("data://config?format=xml") - - -class TestResourceErrorHandling: - """Test error handling in the ResourceManager.""" - - async def test_resource_error_passthrough(self): - """Test that ResourceErrors are passed through directly.""" - manager = ResourceManager() - - async def error_resource(): - """Resource that raises a ResourceError.""" - raise ResourceError("Specific resource error") - - resource = FunctionResource( - uri=AnyUrl("error://resource"), - name="error_resource", - fn=error_resource, - ) - manager.add_resource(resource) - - with pytest.raises(ResourceError, match="Specific resource error"): - await manager.read_resource("error://resource") - - async def test_template_resource_error_passthrough(self): - """Test that ResourceErrors from template-generated resources are passed through.""" - manager = ResourceManager() - - def error_template(param: str): - """Template that raises a ResourceError.""" - raise ResourceError(f"Template error with param {param}") - - template = ResourceTemplate.from_function( - fn=error_template, - uri_template="error://{param}", - name="error_template", - ) - manager.add_template(template) - - with pytest.raises(ResourceError) as excinfo: - await manager.read_resource("error://test") - - # The original error message should be included in the ValueError - assert "Template error with param test" in str(excinfo.value) - - async def test_exception_converted_to_resource_error_with_details(self): - """Test that other exceptions are converted to ResourceError with details by default.""" - manager = ResourceManager() - - async def buggy_resource(): - """Resource that raises a ValueError.""" - raise ValueError("Internal error details") - - resource = FunctionResource( - uri=AnyUrl("buggy://resource"), - name="buggy_resource", - fn=buggy_resource, - ) - manager.add_resource(resource) - - with pytest.raises(ResourceError) as excinfo: - await manager.read_resource("buggy://resource") - - # The error message should include the original exception details - assert "Error reading resource 'buggy://resource'" in str(excinfo.value) - assert "Internal error details" in str(excinfo.value) - - async def test_exception_converted_to_masked_resource_error(self): - """Test that other exceptions are masked when enabled.""" - manager = ResourceManager(mask_error_details=True) - - async def buggy_resource(): - """Resource that raises a ValueError.""" - raise ValueError("Internal error details") - - resource = FunctionResource( - uri=AnyUrl("buggy://resource"), - name="buggy_resource", - fn=buggy_resource, - ) - manager.add_resource(resource) - - with pytest.raises(ResourceError) as excinfo: - await manager.read_resource("buggy://resource") - - # The error message should not include the original exception details - assert "Error reading resource 'buggy://resource'" in str(excinfo.value) - assert "Internal error details" not in str(excinfo.value) diff --git a/tests/server/middleware/test_rate_limiting.py b/tests/server/middleware/test_rate_limiting.py index 1f17d70956..4fe1fb7ef2 100644 --- a/tests/server/middleware/test_rate_limiting.py +++ b/tests/server/middleware/test_rate_limiting.py @@ -396,8 +396,8 @@ def get_client_id(context): rate_limit_server.add_middleware( RateLimitingMiddleware( - max_requests_per_second=6.0, # Accounting for initialization and list_tools calls - burst_capacity=4, + max_requests_per_second=1.0, # Very slow refill to ensure rate limiting triggers + burst_capacity=4, # init + list_tools + call + list_tools = 4, so 2nd call fails get_client_id=get_client_id, ) ) @@ -407,10 +407,11 @@ def get_client_id(context): await client.call_tool("quick_action", {"message": "first"}) # Second should be rate limited for this specific client - with pytest.raises( - ToolError, match="Rate limit exceeded for client: test_client_123" - ): + with pytest.raises(ToolError) as exc_info: await client.call_tool("quick_action", {"message": "second"}) + assert "Rate limit exceeded for client: test_client_123" in str( + exc_info.value + ) async def test_global_rate_limiting(self, rate_limit_server): """Test global rate limiting across all clients.""" diff --git a/tests/server/providers/openapi/test_performance_comparison.py b/tests/server/providers/openapi/test_performance_comparison.py index 457c721047..f656101a5a 100644 --- a/tests/server/providers/openapi/test_performance_comparison.py +++ b/tests/server/providers/openapi/test_performance_comparison.py @@ -222,7 +222,7 @@ def test_server_initialization_performance(self, comprehensive_spec): f"Server should initialize in under 100ms, got {avg_time:.4f}s" ) - def test_functionality_after_optimization(self, comprehensive_spec): + async def test_functionality_after_optimization(self, comprehensive_spec): """Verify that performance optimization doesn't break functionality.""" client = httpx.AsyncClient(base_url="https://api.example.com") @@ -232,14 +232,8 @@ def test_functionality_after_optimization(self, comprehensive_spec): name="Test Server", ) - # Get tools from the provider - def get_provider_tools(server): - for provider in server._providers: - if hasattr(provider, "_tools"): - return provider._tools - return {} - - tools = get_provider_tools(server) + # Get tools from the server via public API + tools = await server.get_tools() # Should have 6 operations in the spec assert len(tools) == 6 @@ -258,12 +252,13 @@ def get_provider_tools(server): def test_memory_efficiency(self, comprehensive_spec): """Test that implementation doesn't significantly increase memory usage.""" - # Helper to get tools from provider - def get_provider_tools(server): + # Helper to count total tools across all providers + def count_provider_tools(server): + total = 0 for provider in server._providers: if hasattr(provider, "_tools"): - return provider._tools - return {} + total += len(provider._tools) + return total gc.collect() # Clean up before baseline baseline_refs = len(gc.get_objects()) @@ -280,7 +275,7 @@ def get_provider_tools(server): # Servers should all be functional assert len(servers) == 10 - assert all(len(get_provider_tools(s)) == 6 for s in servers) + assert all(count_provider_tools(s) == 6 for s in servers) # Memory usage shouldn't explode gc.collect() diff --git a/tests/server/providers/test_local_provider.py b/tests/server/providers/test_local_provider.py new file mode 100644 index 0000000000..3bbc90805a --- /dev/null +++ b/tests/server/providers/test_local_provider.py @@ -0,0 +1,632 @@ +"""Comprehensive tests for LocalProvider. + +Tests cover: +- Storage operations (add/remove tools, resources, templates, prompts) +- Provider interface (list/get operations) +- Decorator patterns (all calling styles) +- Tool transformations +- Standalone usage (provider attached to multiple servers) +- Task registration +""" + +from typing import Any + +import pytest + +from fastmcp import FastMCP +from fastmcp.client import Client +from fastmcp.prompts.prompt import Prompt +from fastmcp.server.providers.local_provider import LocalProvider +from fastmcp.server.tasks import TaskConfig +from fastmcp.tools.tool import Tool, ToolResult + + +class TestLocalProviderStorage: + """Tests for LocalProvider storage operations.""" + + def test_add_tool(self): + """Test adding a tool to LocalProvider.""" + provider = LocalProvider() + + tool = Tool( + name="test_tool", + description="A test tool", + parameters={"type": "object", "properties": {}}, + ) + provider.add_tool(tool) + + assert "test_tool" in provider._tools + assert provider._tools["test_tool"] is tool + + def test_add_multiple_tools(self): + """Test adding multiple tools.""" + provider = LocalProvider() + + tool1 = Tool( + name="tool1", + description="First tool", + parameters={"type": "object", "properties": {}}, + ) + tool2 = Tool( + name="tool2", + description="Second tool", + parameters={"type": "object", "properties": {}}, + ) + provider.add_tool(tool1) + provider.add_tool(tool2) + + assert "tool1" in provider._tools + assert "tool2" in provider._tools + + def test_remove_tool(self): + """Test removing a tool from LocalProvider.""" + provider = LocalProvider() + + tool = Tool( + name="test_tool", + description="A test tool", + parameters={"type": "object", "properties": {}}, + ) + provider.add_tool(tool) + provider.remove_tool("test_tool") + + assert "test_tool" not in provider._tools + + def test_remove_nonexistent_tool_raises(self): + """Test that removing a nonexistent tool raises KeyError.""" + provider = LocalProvider() + + with pytest.raises(KeyError): + provider.remove_tool("nonexistent") + + def test_add_resource(self): + """Test adding a resource to LocalProvider.""" + provider = LocalProvider() + + @provider.resource("resource://test") + def test_resource() -> str: + return "content" + + assert "resource://test" in provider._resources + + def test_remove_resource(self): + """Test removing a resource from LocalProvider.""" + provider = LocalProvider() + + @provider.resource("resource://test") + def test_resource() -> str: + return "content" + + provider.remove_resource("resource://test") + + assert "resource://test" not in provider._resources + + def test_add_template(self): + """Test adding a resource template to LocalProvider.""" + provider = LocalProvider() + + @provider.resource("resource://{id}") + def template_fn(id: str) -> str: + return f"Resource {id}" + + assert "resource://{id}" in provider._templates + + def test_remove_template(self): + """Test removing a resource template from LocalProvider.""" + provider = LocalProvider() + + @provider.resource("resource://{id}") + def template_fn(id: str) -> str: + return f"Resource {id}" + + provider.remove_template("resource://{id}") + + assert "resource://{id}" not in provider._templates + + def test_add_prompt(self): + """Test adding a prompt to LocalProvider.""" + provider = LocalProvider() + + prompt = Prompt( + name="test_prompt", + description="A test prompt", + ) + provider.add_prompt(prompt) + + assert "test_prompt" in provider._prompts + + def test_remove_prompt(self): + """Test removing a prompt from LocalProvider.""" + provider = LocalProvider() + + prompt = Prompt( + name="test_prompt", + description="A test prompt", + ) + provider.add_prompt(prompt) + provider.remove_prompt("test_prompt") + + assert "test_prompt" not in provider._prompts + + +class TestLocalProviderInterface: + """Tests for LocalProvider's Provider interface.""" + + async def test_list_tools_empty(self): + """Test listing tools when empty.""" + provider = LocalProvider() + tools = await provider.list_tools() + assert tools == [] + + async def test_list_tools(self): + """Test listing tools returns all stored tools.""" + provider = LocalProvider() + + tool1 = Tool(name="tool1", description="First", parameters={"type": "object"}) + tool2 = Tool(name="tool2", description="Second", parameters={"type": "object"}) + provider.add_tool(tool1) + provider.add_tool(tool2) + + tools = await provider.list_tools() + assert len(tools) == 2 + names = {t.name for t in tools} + assert names == {"tool1", "tool2"} + + async def test_get_tool_found(self): + """Test getting a tool that exists.""" + provider = LocalProvider() + + tool = Tool( + name="test_tool", + description="A test tool", + parameters={"type": "object"}, + ) + provider.add_tool(tool) + + result = await provider.get_tool("test_tool") + assert result is not None + assert result.name == "test_tool" + + async def test_get_tool_not_found(self): + """Test getting a tool that doesn't exist returns None.""" + provider = LocalProvider() + result = await provider.get_tool("nonexistent") + assert result is None + + async def test_list_resources(self): + """Test listing resources.""" + provider = LocalProvider() + + @provider.resource("resource://test") + def test_resource() -> str: + return "content" + + resources = await provider.list_resources() + assert len(resources) == 1 + assert str(resources[0].uri) == "resource://test" + + async def test_get_resource_found(self): + """Test getting a resource that exists.""" + provider = LocalProvider() + + @provider.resource("resource://test") + def test_resource() -> str: + return "content" + + result = await provider.get_resource("resource://test") + assert result is not None + assert str(result.uri) == "resource://test" + + async def test_get_resource_not_found(self): + """Test getting a resource that doesn't exist returns None.""" + provider = LocalProvider() + result = await provider.get_resource("resource://nonexistent") + assert result is None + + async def test_list_resource_templates(self): + """Test listing resource templates.""" + provider = LocalProvider() + + @provider.resource("resource://{id}") + def template_fn(id: str) -> str: + return f"Resource {id}" + + templates = await provider.list_resource_templates() + assert len(templates) == 1 + assert templates[0].uri_template == "resource://{id}" + + async def test_get_resource_template_match(self): + """Test getting a template that matches a URI.""" + provider = LocalProvider() + + @provider.resource("resource://{id}") + def template_fn(id: str) -> str: + return f"Resource {id}" + + result = await provider.get_resource_template("resource://123") + assert result is not None + assert result.uri_template == "resource://{id}" + + async def test_get_resource_template_no_match(self): + """Test getting a template with no match returns None.""" + provider = LocalProvider() + + @provider.resource("resource://{id}") + def template_fn(id: str) -> str: + return f"Resource {id}" + + result = await provider.get_resource_template("other://123") + assert result is None + + async def test_list_prompts(self): + """Test listing prompts.""" + provider = LocalProvider() + + prompt = Prompt( + name="test_prompt", + description="A test prompt", + ) + provider.add_prompt(prompt) + + prompts = await provider.list_prompts() + assert len(prompts) == 1 + assert prompts[0].name == "test_prompt" + + async def test_get_prompt_found(self): + """Test getting a prompt that exists.""" + provider = LocalProvider() + + prompt = Prompt( + name="test_prompt", + description="A test prompt", + ) + provider.add_prompt(prompt) + + result = await provider.get_prompt("test_prompt") + assert result is not None + assert result.name == "test_prompt" + + async def test_get_prompt_not_found(self): + """Test getting a prompt that doesn't exist returns None.""" + provider = LocalProvider() + result = await provider.get_prompt("nonexistent") + assert result is None + + +class TestLocalProviderDecorators: + """Tests for LocalProvider decorator methods.""" + + def test_tool_decorator_bare(self): + """Test @provider.tool without parentheses.""" + provider = LocalProvider() + + @provider.tool + def my_tool(x: int) -> int: + return x * 2 + + assert "my_tool" in provider._tools + assert provider._tools["my_tool"].name == "my_tool" + + def test_tool_decorator_with_parens(self): + """Test @provider.tool() with empty parentheses.""" + provider = LocalProvider() + + @provider.tool() + def my_tool(x: int) -> int: + return x * 2 + + assert "my_tool" in provider._tools + + def test_tool_decorator_with_name_kwarg(self): + """Test @provider.tool(name='custom').""" + provider = LocalProvider() + + @provider.tool(name="custom_name") + def my_tool(x: int) -> int: + return x * 2 + + assert "custom_name" in provider._tools + assert "my_tool" not in provider._tools + + def test_tool_decorator_with_description(self): + """Test @provider.tool(description='...').""" + provider = LocalProvider() + + @provider.tool(description="Custom description") + def my_tool(x: int) -> int: + return x * 2 + + assert provider._tools["my_tool"].description == "Custom description" + + def test_tool_direct_call(self): + """Test provider.tool(fn, name='...').""" + provider = LocalProvider() + + def my_tool(x: int) -> int: + return x * 2 + + provider.tool(my_tool, name="direct_tool") + + assert "direct_tool" in provider._tools + + async def test_tool_decorator_execution(self): + """Test that decorated tools execute correctly.""" + provider = LocalProvider() + + @provider.tool + def add(a: int, b: int) -> int: + return a + b + + server = FastMCP("Test", providers=[provider]) + + async with Client(server) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + assert result.data == 5 + + def test_resource_decorator(self): + """Test @provider.resource decorator.""" + provider = LocalProvider() + + @provider.resource("resource://test") + def my_resource() -> str: + return "test content" + + assert "resource://test" in provider._resources + + def test_resource_decorator_with_name(self): + """Test @provider.resource with custom name.""" + provider = LocalProvider() + + @provider.resource("resource://test", name="custom_name") + def my_resource() -> str: + return "test content" + + assert provider._resources["resource://test"].name == "custom_name" + + async def test_resource_decorator_execution(self): + """Test that decorated resources execute correctly.""" + provider = LocalProvider() + + @provider.resource("resource://greeting") + def greeting() -> str: + return "Hello, World!" + + server = FastMCP("Test", providers=[provider]) + + async with Client(server) as client: + result = await client.read_resource("resource://greeting") + assert "Hello, World!" in str(result) + + def test_prompt_decorator_bare(self): + """Test @provider.prompt without parentheses.""" + provider = LocalProvider() + + @provider.prompt + def my_prompt() -> str: + return "A prompt" + + assert "my_prompt" in provider._prompts + + def test_prompt_decorator_with_parens(self): + """Test @provider.prompt() with empty parentheses.""" + provider = LocalProvider() + + @provider.prompt() + def my_prompt() -> str: + return "A prompt" + + assert "my_prompt" in provider._prompts + + def test_prompt_decorator_with_name(self): + """Test @provider.prompt(name='custom').""" + provider = LocalProvider() + + @provider.prompt(name="custom_prompt") + def my_prompt() -> str: + return "A prompt" + + assert "custom_prompt" in provider._prompts + assert "my_prompt" not in provider._prompts + + +class TestLocalProviderToolTransformations: + """Tests for tool transformations in LocalProvider.""" + + def test_add_tool_transformation(self): + """Test adding a tool transformation.""" + from fastmcp.tools.tool_transform import ToolTransformConfig + + provider = LocalProvider() + + @provider.tool + def my_tool(x: int) -> int: + return x + + config = ToolTransformConfig(name="renamed_tool") + provider.add_tool_transformation("my_tool", config) + + assert provider.get_tool_transformation("my_tool") is config + + async def test_list_tools_applies_transformations(self): + """Test that list_tools applies transformations.""" + from fastmcp.tools.tool_transform import ToolTransformConfig + + provider = LocalProvider() + + @provider.tool + def original_tool(x: int) -> int: + return x + + config = ToolTransformConfig(name="transformed_tool") + provider.add_tool_transformation("original_tool", config) + + tools = await provider.list_tools() + assert len(tools) == 1 + assert tools[0].name == "transformed_tool" + + async def test_get_tool_applies_transformation(self): + """Test that get_tool applies transformation.""" + from fastmcp.tools.tool_transform import ToolTransformConfig + + provider = LocalProvider() + + @provider.tool + def my_tool(x: int) -> int: + return x + + config = ToolTransformConfig(description="New description") + provider.add_tool_transformation("my_tool", config) + + tool = await provider.get_tool("my_tool") + assert tool is not None + assert tool.description == "New description" + + def test_remove_tool_transformation(self): + """Test removing a tool transformation.""" + from fastmcp.tools.tool_transform import ToolTransformConfig + + provider = LocalProvider() + + @provider.tool + def my_tool(x: int) -> int: + return x + + config = ToolTransformConfig(name="renamed") + provider.add_tool_transformation("my_tool", config) + provider.remove_tool_transformation("my_tool") + + assert provider.get_tool_transformation("my_tool") is None + + +class TestLocalProviderTaskRegistration: + """Tests for task registration in LocalProvider.""" + + async def test_get_tasks_returns_task_eligible_tools(self): + """Test that get_tasks returns tools with task support.""" + provider = LocalProvider() + + @provider.tool(task=True) + async def background_tool(x: int) -> int: + return x + + tasks = await provider.get_tasks() + assert len(tasks.tools) == 1 + assert tasks.tools[0].name == "background_tool" + + async def test_get_tasks_filters_forbidden_tools(self): + """Test that get_tasks excludes tools with forbidden task mode.""" + provider = LocalProvider() + + @provider.tool(task=False) + def sync_only_tool(x: int) -> int: + return x + + tasks = await provider.get_tasks() + assert len(tasks.tools) == 0 + + async def test_get_tasks_includes_custom_tool_subclasses(self): + """Test that custom Tool subclasses are included in get_tasks.""" + + class CustomTool(Tool): + task_config: TaskConfig = TaskConfig(mode="optional") + parameters: dict[str, Any] = {"type": "object", "properties": {}} + + async def run(self, arguments: dict[str, Any]) -> ToolResult: + return ToolResult(content="custom") + + provider = LocalProvider() + provider.add_tool(CustomTool(name="custom", description="Custom tool")) + + tasks = await provider.get_tasks() + assert len(tasks.tools) == 1 + assert tasks.tools[0].name == "custom" + + +class TestLocalProviderStandaloneUsage: + """Tests for standalone LocalProvider usage patterns.""" + + async def test_attach_provider_to_server(self): + """Test that LocalProvider can be attached to a server.""" + provider = LocalProvider() + + @provider.tool + def greet(name: str) -> str: + return f"Hello, {name}!" + + server = FastMCP("Test", providers=[provider]) + + async with Client(server) as client: + tools = await client.list_tools() + assert any(t.name == "greet" for t in tools) + + async def test_attach_provider_to_multiple_servers(self): + """Test that same provider can be attached to multiple servers.""" + provider = LocalProvider() + + @provider.tool + def shared_tool() -> str: + return "shared" + + server1 = FastMCP("Server1", providers=[provider]) + server2 = FastMCP("Server2", providers=[provider]) + + async with Client(server1) as client1: + tools1 = await client1.list_tools() + assert any(t.name == "shared_tool" for t in tools1) + + async with Client(server2) as client2: + tools2 = await client2.list_tools() + assert any(t.name == "shared_tool" for t in tools2) + + async def test_tools_visible_via_server_get_tools(self): + """Test that provider tools are visible via server.get_tools().""" + provider = LocalProvider() + + @provider.tool + def provider_tool() -> str: + return "from provider" + + server = FastMCP("Test", providers=[provider]) + + tools = await server.get_tools() + assert "provider_tool" in tools + + async def test_server_decorator_and_provider_tools_coexist(self): + """Test that server decorators and provider tools coexist.""" + provider = LocalProvider() + + @provider.tool + def provider_tool() -> str: + return "from provider" + + server = FastMCP("Test", providers=[provider]) + + @server.tool + def server_tool() -> str: + return "from server" + + tools = await server.get_tools() + assert "provider_tool" in tools + assert "server_tool" in tools + + async def test_local_provider_first_wins_duplicates(self): + """Test that LocalProvider tools take precedence over added providers.""" + provider = LocalProvider() + + @provider.tool + def duplicate_tool() -> str: + return "from added provider" + + server = FastMCP("Test", providers=[provider]) + + @server.tool + def duplicate_tool() -> str: # noqa: F811 + return "from server" + + # Server's LocalProvider is first, so its tool wins + tools = await server.get_tools() + assert "duplicate_tool" in tools + + async with Client(server) as client: + result = await client.call_tool("duplicate_tool", {}) + assert result.data == "from server" diff --git a/tests/server/providers/test_local_provider_prompts.py b/tests/server/providers/test_local_provider_prompts.py new file mode 100644 index 0000000000..d38ccfdedc --- /dev/null +++ b/tests/server/providers/test_local_provider_prompts.py @@ -0,0 +1,324 @@ +"""Tests for prompt behavior in LocalProvider. + +Tests cover: +- Prompt context injection +- Prompt decorator patterns +""" + +import pytest +from mcp.types import TextContent + +from fastmcp import Client, Context, FastMCP +from fastmcp.prompts.prompt import FunctionPrompt, Prompt, PromptResult + + +class TestPromptContext: + async def test_prompt_context(self): + mcp = FastMCP() + + @mcp.prompt + def prompt_fn(name: str, ctx: Context) -> str: + assert isinstance(ctx, Context) + return f"Hello, {name}! {ctx.request_id}" + + async with Client(mcp) as client: + result = await client.get_prompt("prompt_fn", {"name": "World"}) + assert len(result.messages) == 1 + message = result.messages[0] + assert message.role == "user" + + async def test_prompt_context_with_callable_object(self): + mcp = FastMCP() + + class MyPrompt: + def __call__(self, name: str, ctx: Context) -> str: + return f"Hello, {name}! {ctx.request_id}" + + mcp.add_prompt(Prompt.from_function(MyPrompt(), name="my_prompt")) + + async with Client(mcp) as client: + result = await client.get_prompt("my_prompt", {"name": "World"}) + assert len(result.messages) == 1 + message = result.messages[0] + assert message.role == "user" + assert isinstance(message.content, TextContent) + assert message.content.text == "Hello, World! 1" + + +class TestPromptDecorator: + async def test_prompt_decorator(self): + mcp = FastMCP() + + @mcp.prompt + def fn() -> str: + return "Hello, world!" + + prompts_dict = await mcp.get_prompts() + assert len(prompts_dict) == 1 + prompt = prompts_dict["fn"] + assert prompt.name == "fn" + content = await prompt.render() + if not isinstance(content, PromptResult): + content = PromptResult.from_value(content) + assert isinstance(content.messages[0].content, TextContent) + assert content.messages[0].content.text == "Hello, world!" + + async def test_prompt_decorator_without_parentheses(self): + mcp = FastMCP() + + @mcp.prompt + def fn() -> str: + return "Hello, world!" + + prompts = await mcp.get_prompts() + assert "fn" in prompts + + async with Client(mcp) as client: + result = await client.get_prompt("fn") + assert len(result.messages) == 1 + assert isinstance(result.messages[0].content, TextContent) + assert result.messages[0].content.text == "Hello, world!" + + async def test_prompt_decorator_with_name(self): + mcp = FastMCP() + + @mcp.prompt(name="custom_name") + def fn() -> str: + return "Hello, world!" + + prompts_dict = await mcp.get_prompts() + assert len(prompts_dict) == 1 + prompt = prompts_dict["custom_name"] + assert prompt.name == "custom_name" + content = await prompt.render() + if not isinstance(content, PromptResult): + content = PromptResult.from_value(content) + assert isinstance(content.messages[0].content, TextContent) + assert content.messages[0].content.text == "Hello, world!" + + async def test_prompt_decorator_with_description(self): + mcp = FastMCP() + + @mcp.prompt(description="A custom description") + def fn() -> str: + return "Hello, world!" + + prompts_dict = await mcp.get_prompts() + assert len(prompts_dict) == 1 + prompt = prompts_dict["fn"] + assert prompt.description == "A custom description" + content = await prompt.render() + if not isinstance(content, PromptResult): + content = PromptResult.from_value(content) + assert isinstance(content.messages[0].content, TextContent) + assert content.messages[0].content.text == "Hello, world!" + + async def test_prompt_decorator_with_parameters(self): + mcp = FastMCP() + + @mcp.prompt + def test_prompt(name: str, greeting: str = "Hello") -> str: + return f"{greeting}, {name}!" + + prompts_dict = await mcp.get_prompts() + assert len(prompts_dict) == 1 + prompt = prompts_dict["test_prompt"] + assert prompt.arguments is not None + assert len(prompt.arguments) == 2 + assert prompt.arguments[0].name == "name" + assert prompt.arguments[0].required is True + assert prompt.arguments[1].name == "greeting" + assert prompt.arguments[1].required is False + + async with Client(mcp) as client: + result = await client.get_prompt("test_prompt", {"name": "World"}) + assert len(result.messages) == 1 + message = result.messages[0] + assert isinstance(message.content, TextContent) + assert message.content.text == "Hello, World!" + + result = await client.get_prompt( + "test_prompt", {"name": "World", "greeting": "Hi"} + ) + assert len(result.messages) == 1 + message = result.messages[0] + assert isinstance(message.content, TextContent) + assert message.content.text == "Hi, World!" + + async def test_prompt_decorator_instance_method(self): + mcp = FastMCP() + + class MyClass: + def __init__(self, prefix: str): + self.prefix = prefix + + def test_prompt(self) -> str: + return f"{self.prefix} Hello, world!" + + obj = MyClass("My prefix:") + mcp.add_prompt(Prompt.from_function(obj.test_prompt, name="test_prompt")) + + async with Client(mcp) as client: + result = await client.get_prompt("test_prompt") + assert len(result.messages) == 1 + message = result.messages[0] + assert isinstance(message.content, TextContent) + assert message.content.text == "My prefix: Hello, world!" + + async def test_prompt_decorator_classmethod(self): + mcp = FastMCP() + + class MyClass: + prefix = "Class prefix:" + + @classmethod + def test_prompt(cls) -> str: + return f"{cls.prefix} Hello, world!" + + mcp.add_prompt(Prompt.from_function(MyClass.test_prompt, name="test_prompt")) + + async with Client(mcp) as client: + result = await client.get_prompt("test_prompt") + assert len(result.messages) == 1 + message = result.messages[0] + assert isinstance(message.content, TextContent) + assert message.content.text == "Class prefix: Hello, world!" + + async def test_prompt_decorator_classmethod_error(self): + mcp = FastMCP() + + with pytest.raises(ValueError, match="To decorate a classmethod"): + + class MyClass: + @mcp.prompt + @classmethod + def test_prompt(cls) -> None: + pass + + async def test_prompt_decorator_staticmethod(self): + mcp = FastMCP() + + class MyClass: + @mcp.prompt + @staticmethod + def test_prompt() -> str: + return "Static Hello, world!" + + async with Client(mcp) as client: + result = await client.get_prompt("test_prompt") + assert len(result.messages) == 1 + message = result.messages[0] + assert isinstance(message.content, TextContent) + assert message.content.text == "Static Hello, world!" + + async def test_prompt_decorator_async_function(self): + mcp = FastMCP() + + @mcp.prompt + async def test_prompt() -> str: + return "Async Hello, world!" + + async with Client(mcp) as client: + result = await client.get_prompt("test_prompt") + assert len(result.messages) == 1 + message = result.messages[0] + assert isinstance(message.content, TextContent) + assert message.content.text == "Async Hello, world!" + + async def test_prompt_decorator_with_tags(self): + """Test that the prompt decorator properly sets tags.""" + mcp = FastMCP() + + @mcp.prompt(tags={"example", "test-tag"}) + def sample_prompt() -> str: + return "Hello, world!" + + prompts_dict = await mcp.get_prompts() + assert len(prompts_dict) == 1 + prompt = prompts_dict["sample_prompt"] + assert prompt.tags == {"example", "test-tag"} + + async def test_prompt_decorator_with_string_name(self): + """Test that @prompt(\"custom_name\") syntax works correctly.""" + mcp = FastMCP() + + @mcp.prompt("string_named_prompt") + def my_function() -> str: + """A function with a string name.""" + return "Hello from string named prompt!" + + prompts = await mcp.get_prompts() + assert "string_named_prompt" in prompts + assert "my_function" not in prompts + + async with Client(mcp) as client: + result = await client.get_prompt("string_named_prompt") + assert len(result.messages) == 1 + assert isinstance(result.messages[0].content, TextContent) + assert result.messages[0].content.text == "Hello from string named prompt!" + + async def test_prompt_direct_function_call(self): + """Test that prompts can be registered via direct function call.""" + mcp = FastMCP() + + def standalone_function() -> str: + """A standalone function to be registered.""" + return "Hello from direct call!" + + result_fn = mcp.prompt(standalone_function, name="direct_call_prompt") + + assert isinstance(result_fn, FunctionPrompt) + + prompts = await mcp.get_prompts() + assert prompts["direct_call_prompt"] is result_fn + + async with Client(mcp) as client: + result = await client.get_prompt("direct_call_prompt") + assert len(result.messages) == 1 + assert isinstance(result.messages[0].content, TextContent) + assert result.messages[0].content.text == "Hello from direct call!" + + async def test_prompt_decorator_conflicting_names_error(self): + """Test that providing both positional and keyword names raises an error.""" + mcp = FastMCP() + + with pytest.raises( + TypeError, + match="Cannot specify both a name as first argument and as keyword argument", + ): + + @mcp.prompt("positional_name", name="keyword_name") + def my_function() -> str: + return "Hello, world!" + + async def test_prompt_decorator_staticmethod_order(self): + """Test that both decorator orders work for static methods""" + mcp = FastMCP() + + class MyClass: + @mcp.prompt # type: ignore[misc] + @staticmethod + def test_prompt() -> str: + return "Static Hello, world!" + + async with Client(mcp) as client: + result = await client.get_prompt("test_prompt") + assert len(result.messages) == 1 + message = result.messages[0] + assert isinstance(message.content, TextContent) + assert message.content.text == "Static Hello, world!" + + async def test_prompt_decorator_with_meta(self): + """Test that meta parameter is passed through the prompt decorator.""" + mcp = FastMCP() + + meta_data = {"version": "3.0", "type": "prompt"} + + @mcp.prompt(meta=meta_data) + def test_prompt(message: str) -> str: + return f"Response: {message}" + + prompts_dict = await mcp.get_prompts() + prompt = prompts_dict["test_prompt"] + + assert prompt.meta == meta_data diff --git a/tests/server/providers/test_local_provider_resources.py b/tests/server/providers/test_local_provider_resources.py new file mode 100644 index 0000000000..06a0ea4d02 --- /dev/null +++ b/tests/server/providers/test_local_provider_resources.py @@ -0,0 +1,730 @@ +"""Tests for resource and template behavior in LocalProvider. + +Tests cover: +- Resource context injection +- Resource templates and URI parsing +- Resource template context injection +- Resource decorator patterns +- Template decorator patterns +""" + +import pytest +from mcp import McpError +from mcp.types import BlobResourceContents, TextResourceContents +from pydantic import AnyUrl + +from fastmcp import Client, Context, FastMCP +from fastmcp.resources import Resource, ResourceContent, ResourceTemplate + + +class TestResourceContext: + async def test_resource_with_context_annotation_gets_context(self): + mcp = FastMCP() + + @mcp.resource("resource://test") + def resource_with_context(ctx: Context) -> str: + assert isinstance(ctx, Context) + return ctx.request_id + + async with Client(mcp) as client: + result = await client.read_resource(AnyUrl("resource://test")) + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "1" + + +class TestResourceTemplates: + async def test_resource_with_params_not_in_uri(self): + """Test that a resource with function parameters raises an error if the URI + parameters don't match""" + mcp = FastMCP() + + with pytest.raises( + ValueError, + match="URI template must contain at least one parameter", + ): + + @mcp.resource("resource://data") + def get_data_fn(param: str) -> str: + return f"Data: {param}" + + async def test_resource_with_uri_params_without_args(self): + """Test that a resource with URI parameters is automatically a template""" + mcp = FastMCP() + + with pytest.raises( + ValueError, + match="URI parameters .* must be a subset of the function arguments", + ): + + @mcp.resource("resource://{param}") + def get_data() -> str: + return "Data" + + async def test_resource_with_untyped_params(self): + """Test that a resource with untyped parameters raises an error""" + mcp = FastMCP() + + @mcp.resource("resource://{param}") + def get_data(param) -> str: + return "Data" + + async def test_resource_matching_params(self): + """Test that a resource with matching URI and function parameters works""" + mcp = FastMCP() + + @mcp.resource("resource://{name}/data") + def get_data(name: str) -> str: + return f"Data for {name}" + + async with Client(mcp) as client: + result = await client.read_resource(AnyUrl("resource://test/data")) + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Data for test" + + async def test_resource_mismatched_params(self): + """Test that mismatched parameters raise an error""" + mcp = FastMCP() + + with pytest.raises( + ValueError, + match="Required function arguments .* must be a subset of the URI path parameters", + ): + + @mcp.resource("resource://{name}/data") + def get_data(user: str) -> str: + return f"Data for {user}" + + async def test_resource_multiple_params(self): + """Test that multiple parameters work correctly""" + mcp = FastMCP() + + @mcp.resource("resource://{org}/{repo}/data") + def get_data(org: str, repo: str) -> str: + return f"Data for {org}/{repo}" + + async with Client(mcp) as client: + result = await client.read_resource( + AnyUrl("resource://cursor/fastmcp/data") + ) + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Data for cursor/fastmcp" + + async def test_resource_multiple_mismatched_params(self): + """Test that mismatched parameters raise an error""" + mcp = FastMCP() + + with pytest.raises( + ValueError, + match="Required function arguments .* must be a subset of the URI path parameters", + ): + + @mcp.resource("resource://{org}/{repo}/data") + def get_data_mismatched(org: str, repo_2: str) -> str: + return f"Data for {org}" + + async def test_template_with_varkwargs(self): + """Test that a template can have **kwargs.""" + mcp = FastMCP() + + @mcp.resource("test://{x}/{y}/{z}") + def func(**kwargs: int) -> int: + return sum(kwargs.values()) + + async with Client(mcp) as client: + result = await client.read_resource(AnyUrl("test://1/2/3")) + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "6" + + async def test_template_with_default_params(self): + """Test that a template can have default parameters.""" + mcp = FastMCP() + + @mcp.resource("math://add/{x}") + def add(x: int, y: int = 10) -> int: + return x + y + + templates_dict = await mcp.get_resource_templates() + templates = list(templates_dict.values()) + assert len(templates) == 1 + assert templates[0].uri_template == "math://add/{x}" + + async with Client(mcp) as client: + result = await client.read_resource(AnyUrl("math://add/5")) + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "15" + + result2 = await client.read_resource(AnyUrl("math://add/7")) + assert isinstance(result2[0], TextResourceContents) + assert result2[0].text == "17" + + async def test_template_to_resource_conversion(self): + """Test that a template can be converted to a resource.""" + mcp = FastMCP() + + @mcp.resource("resource://{name}/data") + def get_data(name: str) -> str: + return f"Data for {name}" + + templates_dict = await mcp.get_resource_templates() + templates = list(templates_dict.values()) + assert len(templates) == 1 + assert templates[0].uri_template == "resource://{name}/data" + + async with Client(mcp) as client: + result = await client.read_resource(AnyUrl("resource://test/data")) + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Data for test" + + async def test_template_decorator_with_tags(self): + mcp = FastMCP() + + @mcp.resource("resource://{param}", tags={"template", "test-tag"}) + def template_resource(param: str) -> str: + return f"Template resource: {param}" + + templates_dict = await mcp.get_resource_templates() + template = templates_dict["resource://{param}"] + assert template.tags == {"template", "test-tag"} + + async def test_template_decorator_wildcard_param(self): + mcp = FastMCP() + + @mcp.resource("resource://{param*}") + def template_resource(param: str) -> str: + return f"Template resource: {param}" + + async with Client(mcp) as client: + result = await client.read_resource(AnyUrl("resource://test/data")) + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Template resource: test/data" + + async def test_template_with_query_params(self): + """Test RFC 6570 query parameters in resource templates.""" + mcp = FastMCP() + + @mcp.resource("data://{id}{?format,limit}") + def get_data(id: str, format: str = "json", limit: int = 10) -> str: + return f"id={id}, format={format}, limit={limit}" + + async with Client(mcp) as client: + result = await client.read_resource(AnyUrl("data://123")) + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "id=123, format=json, limit=10" + + result = await client.read_resource(AnyUrl("data://123?format=xml")) + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "id=123, format=xml, limit=10" + + result = await client.read_resource( + AnyUrl("data://123?format=csv&limit=50") + ) + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "id=123, format=csv, limit=50" + + async def test_templates_match_in_order_of_definition(self): + """If a wildcard template is defined first, it will take priority.""" + mcp = FastMCP() + + @mcp.resource("resource://{param*}") + def template_resource(param: str) -> str: + return f"Template resource 1: {param}" + + @mcp.resource("resource://{x}/{y}") + def template_resource_with_params(x: str, y: str) -> str: + return f"Template resource 2: {x}/{y}" + + async with Client(mcp) as client: + result = await client.read_resource(AnyUrl("resource://a/b/c")) + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Template resource 1: a/b/c" + + result = await client.read_resource(AnyUrl("resource://a/b")) + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Template resource 1: a/b" + + async def test_templates_shadow_each_other_reorder(self): + """If a wildcard template is defined second, it will *not* take priority.""" + mcp = FastMCP() + + @mcp.resource("resource://{x}/{y}") + def template_resource_with_params(x: str, y: str) -> str: + return f"Template resource 1: {x}/{y}" + + @mcp.resource("resource://{param*}") + def template_resource(param: str) -> str: + return f"Template resource 2: {param}" + + async with Client(mcp) as client: + result = await client.read_resource(AnyUrl("resource://a/b/c")) + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Template resource 2: a/b/c" + + result = await client.read_resource(AnyUrl("resource://a/b")) + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Template resource 1: a/b" + + async def test_resource_template_with_annotations(self): + """Test that resource template annotations are visible to clients.""" + mcp = FastMCP() + + @mcp.resource( + "api://users/{user_id}", + annotations={"httpMethod": "GET", "Cache-Control": "no-cache"}, + ) + def get_user(user_id: str) -> str: + return f"User {user_id} data" + + async with Client(mcp) as client: + templates = await client.list_resource_templates() + assert len(templates) == 1 + + template = templates[0] + assert template.uriTemplate == "api://users/{user_id}" + + assert template.annotations is not None + assert hasattr(template.annotations, "httpMethod") + assert getattr(template.annotations, "httpMethod") == "GET" + assert hasattr(template.annotations, "Cache-Control") + assert getattr(template.annotations, "Cache-Control") == "no-cache" + + +class TestResourceTemplateContext: + async def test_resource_template_context(self): + mcp = FastMCP() + + @mcp.resource("resource://{param}") + def resource_template(param: str, ctx: Context) -> str: + assert isinstance(ctx, Context) + return f"Resource template: {param} {ctx.request_id}" + + async with Client(mcp) as client: + result = await client.read_resource(AnyUrl("resource://test")) + assert isinstance(result[0], TextResourceContents) + assert result[0].text.startswith("Resource template: test 1") + + async def test_resource_template_context_with_callable_object(self): + mcp = FastMCP() + + class MyResource: + def __call__(self, param: str, ctx: Context) -> str: + return f"Resource template: {param} {ctx.request_id}" + + template = ResourceTemplate.from_function( + MyResource(), uri_template="resource://{param}" + ) + mcp.add_template(template) + + async with Client(mcp) as client: + result = await client.read_resource(AnyUrl("resource://test")) + assert isinstance(result[0], TextResourceContents) + assert result[0].text.startswith("Resource template: test 1") + + +class TestResourceDecorator: + async def test_no_resources_before_decorator(self): + mcp = FastMCP() + + with pytest.raises(McpError, match="Unknown resource"): + async with Client(mcp) as client: + await client.read_resource("resource://data") + + async def test_resource_decorator(self): + mcp = FastMCP() + + @mcp.resource("resource://data") + def get_data() -> str: + return "Hello, world!" + + async with Client(mcp) as client: + result = await client.read_resource("resource://data") + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Hello, world!" + + async def test_resource_decorator_incorrect_usage(self): + mcp = FastMCP() + + with pytest.raises( + TypeError, match="The @resource decorator was used incorrectly" + ): + + @mcp.resource # Missing parentheses #type: ignore + def get_data() -> str: + return "Hello, world!" + + async def test_resource_decorator_with_name(self): + mcp = FastMCP() + + @mcp.resource("resource://data", name="custom-data") + def get_data() -> str: + return "Hello, world!" + + resources_dict = await mcp.get_resources() + resources = list(resources_dict.values()) + assert len(resources) == 1 + assert resources[0].name == "custom-data" + + async with Client(mcp) as client: + result = await client.read_resource("resource://data") + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Hello, world!" + + async def test_resource_decorator_with_description(self): + mcp = FastMCP() + + @mcp.resource("resource://data", description="Data resource") + def get_data() -> str: + return "Hello, world!" + + resources_dict = await mcp.get_resources() + resources = list(resources_dict.values()) + assert len(resources) == 1 + assert resources[0].description == "Data resource" + + async def test_resource_decorator_with_tags(self): + """Test that the resource decorator properly sets tags.""" + mcp = FastMCP() + + @mcp.resource("resource://data", tags={"example", "test-tag"}) + def get_data() -> str: + return "Hello, world!" + + resources_dict = await mcp.get_resources() + resources = list(resources_dict.values()) + assert len(resources) == 1 + assert resources[0].tags == {"example", "test-tag"} + + async def test_resource_decorator_instance_method(self): + mcp = FastMCP() + + class MyClass: + def __init__(self, prefix: str): + self.prefix = prefix + + def get_data(self) -> str: + return f"{self.prefix} Hello, world!" + + obj = MyClass("My prefix:") + + mcp.add_resource( + Resource.from_function( + obj.get_data, uri="resource://data", name="instance-resource" + ) + ) + + async with Client(mcp) as client: + result = await client.read_resource("resource://data") + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "My prefix: Hello, world!" + + async def test_resource_decorator_classmethod(self): + mcp = FastMCP() + + class MyClass: + prefix = "Class prefix:" + + @classmethod + def get_data(cls) -> str: + return f"{cls.prefix} Hello, world!" + + mcp.add_resource( + Resource.from_function( + MyClass.get_data, uri="resource://data", name="class-resource" + ) + ) + + async with Client(mcp) as client: + result = await client.read_resource("resource://data") + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Class prefix: Hello, world!" + + async def test_resource_decorator_classmethod_error(self): + mcp = FastMCP() + + with pytest.raises(ValueError, match="To decorate a classmethod"): + + class MyClass: + @mcp.resource("resource://data") + @classmethod + def get_data(cls) -> None: + pass + + async def test_resource_decorator_staticmethod(self): + mcp = FastMCP() + + class MyClass: + @mcp.resource("resource://data") + @staticmethod + def get_data() -> str: + return "Static Hello, world!" + + async with Client(mcp) as client: + result = await client.read_resource("resource://data") + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Static Hello, world!" + + async def test_resource_decorator_async_function(self): + mcp = FastMCP() + + @mcp.resource("resource://data") + async def get_data() -> str: + return "Async Hello, world!" + + async with Client(mcp) as client: + result = await client.read_resource("resource://data") + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Async Hello, world!" + + async def test_resource_decorator_staticmethod_order(self): + """Test that both decorator orders work for static methods""" + mcp = FastMCP() + + class MyClass: + @mcp.resource("resource://data") # type: ignore[misc] + @staticmethod + def get_data() -> str: + return "Static Hello, world!" + + async with Client(mcp) as client: + result = await client.read_resource("resource://data") + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Static Hello, world!" + + async def test_resource_decorator_with_meta(self): + """Test that meta parameter is passed through the resource decorator.""" + mcp = FastMCP() + + meta_data = {"version": "1.0", "author": "test"} + + @mcp.resource("resource://data", meta=meta_data) + def get_data() -> str: + return "Hello, world!" + + resources_dict = await mcp.get_resources() + resource = resources_dict["resource://data"] + + assert resource.meta == meta_data + + async def test_resource_content_with_meta_in_response(self): + """Test that ResourceContent meta is passed through to MCP response.""" + mcp = FastMCP() + + @mcp.resource("resource://widget") + def get_widget() -> ResourceContent: + return ResourceContent( + content="content", + mime_type="text/html", + meta={"csp": "script-src 'self'", "version": "1.0"}, + ) + + async with Client(mcp) as client: + result = await client.read_resource("resource://widget") + assert len(result) == 1 + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "content" + assert result[0].mimeType == "text/html" + assert isinstance(result[0], TextResourceContents) + assert result[0].meta == {"csp": "script-src 'self'", "version": "1.0"} + + async def test_resource_content_binary_with_meta(self): + """Test that ResourceContent with binary content and meta works.""" + mcp = FastMCP() + + @mcp.resource("resource://binary") + def get_binary() -> ResourceContent: + return ResourceContent( + content=b"\x00\x01\x02", + meta={"encoding": "raw"}, + ) + + async with Client(mcp) as client: + result = await client.read_resource("resource://binary") + assert len(result) == 1 + assert hasattr(result[0], "blob") + assert isinstance(result[0], BlobResourceContents) + assert result[0].meta == {"encoding": "raw"} + + async def test_resource_content_without_meta(self): + """Test that ResourceContent without meta works (meta is None).""" + mcp = FastMCP() + + @mcp.resource("resource://plain") + def get_plain() -> ResourceContent: + return ResourceContent(content="plain content") + + async with Client(mcp) as client: + result = await client.read_resource("resource://plain") + assert len(result) == 1 + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "plain content" + assert isinstance(result[0], TextResourceContents) + assert result[0].meta is None + + +class TestTemplateDecorator: + async def test_template_decorator(self): + mcp = FastMCP() + + @mcp.resource("resource://{name}/data") + def get_data(name: str) -> str: + return f"Data for {name}" + + templates_dict = await mcp.get_resource_templates() + templates = list(templates_dict.values()) + assert len(templates) == 1 + assert templates[0].name == "get_data" + assert templates[0].uri_template == "resource://{name}/data" + + async with Client(mcp) as client: + result = await client.read_resource("resource://test/data") + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Data for test" + + async def test_template_decorator_incorrect_usage(self): + mcp = FastMCP() + + with pytest.raises( + TypeError, match="The @resource decorator was used incorrectly" + ): + + @mcp.resource # Missing parentheses #type: ignore + def get_data(name: str) -> str: + return f"Data for {name}" + + async def test_template_decorator_with_name(self): + mcp = FastMCP() + + @mcp.resource("resource://{name}/data", name="custom-template") + def get_data(name: str) -> str: + return f"Data for {name}" + + templates_dict = await mcp.get_resource_templates() + templates = list(templates_dict.values()) + assert len(templates) == 1 + assert templates[0].name == "custom-template" + + async with Client(mcp) as client: + result = await client.read_resource("resource://test/data") + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Data for test" + + async def test_template_decorator_with_description(self): + mcp = FastMCP() + + @mcp.resource("resource://{name}/data", description="Template description") + def get_data(name: str) -> str: + return f"Data for {name}" + + templates_dict = await mcp.get_resource_templates() + templates = list(templates_dict.values()) + assert len(templates) == 1 + assert templates[0].description == "Template description" + + async def test_template_decorator_instance_method(self): + mcp = FastMCP() + + class MyClass: + def __init__(self, prefix: str): + self.prefix = prefix + + def get_data(self, name: str) -> str: + return f"{self.prefix} Data for {name}" + + obj = MyClass("My prefix:") + template = ResourceTemplate.from_function( + obj.get_data, + uri_template="resource://{name}/data", + name="instance-template", + ) + mcp.add_template(template) + + async with Client(mcp) as client: + result = await client.read_resource("resource://test/data") + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "My prefix: Data for test" + + async def test_template_decorator_classmethod(self): + mcp = FastMCP() + + class MyClass: + prefix = "Class prefix:" + + @classmethod + def get_data(cls, name: str) -> str: + return f"{cls.prefix} Data for {name}" + + template = ResourceTemplate.from_function( + MyClass.get_data, + uri_template="resource://{name}/data", + name="class-template", + ) + mcp.add_template(template) + + async with Client(mcp) as client: + result = await client.read_resource("resource://test/data") + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Class prefix: Data for test" + + async def test_template_decorator_staticmethod(self): + mcp = FastMCP() + + class MyClass: + @mcp.resource("resource://{name}/data") + @staticmethod + def get_data(name: str) -> str: + return f"Static Data for {name}" + + async with Client(mcp) as client: + result = await client.read_resource("resource://test/data") + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Static Data for test" + + async def test_template_decorator_async_function(self): + mcp = FastMCP() + + @mcp.resource("resource://{name}/data") + async def get_data(name: str) -> str: + return f"Async Data for {name}" + + async with Client(mcp) as client: + result = await client.read_resource("resource://test/data") + assert isinstance(result[0], TextResourceContents) + assert result[0].text == "Async Data for test" + + async def test_template_decorator_with_tags(self): + """Test that the template decorator properly sets tags.""" + mcp = FastMCP() + + @mcp.resource("resource://{param}", tags={"template", "test-tag"}) + def template_resource(param: str) -> str: + return f"Template resource: {param}" + + templates_dict = await mcp.get_resource_templates() + template = templates_dict["resource://{param}"] + assert template.tags == {"template", "test-tag"} + + async def test_template_decorator_wildcard_param(self): + mcp = FastMCP() + + @mcp.resource("resource://{param*}") + def template_resource(param: str) -> str: + return f"Template resource: {param}" + + templates_dict = await mcp.get_resource_templates() + template = templates_dict["resource://{param*}"] + assert template.uri_template == "resource://{param*}" + assert template.name == "template_resource" + + async def test_template_decorator_with_meta(self): + """Test that meta parameter is passed through the template decorator.""" + mcp = FastMCP() + + meta_data = {"version": "2.0", "template": "test"} + + @mcp.resource("resource://{param}/data", meta=meta_data) + def get_template_data(param: str) -> str: + return f"Data for {param}" + + templates_dict = await mcp.get_resource_templates() + template = templates_dict["resource://{param}/data"] + + assert template.meta == meta_data diff --git a/tests/server/providers/test_local_provider_tools.py b/tests/server/providers/test_local_provider_tools.py new file mode 100644 index 0000000000..2a0e276f70 --- /dev/null +++ b/tests/server/providers/test_local_provider_tools.py @@ -0,0 +1,1466 @@ +"""Tests for tool behavior in LocalProvider. + +Tests cover: +- Tool return types and serialization +- Tool parameters and validation +- Tool output schemas +- Tool context injection +- Tool decorator patterns +""" + +import base64 +import datetime +import functools +import json +import uuid +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Annotated, Any, Literal + +import pytest +from inline_snapshot import snapshot +from mcp.types import ( + AudioContent, + BlobResourceContents, + EmbeddedResource, + ImageContent, + TextContent, +) +from pydantic import AnyUrl, BaseModel, Field, TypeAdapter +from typing_extensions import TypedDict + +from fastmcp import Client, Context, FastMCP +from fastmcp.client.client import CallToolResult +from fastmcp.exceptions import ToolError +from fastmcp.tools.tool import Tool, ToolResult +from fastmcp.utilities.json_schema import compress_schema +from fastmcp.utilities.types import Audio, File, Image + + +def _normalize_anyof_order(schema): + """Normalize the order of items in anyOf arrays for consistent comparison.""" + if isinstance(schema, dict): + if "anyOf" in schema: + schema = schema.copy() + schema["anyOf"] = sorted(schema["anyOf"], key=str) + return {k: _normalize_anyof_order(v) for k, v in schema.items()} + elif isinstance(schema, list): + return [_normalize_anyof_order(item) for item in schema] + return schema + + +class PersonTypedDict(TypedDict): + name: str + age: int + + +class PersonModel(BaseModel): + name: str + age: int + + +@dataclass +class PersonDataclass: + name: str + age: int + + +@pytest.fixture +def tool_server(): + mcp = FastMCP() + + @mcp.tool + def add(x: int, y: int) -> int: + return x + y + + @mcp.tool + def list_tool() -> list[str | int]: + return ["x", 2] + + @mcp.tool + def error_tool() -> None: + raise ValueError("Test error") + + @mcp.tool + def image_tool(path: str) -> Image: + return Image(path) + + @mcp.tool + def audio_tool(path: str) -> Audio: + return Audio(path) + + @mcp.tool + def file_tool(path: str) -> File: + return File(path) + + @mcp.tool + def mixed_content_tool() -> list[TextContent | ImageContent | EmbeddedResource]: + return [ + TextContent(type="text", text="Hello"), + ImageContent(type="image", data="abc", mimeType="application/octet-stream"), + EmbeddedResource( + type="resource", + resource=BlobResourceContents( + blob=base64.b64encode(b"abc").decode(), + mimeType="application/octet-stream", + uri=AnyUrl("file:///test.bin"), + ), + ), + ] + + @mcp.tool(output_schema=None) + def mixed_list_fn(image_path: str) -> list: + return [ + "text message", + Image(image_path), + {"key": "value"}, + TextContent(type="text", text="direct content"), + ] + + @mcp.tool(output_schema=None) + def mixed_audio_list_fn(audio_path: str) -> list: + return [ + "text message", + Audio(audio_path), + {"key": "value"}, + TextContent(type="text", text="direct content"), + ] + + @mcp.tool(output_schema=None) + def mixed_file_list_fn(file_path: str) -> list: + return [ + "text message", + File(file_path), + {"key": "value"}, + TextContent(type="text", text="direct content"), + ] + + @mcp.tool + def file_text_tool() -> File: + return File(data=b"hello world", format="plain") + + return mcp + + +class TestToolReturnTypes: + async def test_string(self): + mcp = FastMCP() + + @mcp.tool + def string_tool() -> str: + return "Hello, world!" + + async with Client(mcp) as client: + result = await client.call_tool("string_tool", {}) + assert result.data == "Hello, world!" + + async def test_bytes(self, tmp_path: Path): + mcp = FastMCP() + + @mcp.tool + def bytes_tool() -> bytes: + return b"Hello, world!" + + async with Client(mcp) as client: + result = await client.call_tool("bytes_tool", {}) + assert result.data == "Hello, world!" + + async def test_uuid(self): + mcp = FastMCP() + + test_uuid = uuid.uuid4() + + @mcp.tool + def uuid_tool() -> uuid.UUID: + return test_uuid + + async with Client(mcp) as client: + result = await client.call_tool("uuid_tool", {}) + assert result.data == str(test_uuid) + + async def test_path(self): + mcp = FastMCP() + + test_path = Path("/tmp/test.txt") + + @mcp.tool + def path_tool() -> Path: + return test_path + + async with Client(mcp) as client: + result = await client.call_tool("path_tool", {}) + assert result.data == str(test_path) + + async def test_datetime(self): + mcp = FastMCP() + + dt = datetime.datetime(2025, 4, 25, 1, 2, 3) + + @mcp.tool + def datetime_tool() -> datetime.datetime: + return dt + + async with Client(mcp) as client: + result = await client.call_tool("datetime_tool", {}) + assert result.data == dt + + async def test_image(self, tmp_path: Path): + mcp = FastMCP() + + @mcp.tool + def image_tool(path: str) -> Image: + return Image(path) + + image_path = tmp_path / "test.png" + image_path.write_bytes(b"fake png data") + + async with Client(mcp) as client: + result = await client.call_tool("image_tool", {"path": str(image_path)}) + assert result.structured_content is None + content = result.content[0] + assert isinstance(content, ImageContent) + assert content.type == "image" + assert content.mimeType == "image/png" + decoded = base64.b64decode(content.data) + assert decoded == b"fake png data" + + async def test_audio(self, tmp_path: Path): + mcp = FastMCP() + + @mcp.tool + def audio_tool(path: str) -> Audio: + return Audio(path) + + audio_path = tmp_path / "test.wav" + audio_path.write_bytes(b"fake wav data") + + async with Client(mcp) as client: + result = await client.call_tool("audio_tool", {"path": str(audio_path)}) + content = result.content[0] + assert isinstance(content, AudioContent) + assert content.type == "audio" + assert content.mimeType == "audio/wav" + decoded = base64.b64decode(content.data) + assert decoded == b"fake wav data" + + async def test_file(self, tmp_path: Path): + mcp = FastMCP() + + @mcp.tool + def file_tool(path: str) -> File: + return File(path) + + file_path = tmp_path / "test.bin" + file_path.write_bytes(b"test file data") + + async with Client(mcp) as client: + result = await client.call_tool("file_tool", {"path": str(file_path)}) + content = result.content[0] + assert isinstance(content, EmbeddedResource) + assert content.type == "resource" + resource = content.resource + assert resource.mimeType == "application/octet-stream" + assert hasattr(resource, "blob") + blob_data = getattr(resource, "blob") + decoded = base64.b64decode(blob_data) + assert decoded == b"test file data" + assert str(resource.uri) == file_path.resolve().as_uri() + + async def test_tool_mixed_content(self, tool_server: FastMCP): + async with Client(tool_server) as client: + result = await client.call_tool("mixed_content_tool", {}) + assert len(result.content) == 3 + content1 = result.content[0] + content2 = result.content[1] + content3 = result.content[2] + assert isinstance(content1, TextContent) + assert content1.text == "Hello" + assert isinstance(content2, ImageContent) + assert content2.mimeType == "application/octet-stream" + assert content2.data == "abc" + assert isinstance(content3, EmbeddedResource) + assert content3.type == "resource" + resource = content3.resource + assert resource.mimeType == "application/octet-stream" + assert hasattr(resource, "blob") + blob_data = getattr(resource, "blob") + decoded = base64.b64decode(blob_data) + assert decoded == b"abc" + + async def test_tool_mixed_list_with_image( + self, tool_server: FastMCP, tmp_path: Path + ): + """Test that lists containing Image objects and other types are handled + correctly. Items now preserve their original order.""" + image_path = tmp_path / "test.png" + image_path.write_bytes(b"test image data") + + async with Client(tool_server) as client: + result = await client.call_tool( + "mixed_list_fn", {"image_path": str(image_path)} + ) + assert len(result.content) == 4 + content1 = result.content[0] + assert isinstance(content1, TextContent) + assert content1.text == "text message" + content2 = result.content[1] + assert isinstance(content2, ImageContent) + assert content2.mimeType == "image/png" + assert base64.b64decode(content2.data) == b"test image data" + content3 = result.content[2] + assert isinstance(content3, TextContent) + assert json.loads(content3.text) == {"key": "value"} + content4 = result.content[3] + assert isinstance(content4, TextContent) + assert content4.text == "direct content" + + async def test_tool_mixed_list_with_audio( + self, tool_server: FastMCP, tmp_path: Path + ): + """Test that lists containing Audio objects and other types are handled + correctly. Items now preserve their original order.""" + audio_path = tmp_path / "test.wav" + audio_path.write_bytes(b"test audio data") + + async with Client(tool_server) as client: + result = await client.call_tool( + "mixed_audio_list_fn", {"audio_path": str(audio_path)} + ) + assert len(result.content) == 4 + content1 = result.content[0] + assert isinstance(content1, TextContent) + assert content1.text == "text message" + content2 = result.content[1] + assert isinstance(content2, AudioContent) + assert content2.mimeType == "audio/wav" + assert base64.b64decode(content2.data) == b"test audio data" + content3 = result.content[2] + assert isinstance(content3, TextContent) + assert json.loads(content3.text) == {"key": "value"} + content4 = result.content[3] + assert isinstance(content4, TextContent) + assert content4.text == "direct content" + + async def test_tool_mixed_list_with_file( + self, tool_server: FastMCP, tmp_path: Path + ): + """Test that lists containing File objects and other types are handled + correctly. Items now preserve their original order.""" + file_path = tmp_path / "test.bin" + file_path.write_bytes(b"test file data") + + async with Client(tool_server) as client: + result = await client.call_tool( + "mixed_file_list_fn", {"file_path": str(file_path)} + ) + assert len(result.content) == 4 + content1 = result.content[0] + assert isinstance(content1, TextContent) + assert content1.text == "text message" + content2 = result.content[1] + assert isinstance(content2, EmbeddedResource) + assert content2.type == "resource" + resource = content2.resource + assert resource.mimeType == "application/octet-stream" + assert hasattr(resource, "blob") + blob_data = getattr(resource, "blob") + assert base64.b64decode(blob_data) == b"test file data" + content3 = result.content[2] + assert isinstance(content3, TextContent) + assert json.loads(content3.text) == {"key": "value"} + content4 = result.content[3] + assert isinstance(content4, TextContent) + assert content4.text == "direct content" + + +class TestToolParameters: + async def test_parameter_descriptions_with_field_annotations(self): + mcp = FastMCP("Test Server") + + @mcp.tool + def greet( + name: Annotated[str, Field(description="The name to greet")], + title: Annotated[str, Field(description="Optional title", default="")], + ) -> str: + """A greeting tool""" + return f"Hello {title} {name}" + + async with Client(mcp) as client: + tools = await client.list_tools() + assert len(tools) == 1 + tool = tools[0] + + properties = tool.inputSchema["properties"] + assert "name" in properties + assert properties["name"]["description"] == "The name to greet" + assert "title" in properties + assert properties["title"]["description"] == "Optional title" + assert properties["title"]["default"] == "" + assert tool.inputSchema["required"] == ["name"] + + async def test_parameter_descriptions_with_field_defaults(self): + mcp = FastMCP("Test Server") + + @mcp.tool + def greet( + name: str = Field(description="The name to greet"), + title: str = Field(description="Optional title", default=""), + ) -> str: + """A greeting tool""" + return f"Hello {title} {name}" + + async with Client(mcp) as client: + tools = await client.list_tools() + assert len(tools) == 1 + tool = tools[0] + + properties = tool.inputSchema["properties"] + assert "name" in properties + assert properties["name"]["description"] == "The name to greet" + assert "title" in properties + assert properties["title"]["description"] == "Optional title" + assert properties["title"]["default"] == "" + assert tool.inputSchema["required"] == ["name"] + + async def test_tool_with_bytes_input(self): + mcp = FastMCP() + + @mcp.tool + def process_image(image: bytes) -> Image: + return Image(data=image) + + async with Client(mcp) as client: + result = await client.call_tool( + "process_image", {"image": b"fake png data"} + ) + assert result.structured_content is None + assert isinstance(result.content[0], ImageContent) + assert result.content[0].mimeType == "image/png" + assert result.content[0].data == base64.b64encode(b"fake png data").decode() + + async def test_tool_with_invalid_input(self): + mcp = FastMCP() + + @mcp.tool + def my_tool(x: int) -> int: + return x + 1 + + async with Client(mcp) as client: + with pytest.raises( + ToolError, + match="Input should be a valid integer", + ): + await client.call_tool("my_tool", {"x": "not an int"}) + + async def test_tool_int_coercion(self): + """Test that string ints are coerced by default.""" + mcp = FastMCP() + + @mcp.tool + def add_one(x: int) -> int: + return x + 1 + + async with Client(mcp) as client: + result = await client.call_tool("add_one", {"x": "42"}) + assert result.data == 43 + + async def test_tool_bool_coercion(self): + """Test that string bools are coerced by default.""" + mcp = FastMCP() + + @mcp.tool + def toggle(flag: bool) -> bool: + return not flag + + async with Client(mcp) as client: + result = await client.call_tool("toggle", {"flag": "true"}) + assert result.data is False + + result = await client.call_tool("toggle", {"flag": "false"}) + assert result.data is True + + async def test_annotated_field_validation(self): + mcp = FastMCP() + + @mcp.tool + def analyze(x: Annotated[int, Field(ge=1)]) -> None: + pass + + async with Client(mcp) as client: + with pytest.raises( + ToolError, + match="Input should be greater than or equal to 1", + ): + await client.call_tool("analyze", {"x": 0}) + + async def test_default_field_validation(self): + mcp = FastMCP() + + @mcp.tool + def analyze(x: int = Field(ge=1)) -> None: + pass + + async with Client(mcp) as client: + with pytest.raises( + ToolError, + match="Input should be greater than or equal to 1", + ): + await client.call_tool("analyze", {"x": 0}) + + async def test_default_field_is_still_required_if_no_default_specified(self): + mcp = FastMCP() + + @mcp.tool + def analyze(x: int = Field()) -> None: + pass + + async with Client(mcp) as client: + with pytest.raises(ToolError, match="Missing required argument"): + await client.call_tool("analyze", {}) + + async def test_literal_type_validation_error(self): + mcp = FastMCP() + + @mcp.tool + def analyze(x: Literal["a", "b"]) -> None: + pass + + async with Client(mcp) as client: + with pytest.raises( + ToolError, + match="Input should be 'a' or 'b'", + ): + await client.call_tool("analyze", {"x": "c"}) + + async def test_literal_type_validation_success(self): + mcp = FastMCP() + + @mcp.tool + def analyze(x: Literal["a", "b"]) -> str: + return x + + async with Client(mcp) as client: + result = await client.call_tool("analyze", {"x": "a"}) + assert result.data == "a" + + async def test_enum_type_validation_error(self): + mcp = FastMCP() + + class MyEnum(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + @mcp.tool + def analyze(x: MyEnum) -> str: + return x.value + + async with Client(mcp) as client: + with pytest.raises( + ToolError, + match="Input should be 'red', 'green' or 'blue'", + ): + await client.call_tool("analyze", {"x": "some-color"}) + + async def test_enum_type_validation_success(self): + mcp = FastMCP() + + class MyEnum(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + @mcp.tool + def analyze(x: MyEnum) -> str: + return x.value + + async with Client(mcp) as client: + result = await client.call_tool("analyze", {"x": "red"}) + assert result.data == "red" + + async def test_union_type_validation(self): + mcp = FastMCP() + + @mcp.tool + def analyze(x: int | float) -> str: + return str(x) + + async with Client(mcp) as client: + result = await client.call_tool("analyze", {"x": 1}) + assert result.data == "1" + + result = await client.call_tool("analyze", {"x": 1.0}) + assert result.data == "1.0" + + with pytest.raises( + ToolError, + match="Input should be a valid", + ): + await client.call_tool("analyze", {"x": "not a number"}) + + async def test_path_type(self): + mcp = FastMCP() + + @mcp.tool + def send_path(path: Path) -> str: + assert isinstance(path, Path) + return str(path) + + test_path = Path("tmp") / "test.txt" + + async with Client(mcp) as client: + result = await client.call_tool("send_path", {"path": str(test_path)}) + assert result.data == str(test_path) + + async def test_path_type_error(self): + mcp = FastMCP() + + @mcp.tool + def send_path(path: Path) -> str: + return str(path) + + async with Client(mcp) as client: + with pytest.raises(ToolError, match="Input is not a valid path"): + await client.call_tool("send_path", {"path": 1}) + + async def test_uuid_type(self): + mcp = FastMCP() + + @mcp.tool + def send_uuid(x: uuid.UUID) -> str: + assert isinstance(x, uuid.UUID) + return str(x) + + test_uuid = uuid.uuid4() + + async with Client(mcp) as client: + result = await client.call_tool("send_uuid", {"x": test_uuid}) + assert result.data == str(test_uuid) + + async def test_uuid_type_error(self): + mcp = FastMCP() + + @mcp.tool + def send_uuid(x: uuid.UUID) -> str: + return str(x) + + async with Client(mcp) as client: + with pytest.raises(ToolError, match="Input should be a valid UUID"): + await client.call_tool("send_uuid", {"x": "not a uuid"}) + + async def test_datetime_type(self): + mcp = FastMCP() + + @mcp.tool + def send_datetime(x: datetime.datetime) -> str: + return x.isoformat() + + dt = datetime.datetime(2025, 4, 25, 1, 2, 3) + + async with Client(mcp) as client: + result = await client.call_tool("send_datetime", {"x": dt}) + assert result.data == dt.isoformat() + + async def test_datetime_type_parse_string(self): + mcp = FastMCP() + + @mcp.tool + def send_datetime(x: datetime.datetime) -> str: + return x.isoformat() + + async with Client(mcp) as client: + result = await client.call_tool( + "send_datetime", {"x": "2021-01-01T00:00:00"} + ) + assert result.data == "2021-01-01T00:00:00" + + async def test_datetime_type_error(self): + mcp = FastMCP() + + @mcp.tool + def send_datetime(x: datetime.datetime) -> str: + return x.isoformat() + + async with Client(mcp) as client: + with pytest.raises(ToolError, match="Input should be a valid datetime"): + await client.call_tool("send_datetime", {"x": "not a datetime"}) + + async def test_date_type(self): + mcp = FastMCP() + + @mcp.tool + def send_date(x: datetime.date) -> str: + return x.isoformat() + + async with Client(mcp) as client: + result = await client.call_tool("send_date", {"x": datetime.date.today()}) + assert result.data == datetime.date.today().isoformat() + + async def test_date_type_parse_string(self): + mcp = FastMCP() + + @mcp.tool + def send_date(x: datetime.date) -> str: + return x.isoformat() + + async with Client(mcp) as client: + result = await client.call_tool("send_date", {"x": "2021-01-01"}) + assert result.data == "2021-01-01" + + async def test_timedelta_type(self): + mcp = FastMCP() + + @mcp.tool + def send_timedelta(x: datetime.timedelta) -> str: + return str(x) + + async with Client(mcp) as client: + result = await client.call_tool( + "send_timedelta", {"x": datetime.timedelta(days=1)} + ) + assert result.data == "1 day, 0:00:00" + + async def test_timedelta_type_parse_int(self): + """Test that int input is coerced to timedelta (seconds).""" + mcp = FastMCP() + + @mcp.tool + def send_timedelta(x: datetime.timedelta) -> str: + return str(x) + + async with Client(mcp) as client: + result = await client.call_tool("send_timedelta", {"x": 1000}) + assert ( + "0:16:40" in result.data or "16:40" in result.data + ) # 1000 seconds = 16 minutes 40 seconds + + async def test_annotated_string_description(self): + mcp = FastMCP() + + @mcp.tool + def f(x: Annotated[int, "A number"]): + return x + + async with Client(mcp) as client: + tools = await client.list_tools() + assert len(tools) == 1 + assert tools[0].inputSchema["properties"]["x"]["description"] == "A number" + + +class TestToolOutputSchema: + @pytest.mark.parametrize("annotation", [str, int, float, bool, list, AnyUrl]) + async def test_simple_output_schema(self, annotation): + mcp = FastMCP() + + @mcp.tool + def f() -> annotation: + return "hello" + + async with Client(mcp) as client: + tools = await client.list_tools() + assert len(tools) == 1 + + type_schema = TypeAdapter(annotation).json_schema() + type_schema = compress_schema(type_schema, prune_titles=True) + assert tools[0].outputSchema == { + "type": "object", + "properties": {"result": type_schema}, + "required": ["result"], + "x-fastmcp-wrap-result": True, + } + + @pytest.mark.parametrize( + "annotation", + [dict[str, int | str], PersonTypedDict, PersonModel, PersonDataclass], + ) + async def test_structured_output_schema(self, annotation): + mcp = FastMCP() + + @mcp.tool + def f() -> annotation: + return {"name": "John", "age": 30} + + async with Client(mcp) as client: + tools = await client.list_tools() + + type_schema = compress_schema( + TypeAdapter(annotation).json_schema(), prune_titles=True + ) + assert len(tools) == 1 + + actual_schema = _normalize_anyof_order(tools[0].outputSchema) + expected_schema = _normalize_anyof_order(type_schema) + assert actual_schema == expected_schema + + async def test_disabled_output_schema_no_structured_content(self): + mcp = FastMCP() + + @mcp.tool(output_schema=None) + def f() -> int: + return 42 + + async with Client(mcp) as client: + result = await client.call_tool("f", {}) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "42" + assert result.structured_content is None + assert result.data is None + + async def test_manual_structured_content(self): + mcp = FastMCP() + + @mcp.tool + def f() -> ToolResult: + return ToolResult( + content="Hello, world!", structured_content={"message": "Hello, world!"} + ) + + assert f.output_schema is None + + async with Client(mcp) as client: + result = await client.call_tool("f", {}) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Hello, world!" + assert result.structured_content == {"message": "Hello, world!"} + assert result.data == {"message": "Hello, world!"} + + async def test_output_schema_none_full_handshake(self): + """Test that output_schema=None works through full client/server + handshake. We test this by returning a scalar, which requires an output + schema to serialize.""" + mcp = FastMCP() + + @mcp.tool(output_schema=None) + def simple_tool() -> int: + return 42 + + async with Client(mcp) as client: + tools = await client.list_tools() + tool = next(t for t in tools if t.name == "simple_tool") + assert tool.outputSchema is None + + result = await client.call_tool("simple_tool", {}) + assert result.structured_content is None + assert result.data is None + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "42" + + async def test_output_schema_explicit_object_full_handshake(self): + """Test explicit object output schema through full client/server handshake.""" + mcp = FastMCP() + + @mcp.tool( + output_schema={ + "type": "object", + "properties": { + "greeting": {"type": "string"}, + "count": {"type": "integer"}, + }, + "required": ["greeting"], + } + ) + def explicit_tool() -> dict[str, Any]: + return {"greeting": "Hello", "count": 42} + + async with Client(mcp) as client: + tools = await client.list_tools() + tool = next(t for t in tools if t.name == "explicit_tool") + expected_schema = { + "type": "object", + "properties": { + "greeting": {"type": "string"}, + "count": {"type": "integer"}, + }, + "required": ["greeting"], + } + assert tool.outputSchema == expected_schema + + result = await client.call_tool("explicit_tool", {}) + assert result.structured_content == {"greeting": "Hello", "count": 42} + assert result.data is not None + assert result.data.greeting == "Hello" # type: ignore[attr-defined] + assert result.data.count == 42 # type: ignore[attr-defined] + + async def test_output_schema_wrapped_primitive_full_handshake(self): + """Test wrapped primitive output schema through full client/server handshake.""" + mcp = FastMCP() + + @mcp.tool + def primitive_tool() -> str: + return "Hello, primitives!" + + async with Client(mcp) as client: + tools = await client.list_tools() + tool = next(t for t in tools if t.name == "primitive_tool") + expected_schema = { + "type": "object", + "properties": {"result": {"type": "string"}}, + "required": ["result"], + "x-fastmcp-wrap-result": True, + } + assert tool.outputSchema == expected_schema + + result = await client.call_tool("primitive_tool", {}) + assert result.structured_content == {"result": "Hello, primitives!"} + assert result.data == "Hello, primitives!" + + async def test_output_schema_complex_type_full_handshake(self): + """Test complex type output schema through full client/server handshake.""" + mcp = FastMCP() + + @mcp.tool + def complex_tool() -> list[dict[str, int]]: + return [{"a": 1, "b": 2}, {"c": 3, "d": 4}] + + async with Client(mcp) as client: + tools = await client.list_tools() + tool = next(t for t in tools if t.name == "complex_tool") + expected_inner_schema = compress_schema( + TypeAdapter(list[dict[str, int]]).json_schema(), prune_titles=True + ) + expected_schema = { + "type": "object", + "properties": {"result": expected_inner_schema}, + "required": ["result"], + "x-fastmcp-wrap-result": True, + } + assert tool.outputSchema == expected_schema + + result = await client.call_tool("complex_tool", {}) + expected_data = [{"a": 1, "b": 2}, {"c": 3, "d": 4}] + assert result.structured_content == {"result": expected_data} + assert result.data is not None + + async def test_output_schema_dataclass_full_handshake(self): + """Test dataclass output schema through full client/server handshake.""" + mcp = FastMCP() + + @dataclass + class User: + name: str + age: int + + @mcp.tool + def dataclass_tool() -> User: + return User(name="Alice", age=30) + + async with Client(mcp) as client: + tools = await client.list_tools() + tool = next(t for t in tools if t.name == "dataclass_tool") + expected_schema = compress_schema( + TypeAdapter(User).json_schema(), prune_titles=True + ) + assert tool.outputSchema == expected_schema + assert ( + tool.outputSchema and "x-fastmcp-wrap-result" not in tool.outputSchema + ) + + result = await client.call_tool("dataclass_tool", {}) + assert result.structured_content == {"name": "Alice", "age": 30} + assert result.data is not None + assert result.data.name == "Alice" # type: ignore[attr-defined] + assert result.data.age == 30 # type: ignore[attr-defined] + + async def test_output_schema_mixed_content_types(self): + """Test tools with mixed content and output schemas.""" + mcp = FastMCP() + + @mcp.tool + def mixed_output() -> list[Any]: + return [ + "text message", + {"structured": "data"}, + TextContent(type="text", text="direct MCP content"), + ] + + async with Client(mcp) as client: + result = await client.call_tool("mixed_output", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent(type="text", text="text message"), + TextContent(type="text", text='{"structured":"data"}'), + TextContent(type="text", text="direct MCP content"), + ], + structured_content={ + "result": [ + "text message", + {"structured": "data"}, + { + "type": "text", + "text": "direct MCP content", + "annotations": None, + "_meta": None, + }, + ] + }, + data=[ + "text message", + {"structured": "data"}, + { + "type": "text", + "text": "direct MCP content", + "annotations": None, + "_meta": None, + }, + ], + meta=None, + ) + ) + + async def test_output_schema_serialization_edge_cases(self): + """Test edge cases in output schema serialization.""" + mcp = FastMCP() + + @mcp.tool + def edge_case_tool() -> tuple[int, str]: + return (42, "hello") + + async with Client(mcp) as client: + tools = await client.list_tools() + tool = next(t for t in tools if t.name == "edge_case_tool") + + assert tool.outputSchema and "x-fastmcp-wrap-result" in tool.outputSchema + + result = await client.call_tool("edge_case_tool", {}) + assert result.structured_content == {"result": [42, "hello"]} + assert result.data == [42, "hello"] + + +class TestToolContextInjection: + """Test context injection in tools.""" + + async def test_context_detection(self): + """Test that context parameters are properly detected.""" + mcp = FastMCP() + + @mcp.tool + def tool_with_context(x: int, ctx: Context) -> str: + return f"Request {ctx.request_id}: {x}" + + async with Client(mcp) as client: + tools = await client.list_tools() + assert len(tools) == 1 + assert tools[0].name == "tool_with_context" + + async def test_context_injection(self): + """Test that context is properly injected into tool calls.""" + mcp = FastMCP() + + @mcp.tool + def tool_with_context(x: int, ctx: Context) -> str: + assert isinstance(ctx, Context) + assert ctx.request_id is not None + return ctx.request_id + + async with Client(mcp) as client: + result = await client.call_tool("tool_with_context", {"x": 42}) + assert result.data == "1" + + async def test_async_context(self): + """Test that context works in async functions.""" + mcp = FastMCP() + + @mcp.tool + async def async_tool(x: int, ctx: Context) -> str: + assert ctx.request_id is not None + return f"Async request {ctx.request_id}: {x}" + + async with Client(mcp) as client: + result = await client.call_tool("async_tool", {"x": 42}) + assert result.data == "Async request 1: 42" + + async def test_optional_context(self): + """Test that context is optional.""" + mcp = FastMCP() + + @mcp.tool + def no_context(x: int) -> int: + return x * 2 + + async with Client(mcp) as client: + result = await client.call_tool("no_context", {"x": 21}) + assert result.data == 42 + + async def test_context_resource_access(self): + """Test that context can access resources.""" + mcp = FastMCP() + + @mcp.resource("test://data") + def test_resource() -> str: + return "resource data" + + @mcp.tool + async def tool_with_resource(ctx: Context) -> str: + r_iter = await ctx.read_resource("test://data") + r_list = list(r_iter) + assert len(r_list) == 1 + r = r_list[0] + return f"Read resource: {r.content} with mime type {r.mime_type}" + + async with Client(mcp) as client: + result = await client.call_tool("tool_with_resource", {}) + assert ( + result.data == "Read resource: resource data with mime type text/plain" + ) + + async def test_tool_decorator_with_tags(self): + """Test that the tool decorator properly sets tags.""" + mcp = FastMCP() + + @mcp.tool(tags={"example", "test-tag"}) + def sample_tool(x: int) -> int: + return x * 2 + + async with Client(mcp) as client: + tools = await client.list_tools() + assert len(tools) == 1 + + async def test_callable_object_with_context(self): + """Test that a callable object can be used as a tool with context.""" + mcp = FastMCP() + + class MyTool: + async def __call__(self, x: int, ctx: Context) -> int: + return x + int(ctx.request_id) + + mcp.add_tool(Tool.from_function(MyTool(), name="MyTool")) + + async with Client(mcp) as client: + result = await client.call_tool("MyTool", {"x": 2}) + assert result.data == 3 + + async def test_decorated_tool_with_functools_wraps(self): + """Regression test for #2524: @mcp.tool with functools.wraps decorator.""" + + def custom_decorator(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + return await func(*args, **kwargs) + + return wrapper + + mcp = FastMCP() + + @mcp.tool + @custom_decorator + async def decorated_tool(ctx: Context, query: str) -> str: + assert isinstance(ctx, Context) + return f"query: {query}" + + async with Client(mcp) as client: + tools = await client.list_tools() + tool = next(t for t in tools if t.name == "decorated_tool") + assert "ctx" not in tool.inputSchema.get("properties", {}) + + result = await client.call_tool("decorated_tool", {"query": "test"}) + assert result.data == "query: test" + + +class TestToolDecorator: + async def test_no_tools_before_decorator(self): + mcp = FastMCP() + + from fastmcp.exceptions import NotFoundError + + with pytest.raises(NotFoundError, match="Unknown tool: add"): + await mcp._call_tool_mcp("add", {"x": 1, "y": 2}) + + async def test_tool_decorator(self): + mcp = FastMCP() + + @mcp.tool + def add(x: int, y: int) -> int: + return x + y + + async with Client(mcp) as client: + result = await client.call_tool("add", {"x": 1, "y": 2}) + assert result.data == 3 + + async def test_tool_decorator_without_parentheses(self): + """Test that @tool decorator works without parentheses.""" + mcp = FastMCP() + + @mcp.tool + def add(x: int, y: int) -> int: + return x + y + + tools = await mcp.get_tools() + assert "add" in tools + + async with Client(mcp) as client: + result = await client.call_tool("add", {"x": 1, "y": 2}) + assert result.data == 3 + + async def test_tool_decorator_with_name(self): + mcp = FastMCP() + + @mcp.tool(name="custom-add") + def add(x: int, y: int) -> int: + return x + y + + async with Client(mcp) as client: + result = await client.call_tool("custom-add", {"x": 1, "y": 2}) + assert result.data == 3 + + async def test_tool_decorator_with_description(self): + mcp = FastMCP() + + @mcp.tool(description="Add two numbers") + def add(x: int, y: int) -> int: + return x + y + + tools = await mcp._list_tools_mcp() + assert len(tools) == 1 + tool = tools[0] + assert tool.description == "Add two numbers" + + async def test_tool_decorator_instance_method(self): + mcp = FastMCP() + + class MyClass: + def __init__(self, x: int): + self.x = x + + def add(self, y: int) -> int: + return self.x + y + + obj = MyClass(10) + mcp.add_tool(Tool.from_function(obj.add)) + async with Client(mcp) as client: + result = await client.call_tool("add", {"y": 2}) + assert result.data == 12 + + async def test_tool_decorator_classmethod(self): + mcp = FastMCP() + + class MyClass: + x: int = 10 + + @classmethod + def add(cls, y: int) -> int: + return cls.x + y + + mcp.add_tool(Tool.from_function(MyClass.add)) + async with Client(mcp) as client: + result = await client.call_tool("add", {"y": 2}) + assert result.data == 12 + + async def test_tool_decorator_staticmethod(self): + mcp = FastMCP() + + class MyClass: + @mcp.tool + @staticmethod + def add(x: int, y: int) -> int: + return x + y + + async with Client(mcp) as client: + result = await client.call_tool("add", {"x": 1, "y": 2}) + assert result.data == 3 + + async def test_tool_decorator_async_function(self): + mcp = FastMCP() + + @mcp.tool + async def add(x: int, y: int) -> int: + return x + y + + async with Client(mcp) as client: + result = await client.call_tool("add", {"x": 1, "y": 2}) + assert result.data == 3 + + async def test_tool_decorator_classmethod_error(self): + mcp = FastMCP() + + with pytest.raises(ValueError, match="To decorate a classmethod"): + + class MyClass: + @mcp.tool + @classmethod + def add(cls, y: int) -> None: + pass + + async def test_tool_decorator_classmethod_async_function(self): + mcp = FastMCP() + + class MyClass: + x = 10 + + @classmethod + async def add(cls, y: int) -> int: + return cls.x + y + + mcp.add_tool(Tool.from_function(MyClass.add)) + async with Client(mcp) as client: + result = await client.call_tool("add", {"y": 2}) + assert result.data == 12 + + async def test_tool_decorator_staticmethod_async_function(self): + mcp = FastMCP() + + class MyClass: + @staticmethod + async def add(x: int, y: int) -> int: + return x + y + + mcp.add_tool(Tool.from_function(MyClass.add)) + async with Client(mcp) as client: + result = await client.call_tool("add", {"x": 1, "y": 2}) + assert result.data == 3 + + async def test_tool_decorator_staticmethod_order(self): + """Test that the recommended decorator order works for static methods""" + mcp = FastMCP() + + class MyClass: + @mcp.tool + @staticmethod + def add_v1(x: int, y: int) -> int: + return x + y + + async with Client(mcp) as client: + result = await client.call_tool("add_v1", {"x": 1, "y": 2}) + assert result.data == 3 + + async def test_tool_decorator_with_tags(self): + """Test that the tool decorator properly sets tags.""" + mcp = FastMCP() + + @mcp.tool(tags={"example", "test-tag"}) + def sample_tool(x: int) -> int: + return x * 2 + + tools_dict = await mcp.get_tools() + assert len(tools_dict) == 1 + only_tool = next(iter(tools_dict.values())) + assert only_tool.tags == {"example", "test-tag"} + + async def test_add_tool_with_custom_name(self): + """Test adding a tool with a custom name using server.add_tool().""" + mcp = FastMCP() + + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + mcp.add_tool(Tool.from_function(multiply, name="custom_multiply")) + + tools = await mcp.get_tools() + assert "custom_multiply" in tools + + async with Client(mcp) as client: + result = await client.call_tool("custom_multiply", {"a": 5, "b": 3}) + assert result.data == 15 + + assert "multiply" not in tools + + async def test_tool_with_annotated_arguments(self): + """Test that tools with annotated arguments work correctly.""" + mcp = FastMCP() + + @mcp.tool + def add( + x: Annotated[int, Field(description="x is an int")], + y: Annotated[str, Field(description="y is not an int")], + ) -> None: + pass + + tool = (await mcp.get_tools())["add"] + assert tool.parameters["properties"]["x"]["description"] == "x is an int" + assert tool.parameters["properties"]["y"]["description"] == "y is not an int" + + async def test_tool_with_field_defaults(self): + """Test that tools with annotated arguments work correctly.""" + mcp = FastMCP() + + @mcp.tool + def add( + x: int = Field(description="x is an int"), + y: str = Field(description="y is not an int"), + ) -> None: + pass + + tool = (await mcp.get_tools())["add"] + assert tool.parameters["properties"]["x"]["description"] == "x is an int" + assert tool.parameters["properties"]["y"]["description"] == "y is not an int" + + async def test_tool_direct_function_call(self): + """Test that tools can be registered via direct function call.""" + mcp = FastMCP() + + def standalone_function(x: int, y: int) -> int: + """A standalone function to be registered.""" + return x + y + + from fastmcp.tools import FunctionTool + + result_fn = mcp.tool(standalone_function, name="direct_call_tool") + + assert isinstance(result_fn, FunctionTool) + + tools = await mcp.get_tools() + assert tools["direct_call_tool"] is result_fn + + async with Client(mcp) as client: + result = await client.call_tool("direct_call_tool", {"x": 5, "y": 3}) + assert result.data == 8 + + async def test_tool_decorator_with_string_name(self): + """Test that @tool("custom_name") syntax works correctly.""" + mcp = FastMCP() + + @mcp.tool("string_named_tool") + def my_function(x: int) -> str: + """A function with a string name.""" + return f"Result: {x}" + + tools = await mcp.get_tools() + assert "string_named_tool" in tools + assert "my_function" not in tools + + async with Client(mcp) as client: + result = await client.call_tool("string_named_tool", {"x": 42}) + assert result.data == "Result: 42" + + async def test_tool_decorator_conflicting_names_error(self): + """Test that providing both positional and keyword name raises an error.""" + mcp = FastMCP() + + with pytest.raises( + TypeError, + match="Cannot specify both a name as first argument and as keyword argument", + ): + + @mcp.tool("positional_name", name="keyword_name") + def my_function(x: int) -> str: + return f"Result: {x}" + + async def test_tool_decorator_with_output_schema(self): + mcp = FastMCP() + + with pytest.raises( + ValueError, match="Output schemas must represent object types" + ): + + @mcp.tool(output_schema={"type": "integer"}) + def my_function(x: int) -> str: + return f"Result: {x}" + + async def test_tool_decorator_with_meta(self): + """Test that meta parameter is passed through the tool decorator.""" + mcp = FastMCP() + + meta_data = {"version": "1.0", "author": "test"} + + @mcp.tool(meta=meta_data) + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + tools_dict = await mcp.get_tools() + tool = tools_dict["multiply"] + + assert tool.meta == meta_data diff --git a/tests/server/proxy/test_proxy_client.py b/tests/server/proxy/test_proxy_client.py index 371fb65801..8649e345df 100644 --- a/tests/server/proxy/test_proxy_client.py +++ b/tests/server/proxy/test_proxy_client.py @@ -416,8 +416,8 @@ async def test_client_factory_creates_fresh_sessions(self, fastmcp_server: FastM assert proxy_via_factory is not None # Verify they have the expected client factory behavior - assert hasattr(proxy_via_as_proxy, "_tool_manager") - assert hasattr(proxy_via_factory, "_tool_manager") + assert hasattr(proxy_via_as_proxy, "_local_provider") + assert hasattr(proxy_via_factory, "_local_provider") async def test_connected_client_reuses_sessions(self, fastmcp_server: FastMCP): """Test that connected clients passed to as_proxy reuse sessions (preserves #959 behavior).""" @@ -427,4 +427,4 @@ async def test_connected_client_reuses_sessions(self, fastmcp_server: FastMCP): # Verify the proxy is created successfully and uses session reuse assert proxy is not None - assert hasattr(proxy, "_tool_manager") + assert hasattr(proxy, "_local_provider") diff --git a/tests/server/tasks/test_sync_function_task_disabled.py b/tests/server/tasks/test_sync_function_task_disabled.py index 854605e61b..bdedbe01b2 100644 --- a/tests/server/tasks/test_sync_function_task_disabled.py +++ b/tests/server/tasks/test_sync_function_task_disabled.py @@ -137,7 +137,7 @@ async def async_resource() -> str: return "data" # Resource should have task mode="optional" and be a FunctionResource - resource = await mcp._resource_manager.get_resource("test://async") + resource = await mcp.get_resource("test://async") assert isinstance(resource, FunctionResource) assert resource.task_config.mode == "optional" @@ -179,7 +179,7 @@ def sync_resource() -> str: """A synchronous resource.""" return "data" - resource = await mcp._resource_manager.get_resource("test://sync") + resource = await mcp.get_resource("test://sync") assert isinstance(resource, FunctionResource) assert resource.task_config.mode == "forbidden" diff --git a/tests/server/tasks/test_task_config.py b/tests/server/tasks/test_task_config.py index e0158c1b09..6287df933d 100644 --- a/tests/server/tasks/test_task_config.py +++ b/tests/server/tasks/test_task_config.py @@ -30,7 +30,7 @@ async def test_task_true_normalizes_to_optional(self): async def my_tool() -> str: return "ok" - tool = await mcp._tool_manager.get_tool("my_tool") + tool = await mcp.get_tool("my_tool") assert isinstance(tool, Tool) assert tool.task_config.mode == "optional" @@ -42,7 +42,7 @@ async def test_task_false_normalizes_to_forbidden(self): async def my_tool() -> str: return "ok" - tool = await mcp._tool_manager.get_tool("my_tool") + tool = await mcp.get_tool("my_tool") assert isinstance(tool, Tool) assert tool.task_config.mode == "forbidden" @@ -54,7 +54,7 @@ async def test_task_config_passed_directly(self): async def my_tool() -> str: return "ok" - tool = await mcp._tool_manager.get_tool("my_tool") + tool = await mcp.get_tool("my_tool") assert isinstance(tool, Tool) assert tool.task_config.mode == "required" @@ -67,7 +67,7 @@ async def test_default_task_inherits_server_default(self): def my_tool_sync() -> str: return "ok" - tool = await mcp_no_tasks._tool_manager.get_tool("my_tool_sync") + tool = await mcp_no_tasks.get_tool("my_tool_sync") assert isinstance(tool, Tool) assert tool.task_config.mode == "forbidden" @@ -78,7 +78,7 @@ def my_tool_sync() -> str: async def my_tool_async() -> str: return "ok" - tool2 = await mcp_tasks._tool_manager.get_tool("my_tool_async") + tool2 = await mcp_tasks.get_tool("my_tool_async") assert isinstance(tool2, Tool) assert tool2.task_config.mode == "optional" @@ -350,7 +350,7 @@ async def test_sync_function_with_forbidden_mode_ok(self): def sync_tool() -> str: return "ok" - tool = await mcp._tool_manager.get_tool("sync_tool") + tool = await mcp.get_tool("sync_tool") assert isinstance(tool, Tool) assert tool.task_config.mode == "forbidden" @@ -376,7 +376,7 @@ async def test_tool_inherits_poll_interval(self): async def my_tool() -> str: return "ok" - tool = await mcp._tool_manager.get_tool("my_tool") + tool = await mcp.get_tool("my_tool") assert isinstance(tool, Tool) assert tool.task_config.poll_interval == timedelta(seconds=2) @@ -388,6 +388,6 @@ async def test_task_true_uses_default_poll_interval(self): async def my_tool() -> str: return "ok" - tool = await mcp._tool_manager.get_tool("my_tool") + tool = await mcp.get_tool("my_tool") assert isinstance(tool, Tool) assert tool.task_config.poll_interval == timedelta(seconds=5) diff --git a/tests/server/test_mount.py b/tests/server/test_mount.py index 06808951d4..429981d268 100644 --- a/tests/server/test_mount.py +++ b/tests/server/test_mount.py @@ -625,8 +625,8 @@ def temp_tool() -> str: tools = await main_app.get_tools() assert "sub_temp_tool" in tools - # Remove the tool from sub_app - sub_app._tool_manager._tools.pop("temp_tool") + # Remove the tool from sub_app using public API + sub_app.remove_tool("temp_tool") # The tool should no longer be accessible tools = await main_app.get_tools() @@ -864,7 +864,8 @@ async def test_as_proxy_defaults_false(self): sub = FastMCP("Sub") mcp.mount(sub, "sub") - provider = mcp._providers[0] + # Index 1 because LocalProvider is at index 0 + provider = mcp._providers[1] # With namespace, we get TransformingProvider wrapping FastMCPProvider assert isinstance(provider, TransformingProvider) assert isinstance(provider._wrapped, FastMCPProvider) @@ -876,7 +877,8 @@ async def test_as_proxy_false(self): mcp.mount(sub, "sub", as_proxy=False) - provider = mcp._providers[0] + # Index 1 because LocalProvider is at index 0 + provider = mcp._providers[1] # With namespace, we get TransformingProvider wrapping FastMCPProvider assert isinstance(provider, TransformingProvider) assert isinstance(provider._wrapped, FastMCPProvider) @@ -888,7 +890,8 @@ async def test_as_proxy_true(self): mcp.mount(sub, "sub", as_proxy=True) - provider = mcp._providers[0] + # Index 1 because LocalProvider is at index 0 + provider = mcp._providers[1] # With namespace, we get TransformingProvider wrapping FastMCPProvider assert isinstance(provider, TransformingProvider) assert isinstance(provider._wrapped, FastMCPProvider) @@ -912,7 +915,8 @@ async def server_lifespan(mcp: FastMCP): mcp.mount(sub, "sub") # Server should be mounted directly without auto-proxying - provider = mcp._providers[0] + # Index 1 because LocalProvider is at index 0 + provider = mcp._providers[1] assert isinstance(provider, TransformingProvider) assert isinstance(provider._wrapped, FastMCPProvider) assert provider._wrapped.server is sub @@ -924,7 +928,8 @@ async def test_as_proxy_ignored_for_proxy_mounts_default(self): mcp.mount(sub_proxy, "sub") - provider = mcp._providers[0] + # Index 1 because LocalProvider is at index 0 + provider = mcp._providers[1] assert isinstance(provider, TransformingProvider) assert isinstance(provider._wrapped, FastMCPProvider) assert provider._wrapped.server is sub_proxy @@ -936,7 +941,8 @@ async def test_as_proxy_ignored_for_proxy_mounts_false(self): mcp.mount(sub_proxy, "sub", as_proxy=False) - provider = mcp._providers[0] + # Index 1 because LocalProvider is at index 0 + provider = mcp._providers[1] assert isinstance(provider, TransformingProvider) assert isinstance(provider._wrapped, FastMCPProvider) assert provider._wrapped.server is sub_proxy @@ -948,7 +954,8 @@ async def test_as_proxy_ignored_for_proxy_mounts_true(self): mcp.mount(sub_proxy, "sub", as_proxy=True) - provider = mcp._providers[0] + # Index 1 because LocalProvider is at index 0 + provider = mcp._providers[1] assert isinstance(provider, TransformingProvider) assert isinstance(provider._wrapped, FastMCPProvider) assert provider._wrapped.server is sub_proxy @@ -1166,17 +1173,21 @@ async def test_route(request): async def test_mounted_servers_tracking(self): """Test that _providers list tracks mounted servers correctly.""" + from fastmcp.server.providers.local_provider import LocalProvider + main_server = FastMCP("MainServer") sub_server1 = FastMCP("SubServer1") sub_server2 = FastMCP("SubServer2") - # Initially no providers - assert len(main_server._providers) == 0 + # Initially only LocalProvider + assert len(main_server._providers) == 1 + assert isinstance(main_server._providers[0], LocalProvider) # Mount first server main_server.mount(sub_server1, "sub1") - assert len(main_server._providers) == 1 - provider1 = main_server._providers[0] + assert len(main_server._providers) == 2 + # LocalProvider is at index 0, mounted provider at index 1 + provider1 = main_server._providers[1] assert isinstance(provider1, TransformingProvider) assert isinstance(provider1._wrapped, FastMCPProvider) assert provider1._wrapped.server == sub_server1 @@ -1184,8 +1195,8 @@ async def test_mounted_servers_tracking(self): # Mount second server main_server.mount(sub_server2, "sub2") - assert len(main_server._providers) == 2 - provider2 = main_server._providers[1] + assert len(main_server._providers) == 3 + provider2 = main_server._providers[2] assert isinstance(provider2, TransformingProvider) assert isinstance(provider2._wrapped, FastMCPProvider) assert provider2._wrapped.server == sub_server2 diff --git a/tests/server/test_providers.py b/tests/server/test_providers.py index 205b493f52..9fe2d3bbe8 100644 --- a/tests/server/test_providers.py +++ b/tests/server/test_providers.py @@ -244,10 +244,10 @@ async def test_default_get_tool_falls_back_to_list(self, base_server: FastMCP): # Default get_tool should have called list_tools assert provider.list_tools_call_count >= 1 - async def test_dynamic_tools_come_first( + async def test_local_tools_come_first( self, base_server: FastMCP, dynamic_tools: list[Tool] ): - """Test that dynamic tools appear before static tools in list.""" + """Test that local tools (from LocalProvider) appear before other provider tools.""" provider = SimpleToolProvider(tools=dynamic_tools) base_server.add_provider(provider) @@ -255,8 +255,8 @@ async def test_dynamic_tools_come_first( tools: list[MCPTool] = await client.list_tools() tool_names = [tool.name for tool in tools] - # Dynamic tools should come first - assert tool_names[:2] == ["dynamic_multiply", "dynamic_add"] + # Local tools should come first (LocalProvider is first in _providers) + assert tool_names[:2] == ["static_add", "static_subtract"] async def test_empty_provider(self, base_server: FastMCP): """Test that empty provider doesn't affect behavior.""" diff --git a/tests/server/test_server.py b/tests/server/test_server.py index 10b89b847f..5450b0da35 100644 --- a/tests/server/test_server.py +++ b/tests/server/test_server.py @@ -1,17 +1,12 @@ from pathlib import Path from tempfile import TemporaryDirectory from textwrap import dedent -from typing import Annotated import pytest -from mcp import McpError -from mcp.types import BlobResourceContents, TextContent, TextResourceContents -from pydantic import Field +from mcp.types import TextContent, TextResourceContents from fastmcp import Client, FastMCP from fastmcp.exceptions import NotFoundError -from fastmcp.prompts.prompt import FunctionPrompt, Prompt, PromptResult -from fastmcp.resources import Resource, ResourceContent, ResourceTemplate from fastmcp.tools import FunctionTool from fastmcp.tools.tool import Tool @@ -118,1023 +113,71 @@ def g(x: int) -> int: assert tools["g-tool"].description == "add two to a number" -class TestToolDecorator: - async def test_no_tools_before_decorator(self): - mcp = FastMCP() - - with pytest.raises(NotFoundError, match="Unknown tool: add"): - await mcp._call_tool_mcp("add", {"x": 1, "y": 2}) +class TestServerDelegation: + """Test that FastMCP properly delegates to LocalProvider.""" - async def test_tool_decorator(self): + async def test_tool_decorator_delegates_to_local_provider(self): + """Test that @mcp.tool registers with the local provider.""" mcp = FastMCP() @mcp.tool - def add(x: int, y: int) -> int: - return x + y - - async with Client(mcp) as client: - result = await client.call_tool("add", {"x": 1, "y": 2}) - assert result.data == 3 - - async def test_tool_decorator_without_parentheses(self): - """Test that @tool decorator works without parentheses.""" - mcp = FastMCP() - - # Test the @tool syntax without parentheses - @mcp.tool - def add(x: int, y: int) -> int: - return x + y - - # Verify the tool was registered correctly - tools = await mcp.get_tools() - assert "add" in tools - - # Verify it can be called - async with Client(mcp) as client: - result = await client.call_tool("add", {"x": 1, "y": 2}) - assert result.data == 3 - - async def test_tool_decorator_with_name(self): - mcp = FastMCP() - - @mcp.tool(name="custom-add") - def add(x: int, y: int) -> int: - return x + y - - async with Client(mcp) as client: - result = await client.call_tool("custom-add", {"x": 1, "y": 2}) - assert result.data == 3 + def my_tool() -> str: + return "result" - async def test_tool_decorator_with_description(self): - mcp = FastMCP() - - @mcp.tool(description="Add two numbers") - def add(x: int, y: int) -> int: - return x + y - - tools = await mcp._list_tools_mcp() - assert len(tools) == 1 - tool = tools[0] - assert tool.description == "Add two numbers" - - async def test_tool_decorator_instance_method(self): - mcp = FastMCP() - - class MyClass: - def __init__(self, x: int): - self.x = x - - def add(self, y: int) -> int: - return self.x + y - - obj = MyClass(10) - mcp.add_tool(Tool.from_function(obj.add)) - async with Client(mcp) as client: - result = await client.call_tool("add", {"y": 2}) - assert result.data == 12 - - async def test_tool_decorator_classmethod(self): - mcp = FastMCP() - - class MyClass: - x: int = 10 - - @classmethod - def add(cls, y: int) -> int: - return cls.x + y - - mcp.add_tool(Tool.from_function(MyClass.add)) - async with Client(mcp) as client: - result = await client.call_tool("add", {"y": 2}) - assert result.data == 12 - - async def test_tool_decorator_staticmethod(self): - mcp = FastMCP() - - class MyClass: - @mcp.tool - @staticmethod - def add(x: int, y: int) -> int: - return x + y - - async with Client(mcp) as client: - result = await client.call_tool("add", {"x": 1, "y": 2}) - assert result.data == 3 - - async def test_tool_decorator_async_function(self): - mcp = FastMCP() - - @mcp.tool - async def add(x: int, y: int) -> int: - return x + y - - async with Client(mcp) as client: - result = await client.call_tool("add", {"x": 1, "y": 2}) - assert result.data == 3 + # Verify the tool is in the local provider + tool = await mcp._local_provider.get_tool("my_tool") + assert tool is not None + assert tool.name == "my_tool" - async def test_tool_decorator_classmethod_error(self): + async def test_resource_decorator_delegates_to_local_provider(self): + """Test that @mcp.resource registers with the local provider.""" mcp = FastMCP() - with pytest.raises(ValueError, match="To decorate a classmethod"): + @mcp.resource("resource://test") + def my_resource() -> str: + return "content" - class MyClass: - @mcp.tool - @classmethod - def add(cls, y: int) -> None: - pass + # Verify the resource is in the local provider + resource = await mcp._local_provider.get_resource("resource://test") + assert resource is not None - async def test_tool_decorator_classmethod_async_function(self): + async def test_prompt_decorator_delegates_to_local_provider(self): + """Test that @mcp.prompt registers with the local provider.""" mcp = FastMCP() - class MyClass: - x = 10 - - @classmethod - async def add(cls, y: int) -> int: - return cls.x + y - - mcp.add_tool(Tool.from_function(MyClass.add)) - async with Client(mcp) as client: - result = await client.call_tool("add", {"y": 2}) - assert result.data == 12 - - async def test_tool_decorator_staticmethod_async_function(self): - mcp = FastMCP() - - class MyClass: - @staticmethod - async def add(x: int, y: int) -> int: - return x + y - - mcp.add_tool(Tool.from_function(MyClass.add)) - async with Client(mcp) as client: - result = await client.call_tool("add", {"x": 1, "y": 2}) - assert result.data == 3 - - async def test_tool_decorator_staticmethod_order(self): - """Test that the recommended decorator order works for static methods""" - mcp = FastMCP() - - class MyClass: - @mcp.tool - @staticmethod - def add_v1(x: int, y: int) -> int: - return x + y - - # Test that the recommended order works - async with Client(mcp) as client: - result = await client.call_tool("add_v1", {"x": 1, "y": 2}) - assert result.data == 3 - - async def test_tool_decorator_with_tags(self): - """Test that the tool decorator properly sets tags.""" - mcp = FastMCP() - - @mcp.tool(tags={"example", "test-tag"}) - def sample_tool(x: int) -> int: - return x * 2 + @mcp.prompt + def my_prompt() -> str: + return "prompt content" - # Verify the tags were set correctly (local inventory) - tools_dict = await mcp._tool_manager.get_tools() - assert len(tools_dict) == 1 - only_tool = next(iter(tools_dict.values())) - assert only_tool.tags == {"example", "test-tag"} + # Verify the prompt is in the local provider + prompt = await mcp._local_provider.get_prompt("my_prompt") + assert prompt is not None + assert prompt.name == "my_prompt" - async def test_add_tool_with_custom_name(self): - """Test adding a tool with a custom name using server.add_tool().""" + async def test_add_tool_delegates_to_local_provider(self): + """Test that mcp.add_tool() registers with the local provider.""" mcp = FastMCP() - def multiply(a: int, b: int) -> int: - """Multiply two numbers.""" - return a * b - - mcp.add_tool(Tool.from_function(multiply, name="custom_multiply")) - - # Check that the tool is registered with the custom name - tools = await mcp.get_tools() - assert "custom_multiply" in tools - - # Call the tool by its custom name - async with Client(mcp) as client: - result = await client.call_tool("custom_multiply", {"a": 5, "b": 3}) - assert result.data == 15 + def standalone_tool() -> str: + return "result" - # Original name should not be registered - assert "multiply" not in tools + mcp.add_tool(FunctionTool.from_function(standalone_tool)) - async def test_tool_with_annotated_arguments(self): - """Test that tools with annotated arguments work correctly.""" - mcp = FastMCP() + # Verify the tool is in the local provider + tool = await mcp._local_provider.get_tool("standalone_tool") + assert tool is not None + assert tool.name == "standalone_tool" - @mcp.tool - def add( - x: Annotated[int, Field(description="x is an int")], - y: Annotated[str, Field(description="y is not an int")], - ) -> None: - pass - - tool = (await mcp.get_tools())["add"] - assert tool.parameters["properties"]["x"]["description"] == "x is an int" - assert tool.parameters["properties"]["y"]["description"] == "y is not an int" - - async def test_tool_with_field_defaults(self): - """Test that tools with annotated arguments work correctly.""" + async def test_get_tools_includes_local_provider_tools(self): + """Test that get_tools() returns tools from local provider.""" mcp = FastMCP() @mcp.tool - def add( - x: int = Field(description="x is an int"), - y: str = Field(description="y is not an int"), - ) -> None: - pass - - tool = (await mcp.get_tools())["add"] - assert tool.parameters["properties"]["x"]["description"] == "x is an int" - assert tool.parameters["properties"]["y"]["description"] == "y is not an int" - - async def test_tool_direct_function_call(self): - """Test that tools can be registered via direct function call.""" - mcp = FastMCP() - - def standalone_function(x: int, y: int) -> int: - """A standalone function to be registered.""" - return x + y - - # Register it directly using the new syntax - result_fn = mcp.tool(standalone_function, name="direct_call_tool") + def local_tool() -> str: + return "local" - # The function should be returned unchanged - assert isinstance(result_fn, FunctionTool) - - # Verify the tool was registered correctly tools = await mcp.get_tools() - assert tools["direct_call_tool"] is result_fn - - # Verify it can be called - async with Client(mcp) as client: - result = await client.call_tool("direct_call_tool", {"x": 5, "y": 3}) - assert result.data == 8 - - async def test_tool_decorator_with_string_name(self): - """Test that @tool("custom_name") syntax works correctly.""" - mcp = FastMCP() - - @mcp.tool("string_named_tool") - def my_function(x: int) -> str: - """A function with a string name.""" - return f"Result: {x}" - - # Verify the tool was registered with the custom name - tools = await mcp.get_tools() - assert "string_named_tool" in tools - assert "my_function" not in tools # Original name should not be registered - - # Verify it can be called - async with Client(mcp) as client: - result = await client.call_tool("string_named_tool", {"x": 42}) - assert result.data == "Result: 42" - - async def test_tool_decorator_conflicting_names_error(self): - """Test that providing both positional and keyword name raises an error.""" - mcp = FastMCP() - - with pytest.raises( - TypeError, - match="Cannot specify both a name as first argument and as keyword argument", - ): - - @mcp.tool("positional_name", name="keyword_name") - def my_function(x: int) -> str: - return f"Result: {x}" - - async def test_tool_decorator_with_output_schema(self): - mcp = FastMCP() - - with pytest.raises( - ValueError, match="Output schemas must represent object types" - ): - - @mcp.tool(output_schema={"type": "integer"}) - def my_function(x: int) -> str: - return f"Result: {x}" - - async def test_tool_decorator_with_meta(self): - """Test that meta parameter is passed through the tool decorator.""" - mcp = FastMCP() - - meta_data = {"version": "1.0", "author": "test"} - - @mcp.tool(meta=meta_data) - def multiply(a: int, b: int) -> int: - """Multiply two numbers.""" - return a * b - - tools_dict = await mcp.get_tools() - tool = tools_dict["multiply"] - - assert tool.meta == meta_data - - -class TestResourceDecorator: - async def test_no_resources_before_decorator(self): - mcp = FastMCP() - - with pytest.raises(McpError, match="Unknown resource"): - async with Client(mcp) as client: - await client.read_resource("resource://data") - - async def test_resource_decorator(self): - mcp = FastMCP() - - @mcp.resource("resource://data") - def get_data() -> str: - return "Hello, world!" - - async with Client(mcp) as client: - result = await client.read_resource("resource://data") - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Hello, world!" - - async def test_resource_decorator_incorrect_usage(self): - mcp = FastMCP() - - with pytest.raises( - TypeError, match="The @resource decorator was used incorrectly" - ): - - @mcp.resource # Missing parentheses #type: ignore - def get_data() -> str: - return "Hello, world!" - - async def test_resource_decorator_with_name(self): - mcp = FastMCP() - - @mcp.resource("resource://data", name="custom-data") - def get_data() -> str: - return "Hello, world!" - - resources_dict = await mcp.get_resources() - resources = list(resources_dict.values()) - assert len(resources) == 1 - assert resources[0].name == "custom-data" - - async with Client(mcp) as client: - result = await client.read_resource("resource://data") - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Hello, world!" - - async def test_resource_decorator_with_description(self): - mcp = FastMCP() - - @mcp.resource("resource://data", description="Data resource") - def get_data() -> str: - return "Hello, world!" - - resources_dict = await mcp.get_resources() - resources = list(resources_dict.values()) - assert len(resources) == 1 - assert resources[0].description == "Data resource" - - async def test_resource_decorator_with_tags(self): - """Test that the resource decorator properly sets tags.""" - mcp = FastMCP() - - @mcp.resource("resource://data", tags={"example", "test-tag"}) - def get_data() -> str: - return "Hello, world!" - - resources_dict = await mcp.get_resources() - resources = list(resources_dict.values()) - assert len(resources) == 1 - assert resources[0].tags == {"example", "test-tag"} - - async def test_resource_decorator_instance_method(self): - mcp = FastMCP() - - class MyClass: - def __init__(self, prefix: str): - self.prefix = prefix - - def get_data(self) -> str: - return f"{self.prefix} Hello, world!" - - obj = MyClass("My prefix:") - - mcp.add_resource( - Resource.from_function( - obj.get_data, uri="resource://data", name="instance-resource" - ) - ) - - async with Client(mcp) as client: - result = await client.read_resource("resource://data") - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "My prefix: Hello, world!" - - async def test_resource_decorator_classmethod(self): - mcp = FastMCP() - - class MyClass: - prefix = "Class prefix:" - - @classmethod - def get_data(cls) -> str: - return f"{cls.prefix} Hello, world!" - - mcp.add_resource( - Resource.from_function( - MyClass.get_data, uri="resource://data", name="class-resource" - ) - ) - - async with Client(mcp) as client: - result = await client.read_resource("resource://data") - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Class prefix: Hello, world!" - - async def test_resource_decorator_classmethod_error(self): - mcp = FastMCP() - - with pytest.raises(ValueError, match="To decorate a classmethod"): - - class MyClass: - @mcp.resource("resource://data") - @classmethod - def get_data(cls) -> None: - pass - - async def test_resource_decorator_staticmethod(self): - mcp = FastMCP() - - class MyClass: - @mcp.resource("resource://data") - @staticmethod - def get_data() -> str: - return "Static Hello, world!" - - async with Client(mcp) as client: - result = await client.read_resource("resource://data") - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Static Hello, world!" - - async def test_resource_decorator_async_function(self): - mcp = FastMCP() - - @mcp.resource("resource://data") - async def get_data() -> str: - return "Async Hello, world!" - - async with Client(mcp) as client: - result = await client.read_resource("resource://data") - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Async Hello, world!" - - async def test_resource_decorator_staticmethod_order(self): - """Test that both decorator orders work for static methods""" - mcp = FastMCP() - - class MyClass: - @mcp.resource("resource://data") # type: ignore[misc] # Type checker warns but runtime works - @staticmethod - def get_data() -> str: - return "Static Hello, world!" - - async with Client(mcp) as client: - result = await client.read_resource("resource://data") - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Static Hello, world!" - - async def test_resource_decorator_with_meta(self): - """Test that meta parameter is passed through the resource decorator.""" - mcp = FastMCP() - - meta_data = {"version": "1.0", "author": "test"} - - @mcp.resource("resource://data", meta=meta_data) - def get_data() -> str: - return "Hello, world!" - - resources_dict = await mcp.get_resources() - resource = resources_dict["resource://data"] - - assert resource.meta == meta_data - - async def test_resource_content_with_meta_in_response(self): - """Test that ResourceContent meta is passed through to MCP response.""" - mcp = FastMCP() - - @mcp.resource("resource://widget") - def get_widget() -> ResourceContent: - return ResourceContent( - content="content", - mime_type="text/html", - meta={"csp": "script-src 'self'", "version": "1.0"}, - ) - - async with Client(mcp) as client: - result = await client.read_resource("resource://widget") - assert len(result) == 1 - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "content" - assert result[0].mimeType == "text/html" - # Meta should be in the response - assert isinstance(result[0], TextResourceContents) - assert result[0].meta == {"csp": "script-src 'self'", "version": "1.0"} - - async def test_resource_content_binary_with_meta(self): - """Test that ResourceContent with binary content and meta works.""" - mcp = FastMCP() - - @mcp.resource("resource://binary") - def get_binary() -> ResourceContent: - return ResourceContent( - content=b"\x00\x01\x02", - meta={"encoding": "raw"}, - ) - - async with Client(mcp) as client: - result = await client.read_resource("resource://binary") - assert len(result) == 1 - # Binary content comes back as blob - assert hasattr(result[0], "blob") - assert isinstance(result[0], BlobResourceContents) - assert result[0].meta == {"encoding": "raw"} - - async def test_resource_content_without_meta(self): - """Test that ResourceContent without meta works (meta is None).""" - mcp = FastMCP() - - @mcp.resource("resource://plain") - def get_plain() -> ResourceContent: - return ResourceContent(content="plain content") - - async with Client(mcp) as client: - result = await client.read_resource("resource://plain") - assert len(result) == 1 - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "plain content" - # Meta should be None - assert isinstance(result[0], TextResourceContents) - assert result[0].meta is None - - -class TestTemplateDecorator: - async def test_template_decorator(self): - mcp = FastMCP() - - @mcp.resource("resource://{name}/data") - def get_data(name: str) -> str: - return f"Data for {name}" - - templates_dict = await mcp.get_resource_templates() - templates = list(templates_dict.values()) - assert len(templates) == 1 - assert templates[0].name == "get_data" - assert templates[0].uri_template == "resource://{name}/data" - - async with Client(mcp) as client: - result = await client.read_resource("resource://test/data") - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Data for test" - - async def test_template_decorator_incorrect_usage(self): - mcp = FastMCP() - - with pytest.raises( - TypeError, match="The @resource decorator was used incorrectly" - ): - - @mcp.resource # Missing parentheses #type: ignore - def get_data(name: str) -> str: - return f"Data for {name}" - - async def test_template_decorator_with_name(self): - mcp = FastMCP() - - @mcp.resource("resource://{name}/data", name="custom-template") - def get_data(name: str) -> str: - return f"Data for {name}" - - templates_dict = await mcp.get_resource_templates() - templates = list(templates_dict.values()) - assert len(templates) == 1 - assert templates[0].name == "custom-template" - - async with Client(mcp) as client: - result = await client.read_resource("resource://test/data") - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Data for test" - - async def test_template_decorator_with_description(self): - mcp = FastMCP() - - @mcp.resource("resource://{name}/data", description="Template description") - def get_data(name: str) -> str: - return f"Data for {name}" - - templates_dict = await mcp.get_resource_templates() - templates = list(templates_dict.values()) - assert len(templates) == 1 - assert templates[0].description == "Template description" - - async def test_template_decorator_instance_method(self): - mcp = FastMCP() - - class MyClass: - def __init__(self, prefix: str): - self.prefix = prefix - - def get_data(self, name: str) -> str: - return f"{self.prefix} Data for {name}" - - obj = MyClass("My prefix:") - template = ResourceTemplate.from_function( - obj.get_data, - uri_template="resource://{name}/data", - name="instance-template", - ) - mcp.add_template(template) - - async with Client(mcp) as client: - result = await client.read_resource("resource://test/data") - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "My prefix: Data for test" - - async def test_template_decorator_classmethod(self): - mcp = FastMCP() - - class MyClass: - prefix = "Class prefix:" - - @classmethod - def get_data(cls, name: str) -> str: - return f"{cls.prefix} Data for {name}" - - template = ResourceTemplate.from_function( - MyClass.get_data, - uri_template="resource://{name}/data", - name="class-template", - ) - mcp.add_template(template) - - async with Client(mcp) as client: - result = await client.read_resource("resource://test/data") - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Class prefix: Data for test" - - async def test_template_decorator_staticmethod(self): - mcp = FastMCP() - - class MyClass: - @mcp.resource("resource://{name}/data") - @staticmethod - def get_data(name: str) -> str: - return f"Static Data for {name}" - - async with Client(mcp) as client: - result = await client.read_resource("resource://test/data") - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Static Data for test" - - async def test_template_decorator_async_function(self): - mcp = FastMCP() - - @mcp.resource("resource://{name}/data") - async def get_data(name: str) -> str: - return f"Async Data for {name}" - - async with Client(mcp) as client: - result = await client.read_resource("resource://test/data") - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Async Data for test" - - async def test_template_decorator_with_tags(self): - """Test that the template decorator properly sets tags.""" - mcp = FastMCP() - - @mcp.resource("resource://{param}", tags={"template", "test-tag"}) - def template_resource(param: str) -> str: - return f"Template resource: {param}" - - templates_dict = await mcp.get_resource_templates() - template = templates_dict["resource://{param}"] - assert template.tags == {"template", "test-tag"} - - async def test_template_decorator_wildcard_param(self): - mcp = FastMCP() - - @mcp.resource("resource://{param*}") - def template_resource(param: str) -> str: - return f"Template resource: {param}" - - templates_dict = await mcp.get_resource_templates() - template = templates_dict["resource://{param*}"] - assert template.uri_template == "resource://{param*}" - assert template.name == "template_resource" - - async def test_template_decorator_with_meta(self): - """Test that meta parameter is passed through the template decorator.""" - mcp = FastMCP() - - meta_data = {"version": "2.0", "template": "test"} - - @mcp.resource("resource://{param}/data", meta=meta_data) - def get_template_data(param: str) -> str: - return f"Data for {param}" - - templates_dict = await mcp.get_resource_templates() - template = templates_dict["resource://{param}/data"] - - assert template.meta == meta_data - - -class TestPromptDecorator: - async def test_prompt_decorator(self): - mcp = FastMCP() - - @mcp.prompt - def fn() -> str: - return "Hello, world!" - - prompts_dict = await mcp.get_prompts() - assert len(prompts_dict) == 1 - prompt = prompts_dict["fn"] - assert prompt.name == "fn" - # Don't compare functions directly since validate_call wraps them - content = await prompt.render() - if not isinstance(content, PromptResult): - content = PromptResult.from_value(content) - assert isinstance(content.messages[0].content, TextContent) - assert content.messages[0].content.text == "Hello, world!" - - async def test_prompt_decorator_without_parentheses(self): - mcp = FastMCP() - - # This should now work correctly (not raise an error) - @mcp.prompt # No parentheses - this is now supported - def fn() -> str: - return "Hello, world!" - - # Verify the prompt was registered correctly - prompts = await mcp.get_prompts() - assert "fn" in prompts - - # Verify it can be called - async with Client(mcp) as client: - result = await client.get_prompt("fn") - assert len(result.messages) == 1 - assert isinstance(result.messages[0].content, TextContent) - assert result.messages[0].content.text == "Hello, world!" - - async def test_prompt_decorator_with_name(self): - mcp = FastMCP() - - @mcp.prompt(name="custom_name") - def fn() -> str: - return "Hello, world!" - - prompts_dict = await mcp.get_prompts() - assert len(prompts_dict) == 1 - prompt = prompts_dict["custom_name"] - assert prompt.name == "custom_name" - content = await prompt.render() - if not isinstance(content, PromptResult): - content = PromptResult.from_value(content) - assert isinstance(content.messages[0].content, TextContent) - assert content.messages[0].content.text == "Hello, world!" - - async def test_prompt_decorator_with_description(self): - mcp = FastMCP() - - @mcp.prompt(description="A custom description") - def fn() -> str: - return "Hello, world!" - - prompts_dict = await mcp.get_prompts() - assert len(prompts_dict) == 1 - prompt = prompts_dict["fn"] - assert prompt.description == "A custom description" - content = await prompt.render() - if not isinstance(content, PromptResult): - content = PromptResult.from_value(content) - assert isinstance(content.messages[0].content, TextContent) - assert content.messages[0].content.text == "Hello, world!" - - async def test_prompt_decorator_with_parameters(self): - mcp = FastMCP() - - @mcp.prompt - def test_prompt(name: str, greeting: str = "Hello") -> str: - return f"{greeting}, {name}!" - - prompts_dict = await mcp.get_prompts() - assert len(prompts_dict) == 1 - prompt = prompts_dict["test_prompt"] - assert prompt.arguments is not None - assert len(prompt.arguments) == 2 - assert prompt.arguments[0].name == "name" - assert prompt.arguments[0].required is True - assert prompt.arguments[1].name == "greeting" - assert prompt.arguments[1].required is False - - async with Client(mcp) as client: - result = await client.get_prompt("test_prompt", {"name": "World"}) - assert len(result.messages) == 1 - message = result.messages[0] - assert isinstance(message.content, TextContent) - assert message.content.text == "Hello, World!" - - result = await client.get_prompt( - "test_prompt", {"name": "World", "greeting": "Hi"} - ) - assert len(result.messages) == 1 - message = result.messages[0] - assert isinstance(message.content, TextContent) - assert message.content.text == "Hi, World!" - - async def test_prompt_decorator_instance_method(self): - mcp = FastMCP() - - class MyClass: - def __init__(self, prefix: str): - self.prefix = prefix - - def test_prompt(self) -> str: - return f"{self.prefix} Hello, world!" - - obj = MyClass("My prefix:") - mcp.add_prompt(Prompt.from_function(obj.test_prompt, name="test_prompt")) - - async with Client(mcp) as client: - result = await client.get_prompt("test_prompt") - assert len(result.messages) == 1 - message = result.messages[0] - assert isinstance(message.content, TextContent) - assert message.content.text == "My prefix: Hello, world!" - - async def test_prompt_decorator_classmethod(self): - mcp = FastMCP() - - class MyClass: - prefix = "Class prefix:" - - @classmethod - def test_prompt(cls) -> str: - return f"{cls.prefix} Hello, world!" - - mcp.add_prompt(Prompt.from_function(MyClass.test_prompt, name="test_prompt")) - - async with Client(mcp) as client: - result = await client.get_prompt("test_prompt") - assert len(result.messages) == 1 - message = result.messages[0] - assert isinstance(message.content, TextContent) - assert message.content.text == "Class prefix: Hello, world!" - - async def test_prompt_decorator_classmethod_error(self): - mcp = FastMCP() - - with pytest.raises(ValueError, match="To decorate a classmethod"): - - class MyClass: - @mcp.prompt - @classmethod - def test_prompt(cls) -> None: - pass - - async def test_prompt_decorator_staticmethod(self): - mcp = FastMCP() - - class MyClass: - @mcp.prompt - @staticmethod - def test_prompt() -> str: - return "Static Hello, world!" - - async with Client(mcp) as client: - result = await client.get_prompt("test_prompt") - assert len(result.messages) == 1 - message = result.messages[0] - assert isinstance(message.content, TextContent) - assert message.content.text == "Static Hello, world!" - - async def test_prompt_decorator_async_function(self): - mcp = FastMCP() - - @mcp.prompt - async def test_prompt() -> str: - return "Async Hello, world!" - - async with Client(mcp) as client: - result = await client.get_prompt("test_prompt") - assert len(result.messages) == 1 - message = result.messages[0] - assert isinstance(message.content, TextContent) - assert message.content.text == "Async Hello, world!" - - async def test_prompt_decorator_with_tags(self): - """Test that the prompt decorator properly sets tags.""" - mcp = FastMCP() - - @mcp.prompt(tags={"example", "test-tag"}) - def sample_prompt() -> str: - return "Hello, world!" - - prompts_dict = await mcp.get_prompts() - assert len(prompts_dict) == 1 - prompt = prompts_dict["sample_prompt"] - assert prompt.tags == {"example", "test-tag"} - - async def test_prompt_decorator_with_string_name(self): - """Test that @prompt(\"custom_name\") syntax works correctly.""" - mcp = FastMCP() - - @mcp.prompt("string_named_prompt") - def my_function() -> str: - """A function with a string name.""" - return "Hello from string named prompt!" - - # Verify the prompt was registered with the custom name - prompts = await mcp.get_prompts() - assert "string_named_prompt" in prompts - assert "my_function" not in prompts # Original name should not be registered - - # Verify it can be called - async with Client(mcp) as client: - result = await client.get_prompt("string_named_prompt") - assert len(result.messages) == 1 - assert isinstance(result.messages[0].content, TextContent) - assert result.messages[0].content.text == "Hello from string named prompt!" - - async def test_prompt_direct_function_call(self): - """Test that prompts can be registered via direct function call.""" - mcp = FastMCP() - - def standalone_function() -> str: - """A standalone function to be registered.""" - return "Hello from direct call!" - - # Register it directly using the new syntax - result_fn = mcp.prompt(standalone_function, name="direct_call_prompt") - - # The function should be returned unchanged - assert isinstance(result_fn, FunctionPrompt) - - # Verify the prompt was registered correctly - prompts = await mcp.get_prompts() - assert prompts["direct_call_prompt"] is result_fn - - # Verify it can be called - async with Client(mcp) as client: - result = await client.get_prompt("direct_call_prompt") - assert len(result.messages) == 1 - assert isinstance(result.messages[0].content, TextContent) - assert result.messages[0].content.text == "Hello from direct call!" - - async def test_prompt_decorator_conflicting_names_error(self): - """Test that providing both positional and keyword names raises an error.""" - mcp = FastMCP() - - with pytest.raises( - TypeError, - match="Cannot specify both a name as first argument and as keyword argument", - ): - - @mcp.prompt("positional_name", name="keyword_name") - def my_function() -> str: - return "Hello, world!" - - async def test_prompt_decorator_staticmethod_order(self): - """Test that both decorator orders work for static methods""" - mcp = FastMCP() - - class MyClass: - @mcp.prompt # type: ignore[misc] # Type checker warns but runtime works - @staticmethod - def test_prompt() -> str: - return "Static Hello, world!" - - async with Client(mcp) as client: - result = await client.get_prompt("test_prompt") - assert len(result.messages) == 1 - message = result.messages[0] - assert isinstance(message.content, TextContent) - assert message.content.text == "Static Hello, world!" - - async def test_prompt_decorator_with_meta(self): - """Test that meta parameter is passed through the prompt decorator.""" - mcp = FastMCP() - - meta_data = {"version": "3.0", "type": "prompt"} - - @mcp.prompt(meta=meta_data) - def test_prompt(message: str) -> str: - return f"Response: {message}" - - prompts_dict = await mcp.get_prompts() - prompt = prompts_dict["test_prompt"] - - assert prompt.meta == meta_data + assert "local_tool" in tools class TestResourcePrefixMounting: diff --git a/tests/server/test_server_interactions.py b/tests/server/test_server_interactions.py index 9ce4c60ff2..7a9a5b18f2 100644 --- a/tests/server/test_server_interactions.py +++ b/tests/server/test_server_interactions.py @@ -1,36 +1,25 @@ import base64 -import datetime -import functools -import json -import uuid from dataclasses import dataclass -from enum import Enum from pathlib import Path -from typing import Annotated, Any, Literal import pytest -from inline_snapshot import snapshot from mcp import McpError from mcp.types import ( - AudioContent, BlobResourceContents, EmbeddedResource, ImageContent, TextContent, TextResourceContents, ) -from pydantic import AnyUrl, BaseModel, Field, TypeAdapter +from pydantic import AnyUrl, BaseModel from typing_extensions import TypedDict -from fastmcp import Client, Context, FastMCP -from fastmcp.client.client import CallToolResult +from fastmcp import Client, FastMCP from fastmcp.client.transports import FastMCPTransport from fastmcp.exceptions import ToolError -from fastmcp.prompts.prompt import Prompt, PromptMessage, PromptResult -from fastmcp.resources import FileResource, ResourceTemplate +from fastmcp.prompts.prompt import PromptMessage, PromptResult +from fastmcp.resources import FileResource from fastmcp.resources.resource import FunctionResource -from fastmcp.tools.tool import Tool, ToolResult -from fastmcp.utilities.json_schema import compress_schema from fastmcp.utilities.tests import temporary_settings from fastmcp.utilities.types import Audio, File, Image @@ -278,1078 +267,6 @@ async def test_call_excluded_tool(self): assert result_2.data == 2 -class TestToolReturnTypes: - async def test_string(self): - mcp = FastMCP() - - @mcp.tool - def string_tool() -> str: - return "Hello, world!" - - async with Client(mcp) as client: - result = await client.call_tool("string_tool", {}) - assert result.data == "Hello, world!" - - async def test_bytes(self, tmp_path: Path): - mcp = FastMCP() - - @mcp.tool - def bytes_tool() -> bytes: - return b"Hello, world!" - - async with Client(mcp) as client: - result = await client.call_tool("bytes_tool", {}) - assert result.data == "Hello, world!" - - async def test_uuid(self): - mcp = FastMCP() - - test_uuid = uuid.uuid4() - - @mcp.tool - def uuid_tool() -> uuid.UUID: - return test_uuid - - async with Client(mcp) as client: - result = await client.call_tool("uuid_tool", {}) - assert result.data == str(test_uuid) - - async def test_path(self): - mcp = FastMCP() - - test_path = Path("/tmp/test.txt") - - @mcp.tool - def path_tool() -> Path: - return test_path - - async with Client(mcp) as client: - result = await client.call_tool("path_tool", {}) - assert result.data == str(test_path) - - async def test_datetime(self): - mcp = FastMCP() - - dt = datetime.datetime(2025, 4, 25, 1, 2, 3) - - @mcp.tool - def datetime_tool() -> datetime.datetime: - return dt - - async with Client(mcp) as client: - result = await client.call_tool("datetime_tool", {}) - assert result.data == dt - - async def test_image(self, tmp_path: Path): - mcp = FastMCP() - - @mcp.tool - def image_tool(path: str) -> Image: - return Image(path) - - # Create a test image - image_path = tmp_path / "test.png" - image_path.write_bytes(b"fake png data") - - async with Client(mcp) as client: - result = await client.call_tool("image_tool", {"path": str(image_path)}) - assert result.structured_content is None - content = result.content[0] - assert isinstance(content, ImageContent) - assert content.type == "image" - assert content.mimeType == "image/png" - # Verify base64 encoding - decoded = base64.b64decode(content.data) - assert decoded == b"fake png data" - - async def test_audio(self, tmp_path: Path): - mcp = FastMCP() - - @mcp.tool - def audio_tool(path: str) -> Audio: - return Audio(path) - - # Create a test audio file - audio_path = tmp_path / "test.wav" - audio_path.write_bytes(b"fake wav data") - - async with Client(mcp) as client: - result = await client.call_tool("audio_tool", {"path": str(audio_path)}) - content = result.content[0] - assert isinstance(content, AudioContent) - assert content.type == "audio" - assert content.mimeType == "audio/wav" - # Verify base64 encoding - decoded = base64.b64decode(content.data) - assert decoded == b"fake wav data" - - async def test_file(self, tmp_path: Path): - mcp = FastMCP() - - @mcp.tool - def file_tool(path: str) -> File: - return File(path) - - # Create a test file - file_path = tmp_path / "test.bin" - file_path.write_bytes(b"test file data") - - async with Client(mcp) as client: - result = await client.call_tool("file_tool", {"path": str(file_path)}) - content = result.content[0] - assert isinstance(content, EmbeddedResource) - assert content.type == "resource" - resource = content.resource - assert resource.mimeType == "application/octet-stream" - # Verify base64 encoding - assert hasattr(resource, "blob") - blob_data = getattr(resource, "blob") - decoded = base64.b64decode(blob_data) - assert decoded == b"test file data" - # Verify URI points to the file - assert str(resource.uri) == file_path.resolve().as_uri() - - async def test_tool_mixed_content(self, tool_server: FastMCP): - async with Client(tool_server) as client: - result = await client.call_tool("mixed_content_tool", {}) - assert len(result.content) == 3 - content1 = result.content[0] - content2 = result.content[1] - content3 = result.content[2] - assert isinstance(content1, TextContent) - assert content1.text == "Hello" - assert isinstance(content2, ImageContent) - assert content2.mimeType == "application/octet-stream" - assert content2.data == "abc" - assert isinstance(content3, EmbeddedResource) - assert content3.type == "resource" - resource = content3.resource - assert resource.mimeType == "application/octet-stream" - assert hasattr(resource, "blob") - blob_data = getattr(resource, "blob") - decoded = base64.b64decode(blob_data) - assert decoded == b"abc" - - async def test_tool_mixed_list_with_image( - self, tool_server: FastMCP, tmp_path: Path - ): - """Test that lists containing Image objects and other types are handled - correctly. Items now preserve their original order.""" - # Create a test image - image_path = tmp_path / "test.png" - image_path.write_bytes(b"test image data") - - async with Client(tool_server) as client: - result = await client.call_tool( - "mixed_list_fn", {"image_path": str(image_path)} - ) - assert len(result.content) == 4 # Now each item is separate - # Check text message (first item) - content1 = result.content[0] - assert isinstance(content1, TextContent) - assert content1.text == "text message" - # Check image conversion (second item) - content2 = result.content[1] - assert isinstance(content2, ImageContent) - assert content2.mimeType == "image/png" - assert base64.b64decode(content2.data) == b"test image data" - # Check dict content (third item) - content3 = result.content[2] - assert isinstance(content3, TextContent) - assert json.loads(content3.text) == {"key": "value"} - # Check direct TextContent (fourth item) - content4 = result.content[3] - assert isinstance(content4, TextContent) - assert content4.text == "direct content" - - async def test_tool_mixed_list_with_audio( - self, tool_server: FastMCP, tmp_path: Path - ): - """Test that lists containing Audio objects and other types are handled - correctly. Items now preserve their original order.""" - # Create a test audio file - audio_path = tmp_path / "test.wav" - audio_path.write_bytes(b"test audio data") - - async with Client(tool_server) as client: - result = await client.call_tool( - "mixed_audio_list_fn", {"audio_path": str(audio_path)} - ) - assert len(result.content) == 4 # Now each item is separate - # Check text message (first item) - content1 = result.content[0] - assert isinstance(content1, TextContent) - assert content1.text == "text message" - # Check audio conversion (second item) - content2 = result.content[1] - assert isinstance(content2, AudioContent) - assert content2.mimeType == "audio/wav" - assert base64.b64decode(content2.data) == b"test audio data" - # Check dict content (third item) - content3 = result.content[2] - assert isinstance(content3, TextContent) - assert json.loads(content3.text) == {"key": "value"} - # Check direct TextContent (fourth item) - content4 = result.content[3] - assert isinstance(content4, TextContent) - assert content4.text == "direct content" - - async def test_tool_mixed_list_with_file( - self, tool_server: FastMCP, tmp_path: Path - ): - """Test that lists containing File objects and other types are handled - correctly. Items now preserve their original order.""" - # Create a test file - file_path = tmp_path / "test.bin" - file_path.write_bytes(b"test file data") - - async with Client(tool_server) as client: - result = await client.call_tool( - "mixed_file_list_fn", {"file_path": str(file_path)} - ) - assert len(result.content) == 4 # Now each item is separate - # Check text message (first item) - content1 = result.content[0] - assert isinstance(content1, TextContent) - assert content1.text == "text message" - # Check file conversion (second item) - content2 = result.content[1] - assert isinstance(content2, EmbeddedResource) - assert content2.type == "resource" - resource = content2.resource - assert resource.mimeType == "application/octet-stream" - assert hasattr(resource, "blob") - blob_data = getattr(resource, "blob") - assert base64.b64decode(blob_data) == b"test file data" - # Check dict content (third item) - content3 = result.content[2] - assert isinstance(content3, TextContent) - assert json.loads(content3.text) == {"key": "value"} - # Check direct TextContent (fourth item) - content4 = result.content[3] - assert isinstance(content4, TextContent) - assert content4.text == "direct content" - - -class TestToolParameters: - async def test_parameter_descriptions_with_field_annotations(self): - mcp = FastMCP("Test Server") - - @mcp.tool - def greet( - name: Annotated[str, Field(description="The name to greet")], - title: Annotated[str, Field(description="Optional title", default="")], - ) -> str: - """A greeting tool""" - return f"Hello {title} {name}" - - async with Client(mcp) as client: - tools = await client.list_tools() - assert len(tools) == 1 - tool = tools[0] - - # Check that parameter descriptions are present in the schema - properties = tool.inputSchema["properties"] - assert "name" in properties - assert properties["name"]["description"] == "The name to greet" - assert "title" in properties - assert properties["title"]["description"] == "Optional title" - assert properties["title"]["default"] == "" - assert tool.inputSchema["required"] == ["name"] - - async def test_parameter_descriptions_with_field_defaults(self): - mcp = FastMCP("Test Server") - - @mcp.tool - def greet( - name: str = Field(description="The name to greet"), - title: str = Field(description="Optional title", default=""), - ) -> str: - """A greeting tool""" - return f"Hello {title} {name}" - - async with Client(mcp) as client: - tools = await client.list_tools() - assert len(tools) == 1 - tool = tools[0] - - # Check that parameter descriptions are present in the schema - properties = tool.inputSchema["properties"] - assert "name" in properties - assert properties["name"]["description"] == "The name to greet" - assert "title" in properties - assert properties["title"]["description"] == "Optional title" - assert properties["title"]["default"] == "" - assert tool.inputSchema["required"] == ["name"] - - async def test_tool_with_bytes_input(self): - mcp = FastMCP() - - @mcp.tool - def process_image(image: bytes) -> Image: - return Image(data=image) - - async with Client(mcp) as client: - result = await client.call_tool( - "process_image", {"image": b"fake png data"} - ) - assert result.structured_content is None - assert isinstance(result.content[0], ImageContent) - assert result.content[0].mimeType == "image/png" - assert result.content[0].data == base64.b64encode(b"fake png data").decode() - - async def test_tool_with_invalid_input(self): - mcp = FastMCP() - - @mcp.tool - def my_tool(x: int) -> int: - return x + 1 - - async with Client(mcp) as client: - with pytest.raises( - ToolError, - match="Input should be a valid integer", - ): - await client.call_tool("my_tool", {"x": "not an int"}) - - async def test_tool_int_coercion(self): - """Test that string ints are coerced by default.""" - mcp = FastMCP() - - @mcp.tool - def add_one(x: int) -> int: - return x + 1 - - async with Client(mcp) as client: - # String input should be coerced with default settings - result = await client.call_tool("add_one", {"x": "42"}) - assert result.data == 43 - - async def test_tool_bool_coercion(self): - """Test that string bools are coerced by default.""" - mcp = FastMCP() - - @mcp.tool - def toggle(flag: bool) -> bool: - return not flag - - async with Client(mcp) as client: - # String input should be coerced with default settings - result = await client.call_tool("toggle", {"flag": "true"}) - assert result.data is False - - result = await client.call_tool("toggle", {"flag": "false"}) - assert result.data is True - - async def test_annotated_field_validation(self): - mcp = FastMCP() - - @mcp.tool - def analyze(x: Annotated[int, Field(ge=1)]) -> None: - pass - - async with Client(mcp) as client: - with pytest.raises( - ToolError, - match="Input should be greater than or equal to 1", - ): - await client.call_tool("analyze", {"x": 0}) - - async def test_default_field_validation(self): - mcp = FastMCP() - - @mcp.tool - def analyze(x: int = Field(ge=1)) -> None: - pass - - async with Client(mcp) as client: - with pytest.raises( - ToolError, - match="Input should be greater than or equal to 1", - ): - await client.call_tool("analyze", {"x": 0}) - - async def test_default_field_is_still_required_if_no_default_specified(self): - mcp = FastMCP() - - @mcp.tool - def analyze(x: int = Field()) -> None: - pass - - async with Client(mcp) as client: - with pytest.raises(ToolError, match="Missing required argument"): - await client.call_tool("analyze", {}) - - async def test_literal_type_validation_error(self): - mcp = FastMCP() - - @mcp.tool - def analyze(x: Literal["a", "b"]) -> None: - pass - - async with Client(mcp) as client: - with pytest.raises( - ToolError, - match="Input should be 'a' or 'b'", - ): - await client.call_tool("analyze", {"x": "c"}) - - async def test_literal_type_validation_success(self): - mcp = FastMCP() - - @mcp.tool - def analyze(x: Literal["a", "b"]) -> str: - return x - - async with Client(mcp) as client: - result = await client.call_tool("analyze", {"x": "a"}) - assert result.data == "a" - - async def test_enum_type_validation_error(self): - mcp = FastMCP() - - class MyEnum(Enum): - RED = "red" - GREEN = "green" - BLUE = "blue" - - @mcp.tool - def analyze(x: MyEnum) -> str: - return x.value - - async with Client(mcp) as client: - with pytest.raises( - ToolError, - match="Input should be 'red', 'green' or 'blue'", - ): - await client.call_tool("analyze", {"x": "some-color"}) - - async def test_enum_type_validation_success(self): - mcp = FastMCP() - - class MyEnum(Enum): - RED = "red" - GREEN = "green" - BLUE = "blue" - - @mcp.tool - def analyze(x: MyEnum) -> str: - return x.value - - async with Client(mcp) as client: - result = await client.call_tool("analyze", {"x": "red"}) - assert result.data == "red" - - async def test_union_type_validation(self): - mcp = FastMCP() - - @mcp.tool - def analyze(x: int | float) -> str: - return str(x) - - async with Client(mcp) as client: - result = await client.call_tool("analyze", {"x": 1}) - assert result.data == "1" - - result = await client.call_tool("analyze", {"x": 1.0}) - assert result.data == "1.0" - - with pytest.raises( - ToolError, - match="Input should be a valid", - ): - await client.call_tool("analyze", {"x": "not a number"}) - - async def test_path_type(self): - mcp = FastMCP() - - @mcp.tool - def send_path(path: Path) -> str: - assert isinstance(path, Path) - return str(path) - - # Use a platform-independent path - test_path = Path("tmp") / "test.txt" - - async with Client(mcp) as client: - result = await client.call_tool("send_path", {"path": str(test_path)}) - assert result.data == str(test_path) - - async def test_path_type_error(self): - mcp = FastMCP() - - @mcp.tool - def send_path(path: Path) -> str: - return str(path) - - async with Client(mcp) as client: - with pytest.raises(ToolError, match="Input is not a valid path"): - await client.call_tool("send_path", {"path": 1}) - - async def test_uuid_type(self): - mcp = FastMCP() - - @mcp.tool - def send_uuid(x: uuid.UUID) -> str: - assert isinstance(x, uuid.UUID) - return str(x) - - test_uuid = uuid.uuid4() - - async with Client(mcp) as client: - result = await client.call_tool("send_uuid", {"x": test_uuid}) - assert result.data == str(test_uuid) - - async def test_uuid_type_error(self): - mcp = FastMCP() - - @mcp.tool - def send_uuid(x: uuid.UUID) -> str: - return str(x) - - async with Client(mcp) as client: - with pytest.raises(ToolError, match="Input should be a valid UUID"): - await client.call_tool("send_uuid", {"x": "not a uuid"}) - - async def test_datetime_type(self): - mcp = FastMCP() - - @mcp.tool - def send_datetime(x: datetime.datetime) -> str: - return x.isoformat() - - dt = datetime.datetime(2025, 4, 25, 1, 2, 3) - - async with Client(mcp) as client: - result = await client.call_tool("send_datetime", {"x": dt}) - assert result.data == dt.isoformat() - - async def test_datetime_type_parse_string(self): - mcp = FastMCP() - - @mcp.tool - def send_datetime(x: datetime.datetime) -> str: - return x.isoformat() - - async with Client(mcp) as client: - result = await client.call_tool( - "send_datetime", {"x": "2021-01-01T00:00:00"} - ) - assert result.data == "2021-01-01T00:00:00" - - async def test_datetime_type_error(self): - mcp = FastMCP() - - @mcp.tool - def send_datetime(x: datetime.datetime) -> str: - return x.isoformat() - - async with Client(mcp) as client: - with pytest.raises(ToolError, match="Input should be a valid datetime"): - await client.call_tool("send_datetime", {"x": "not a datetime"}) - - async def test_date_type(self): - mcp = FastMCP() - - @mcp.tool - def send_date(x: datetime.date) -> str: - return x.isoformat() - - async with Client(mcp) as client: - result = await client.call_tool("send_date", {"x": datetime.date.today()}) - assert result.data == datetime.date.today().isoformat() - - async def test_date_type_parse_string(self): - mcp = FastMCP() - - @mcp.tool - def send_date(x: datetime.date) -> str: - return x.isoformat() - - async with Client(mcp) as client: - result = await client.call_tool("send_date", {"x": "2021-01-01"}) - assert result.data == "2021-01-01" - - async def test_timedelta_type(self): - mcp = FastMCP() - - @mcp.tool - def send_timedelta(x: datetime.timedelta) -> str: - return str(x) - - async with Client(mcp) as client: - result = await client.call_tool( - "send_timedelta", {"x": datetime.timedelta(days=1)} - ) - assert result.data == "1 day, 0:00:00" - - async def test_timedelta_type_parse_int(self): - """Test that int input is coerced to timedelta (seconds).""" - mcp = FastMCP() - - @mcp.tool - def send_timedelta(x: datetime.timedelta) -> str: - return str(x) - - async with Client(mcp) as client: - # Int input should be coerced to timedelta (seconds) - result = await client.call_tool("send_timedelta", {"x": 1000}) - assert ( - "0:16:40" in result.data or "16:40" in result.data - ) # 1000 seconds = 16 minutes 40 seconds - - async def test_annotated_string_description(self): - mcp = FastMCP() - - @mcp.tool - def f(x: Annotated[int, "A number"]): - return x - - async with Client(mcp) as client: - tools = await client.list_tools() - assert len(tools) == 1 - assert tools[0].inputSchema["properties"]["x"]["description"] == "A number" - - -class TestToolOutputSchema: - @pytest.mark.parametrize("annotation", [str, int, float, bool, list, AnyUrl]) - async def test_simple_output_schema(self, annotation): - mcp = FastMCP() - - @mcp.tool - def f() -> annotation: - return "hello" - - async with Client(mcp) as client: - tools = await client.list_tools() - assert len(tools) == 1 - - type_schema = TypeAdapter(annotation).json_schema() - # Remove title fields from the schema for comparison (title pruning is enabled) - type_schema = compress_schema(type_schema, prune_titles=True) - # this line will fail until MCP adds output schemas!! - assert tools[0].outputSchema == { - "type": "object", - "properties": {"result": type_schema}, - "required": ["result"], - "x-fastmcp-wrap-result": True, - } - - @pytest.mark.parametrize( - "annotation", - [dict[str, int | str], PersonTypedDict, PersonModel, PersonDataclass], - ) - async def test_structured_output_schema(self, annotation): - mcp = FastMCP() - - @mcp.tool - def f() -> annotation: - return {"name": "John", "age": 30} - - async with Client(mcp) as client: - tools = await client.list_tools() - - type_schema = compress_schema( - TypeAdapter(annotation).json_schema(), prune_titles=True - ) - assert len(tools) == 1 - - # Normalize anyOf ordering for comparison since union type order - # can vary between environments when using annotation resolution - actual_schema = _normalize_anyof_order(tools[0].outputSchema) - expected_schema = _normalize_anyof_order(type_schema) - assert actual_schema == expected_schema - - async def test_disabled_output_schema_no_structured_content(self): - mcp = FastMCP() - - @mcp.tool(output_schema=None) - def f() -> int: - return 42 - - async with Client(mcp) as client: - result = await client.call_tool("f", {}) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "42" - assert result.structured_content is None - assert result.data is None - - async def test_manual_structured_content(self): - mcp = FastMCP() - - @mcp.tool - def f() -> ToolResult: - return ToolResult( - content="Hello, world!", structured_content={"message": "Hello, world!"} - ) - - assert f.output_schema is None - - async with Client(mcp) as client: - result = await client.call_tool("f", {}) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Hello, world!" - assert result.structured_content == {"message": "Hello, world!"} - assert result.data == {"message": "Hello, world!"} - - async def test_output_schema_none_full_handshake(self): - """Test that output_schema=None works through full client/server - handshake. We test this by returning a scalar, which requires an output - schema to serialize.""" - mcp = FastMCP() - - @mcp.tool(output_schema=None) - def simple_tool() -> int: - return 42 - - async with Client(mcp) as client: - # List tools and verify output schema is None - tools = await client.list_tools() - tool = next(t for t in tools if t.name == "simple_tool") - assert tool.outputSchema is None - - # Call tool and verify no structured content - result = await client.call_tool("simple_tool", {}) - assert result.structured_content is None - assert result.data is None - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "42" - - async def test_output_schema_explicit_object_full_handshake(self): - """Test explicit object output schema through full client/server handshake.""" - mcp = FastMCP() - - @mcp.tool( - output_schema={ - "type": "object", - "properties": { - "greeting": {"type": "string"}, - "count": {"type": "integer"}, - }, - "required": ["greeting"], - } - ) - def explicit_tool() -> dict[str, Any]: - return {"greeting": "Hello", "count": 42} - - async with Client(mcp) as client: - # List tools and verify exact schema is preserved - tools = await client.list_tools() - tool = next(t for t in tools if t.name == "explicit_tool") - expected_schema = { - "type": "object", - "properties": { - "greeting": {"type": "string"}, - "count": {"type": "integer"}, - }, - "required": ["greeting"], - } - assert tool.outputSchema == expected_schema - - # Call tool and verify structured content matches return value directly - result = await client.call_tool("explicit_tool", {}) - assert result.structured_content == {"greeting": "Hello", "count": 42} - # Client deserializes according to schema, so check fields - # result.data is a dynamically generated Root type, so check attributes directly - assert result.data is not None - assert result.data.greeting == "Hello" # type: ignore[attr-defined] - assert result.data.count == 42 # type: ignore[attr-defined] - - async def test_output_schema_wrapped_primitive_full_handshake(self): - """Test wrapped primitive output schema through full client/server handshake.""" - mcp = FastMCP() - - @mcp.tool - def primitive_tool() -> str: - return "Hello, primitives!" - - async with Client(mcp) as client: - # List tools and verify schema shows wrapped structure - tools = await client.list_tools() - tool = next(t for t in tools if t.name == "primitive_tool") - expected_schema = { - "type": "object", - "properties": {"result": {"type": "string"}}, - "required": ["result"], - "x-fastmcp-wrap-result": True, - } - assert tool.outputSchema == expected_schema - - # Call tool and verify structured content is wrapped - result = await client.call_tool("primitive_tool", {}) - assert result.structured_content == {"result": "Hello, primitives!"} - assert result.data == "Hello, primitives!" # Client unwraps for convenience - - async def test_output_schema_complex_type_full_handshake(self): - """Test complex type output schema through full client/server handshake.""" - mcp = FastMCP() - - @mcp.tool - def complex_tool() -> list[dict[str, int]]: - return [{"a": 1, "b": 2}, {"c": 3, "d": 4}] - - async with Client(mcp) as client: - # List tools and verify schema shows wrapped array - tools = await client.list_tools() - tool = next(t for t in tools if t.name == "complex_tool") - expected_inner_schema = compress_schema( - TypeAdapter(list[dict[str, int]]).json_schema(), prune_titles=True - ) - expected_schema = { - "type": "object", - "properties": {"result": expected_inner_schema}, - "required": ["result"], - "x-fastmcp-wrap-result": True, - } - assert tool.outputSchema == expected_schema - - # Call tool and verify structured content is wrapped - result = await client.call_tool("complex_tool", {}) - expected_data = [{"a": 1, "b": 2}, {"c": 3, "d": 4}] - assert result.structured_content == {"result": expected_data} - # Client deserializes - just verify we got data back - assert result.data is not None - - async def test_output_schema_dataclass_full_handshake(self): - """Test dataclass output schema through full client/server handshake.""" - mcp = FastMCP() - - @dataclass - class User: - name: str - age: int - - @mcp.tool - def dataclass_tool() -> User: - return User(name="Alice", age=30) - - async with Client(mcp) as client: - # List tools and verify schema is object type (not wrapped) - tools = await client.list_tools() - tool = next(t for t in tools if t.name == "dataclass_tool") - expected_schema = compress_schema( - TypeAdapter(User).json_schema(), prune_titles=True - ) - assert tool.outputSchema == expected_schema - assert ( - tool.outputSchema and "x-fastmcp-wrap-result" not in tool.outputSchema - ) - - # Call tool and verify structured content is direct - result = await client.call_tool("dataclass_tool", {}) - assert result.structured_content == {"name": "Alice", "age": 30} - # Client deserializes according to schema - # result.data is a dynamically generated Root type, so check attributes directly - assert result.data is not None - assert result.data.name == "Alice" # type: ignore[attr-defined] - assert result.data.age == 30 # type: ignore[attr-defined] - - async def test_output_schema_mixed_content_types(self): - """Test tools with mixed content and output schemas.""" - mcp = FastMCP() - - @mcp.tool - def mixed_output() -> list[Any]: - # Return mixed content that includes MCP types and regular data - return [ - "text message", - {"structured": "data"}, - TextContent(type="text", text="direct MCP content"), - ] - - async with Client(mcp) as client: - result = await client.call_tool("mixed_output", {}) - - # Should have multiple content blocks - assert result == snapshot( - CallToolResult( - content=[ - TextContent(type="text", text="text message"), - TextContent(type="text", text='{"structured":"data"}'), - TextContent(type="text", text="direct MCP content"), - ], - structured_content={ - "result": [ - "text message", - {"structured": "data"}, - { - "type": "text", - "text": "direct MCP content", - "annotations": None, - "_meta": None, - }, - ] - }, - data=[ - "text message", - {"structured": "data"}, - { - "type": "text", - "text": "direct MCP content", - "annotations": None, - "_meta": None, - }, - ], - meta=None, - ) - ) - - async def test_output_schema_serialization_edge_cases(self): - """Test edge cases in output schema serialization.""" - mcp = FastMCP() - - @mcp.tool - def edge_case_tool() -> tuple[int, str]: - return (42, "hello") - - async with Client(mcp) as client: - # Verify tuple gets proper schema - tools = await client.list_tools() - tool = next(t for t in tools if t.name == "edge_case_tool") - - # Tuples should be wrapped since they're not object type - assert tool.outputSchema and "x-fastmcp-wrap-result" in tool.outputSchema - - result = await client.call_tool("edge_case_tool", {}) - # Should be wrapped with result key - assert result.structured_content == {"result": [42, "hello"]} - assert result.data == [42, "hello"] - - -class TestToolContextInjection: - """Test context injection in tools.""" - - async def test_context_detection(self): - """Test that context parameters are properly detected.""" - mcp = FastMCP() - - @mcp.tool - def tool_with_context(x: int, ctx: Context) -> str: - return f"Request {ctx.request_id}: {x}" - - async with Client(mcp) as client: - tools = await client.list_tools() - assert len(tools) == 1 - assert tools[0].name == "tool_with_context" - - async def test_context_injection(self): - """Test that context is properly injected into tool calls.""" - mcp = FastMCP() - - @mcp.tool - def tool_with_context(x: int, ctx: Context) -> str: - assert isinstance(ctx, Context) - assert ctx.request_id is not None - return ctx.request_id - - async with Client(mcp) as client: - result = await client.call_tool("tool_with_context", {"x": 42}) - assert result.data == "1" - - async def test_async_context(self): - """Test that context works in async functions.""" - mcp = FastMCP() - - @mcp.tool - async def async_tool(x: int, ctx: Context) -> str: - assert ctx.request_id is not None - return f"Async request {ctx.request_id}: {x}" - - async with Client(mcp) as client: - result = await client.call_tool("async_tool", {"x": 42}) - assert result.data == "Async request 1: 42" - - async def test_optional_context(self): - """Test that context is optional.""" - mcp = FastMCP() - - @mcp.tool - def no_context(x: int) -> int: - return x * 2 - - async with Client(mcp) as client: - result = await client.call_tool("no_context", {"x": 21}) - assert result.data == 42 - - async def test_context_resource_access(self): - """Test that context can access resources.""" - mcp = FastMCP() - - @mcp.resource("test://data") - def test_resource() -> str: - return "resource data" - - @mcp.tool - async def tool_with_resource(ctx: Context) -> str: - r_iter = await ctx.read_resource("test://data") - r_list = list(r_iter) - assert len(r_list) == 1 - r = r_list[0] - return f"Read resource: {r.content} with mime type {r.mime_type}" - - async with Client(mcp) as client: - result = await client.call_tool("tool_with_resource", {}) - assert ( - result.data == "Read resource: resource data with mime type text/plain" - ) - - async def test_tool_decorator_with_tags(self): - """Test that the tool decorator properly sets tags.""" - mcp = FastMCP() - - @mcp.tool(tags={"example", "test-tag"}) - def sample_tool(x: int) -> int: - return x * 2 - - # Verify the tool exists - async with Client(mcp) as client: - tools = await client.list_tools() - assert len(tools) == 1 - # Note: MCPTool from the client API doesn't expose tags - - async def test_callable_object_with_context(self): - """Test that a callable object can be used as a tool with context.""" - mcp = FastMCP() - - class MyTool: - async def __call__(self, x: int, ctx: Context) -> int: - return x + int(ctx.request_id) - - mcp.add_tool(Tool.from_function(MyTool(), name="MyTool")) - - async with Client(mcp) as client: - result = await client.call_tool("MyTool", {"x": 2}) - assert result.data == 3 - - async def test_decorated_tool_with_functools_wraps(self): - """Regression test for #2524: @mcp.tool with functools.wraps decorator.""" - - def custom_decorator(func): - @functools.wraps(func) - async def wrapper(*args, **kwargs): - return await func(*args, **kwargs) - - return wrapper - - mcp = FastMCP() - - @mcp.tool - @custom_decorator - async def decorated_tool(ctx: Context, query: str) -> str: - assert isinstance(ctx, Context) - return f"query: {query}" - - async with Client(mcp) as client: - # Verify ctx is not in the schema - tools = await client.list_tools() - tool = next(t for t in tools if t.name == "decorated_tool") - assert "ctx" not in tool.inputSchema.get("properties", {}) - - # Verify the tool works - result = await client.call_tool("decorated_tool", {"query": "test"}) - assert result.data == "query: test" - - class TestToolEnabled: async def test_toggle_enabled(self): mcp = FastMCP() @@ -1615,21 +532,6 @@ async def test_read_excluded_resource(self): await client.read_resource(AnyUrl("resource://1")) -class TestResourceContext: - async def test_resource_with_context_annotation_gets_context(self): - mcp = FastMCP() - - @mcp.resource("resource://test") - def resource_with_context(ctx: Context) -> str: - assert isinstance(ctx, Context) - return ctx.request_id - - async with Client(mcp) as client: - result = await client.read_resource(AnyUrl("resource://test")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "1" - - class TestResourceEnabled: async def test_toggle_enabled(self): mcp = FastMCP() @@ -1726,289 +628,6 @@ def sample_resource() -> str: await client.read_resource(AnyUrl("resource://data")) -class TestResourceTemplates: - async def test_resource_with_params_not_in_uri(self): - """Test that a resource with function parameters raises an error if the URI - parameters don't match""" - mcp = FastMCP() - - with pytest.raises( - ValueError, - match="URI template must contain at least one parameter", - ): - - @mcp.resource("resource://data") - def get_data_fn(param: str) -> str: - return f"Data: {param}" - - async def test_resource_with_uri_params_without_args(self): - """Test that a resource with URI parameters is automatically a template""" - mcp = FastMCP() - - with pytest.raises( - ValueError, - match="URI parameters .* must be a subset of the function arguments", - ): - - @mcp.resource("resource://{param}") - def get_data() -> str: - return "Data" - - async def test_resource_with_untyped_params(self): - """Test that a resource with untyped parameters raises an error""" - mcp = FastMCP() - - @mcp.resource("resource://{param}") - def get_data(param) -> str: - return "Data" - - async def test_resource_matching_params(self): - """Test that a resource with matching URI and function parameters works""" - mcp = FastMCP() - - @mcp.resource("resource://{name}/data") - def get_data(name: str) -> str: - return f"Data for {name}" - - async with Client(mcp) as client: - result = await client.read_resource(AnyUrl("resource://test/data")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Data for test" - - async def test_resource_mismatched_params(self): - """Test that mismatched parameters raise an error""" - mcp = FastMCP() - - with pytest.raises( - ValueError, - match="Required function arguments .* must be a subset of the URI path parameters", - ): - - @mcp.resource("resource://{name}/data") - def get_data(user: str) -> str: - return f"Data for {user}" - - async def test_resource_multiple_params(self): - """Test that multiple parameters work correctly""" - mcp = FastMCP() - - @mcp.resource("resource://{org}/{repo}/data") - def get_data(org: str, repo: str) -> str: - return f"Data for {org}/{repo}" - - async with Client(mcp) as client: - result = await client.read_resource( - AnyUrl("resource://cursor/fastmcp/data") - ) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Data for cursor/fastmcp" - - async def test_resource_multiple_mismatched_params(self): - """Test that mismatched parameters raise an error""" - mcp = FastMCP() - - with pytest.raises( - ValueError, - match="Required function arguments .* must be a subset of the URI path parameters", - ): - - @mcp.resource("resource://{org}/{repo}/data") - def get_data_mismatched(org: str, repo_2: str) -> str: - return f"Data for {org}" - - """Test that a resource with no parameters works as a regular resource""" - mcp = FastMCP() - - @mcp.resource("resource://static") - def get_static_data() -> str: - return "Static data" - - async with Client(mcp) as client: - result = await client.read_resource(AnyUrl("resource://static")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Static data" - - async def test_template_with_varkwargs(self): - """Test that a template can have **kwargs.""" - mcp = FastMCP() - - @mcp.resource("test://{x}/{y}/{z}") - def func(**kwargs: int) -> int: - return sum(kwargs.values()) - - async with Client(mcp) as client: - result = await client.read_resource(AnyUrl("test://1/2/3")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "6" - - async def test_template_with_default_params(self): - """Test that a template can have default parameters.""" - mcp = FastMCP() - - @mcp.resource("math://add/{x}") - def add(x: int, y: int = 10) -> int: - return x + y - - # Verify it's registered as a template - templates_dict = await mcp.get_resource_templates() - templates = list(templates_dict.values()) - assert len(templates) == 1 - assert templates[0].uri_template == "math://add/{x}" - - # Call the template and verify it uses the default value - async with Client(mcp) as client: - result = await client.read_resource(AnyUrl("math://add/5")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "15" - - # Can also call with explicit params - result2 = await client.read_resource(AnyUrl("math://add/7")) - assert isinstance(result2[0], TextResourceContents) - assert result2[0].text == "17" - - async def test_template_to_resource_conversion(self): - """Test that a template can be converted to a resource.""" - mcp = FastMCP() - - @mcp.resource("resource://{name}/data") - def get_data(name: str) -> str: - return f"Data for {name}" - - # Verify it's registered as a template - templates_dict = await mcp.get_resource_templates() - templates = list(templates_dict.values()) - assert len(templates) == 1 - assert templates[0].uri_template == "resource://{name}/data" - - # When accessed, should create a concrete resource - async with Client(mcp) as client: - result = await client.read_resource(AnyUrl("resource://test/data")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Data for test" - - async def test_template_decorator_with_tags(self): - mcp = FastMCP() - - @mcp.resource("resource://{param}", tags={"template", "test-tag"}) - def template_resource(param: str) -> str: - return f"Template resource: {param}" - - templates_dict = await mcp.get_resource_templates() - template = templates_dict["resource://{param}"] - assert template.tags == {"template", "test-tag"} - - async def test_template_decorator_wildcard_param(self): - mcp = FastMCP() - - @mcp.resource("resource://{param*}") - def template_resource(param: str) -> str: - return f"Template resource: {param}" - - async with Client(mcp) as client: - result = await client.read_resource(AnyUrl("resource://test/data")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Template resource: test/data" - - async def test_template_with_query_params(self): - """Test RFC 6570 query parameters in resource templates.""" - mcp = FastMCP() - - @mcp.resource("data://{id}{?format,limit}") - def get_data(id: str, format: str = "json", limit: int = 10) -> str: - return f"id={id}, format={format}, limit={limit}" - - async with Client(mcp) as client: - # No query params - uses defaults - result = await client.read_resource(AnyUrl("data://123")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "id=123, format=json, limit=10" - - # One query param - result = await client.read_resource(AnyUrl("data://123?format=xml")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "id=123, format=xml, limit=10" - - # Multiple query params - result = await client.read_resource( - AnyUrl("data://123?format=csv&limit=50") - ) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "id=123, format=csv, limit=50" - - async def test_templates_match_in_order_of_definition(self): - """ - If a wildcard template is defined first, it will take priority over another - matching template. - - """ - mcp = FastMCP() - - @mcp.resource("resource://{param*}") - def template_resource(param: str) -> str: - return f"Template resource 1: {param}" - - @mcp.resource("resource://{x}/{y}") - def template_resource_with_params(x: str, y: str) -> str: - return f"Template resource 2: {x}/{y}" - - async with Client(mcp) as client: - result = await client.read_resource(AnyUrl("resource://a/b/c")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Template resource 1: a/b/c" - - result = await client.read_resource(AnyUrl("resource://a/b")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Template resource 1: a/b" - - async def test_templates_shadow_each_other_reorder(self): - """ - If a wildcard template is defined second, it will *not* take priority over - another matching template. - """ - mcp = FastMCP() - - @mcp.resource("resource://{x}/{y}") - def template_resource_with_params(x: str, y: str) -> str: - return f"Template resource 1: {x}/{y}" - - @mcp.resource("resource://{param*}") - def template_resource(param: str) -> str: - return f"Template resource 2: {param}" - - async with Client(mcp) as client: - result = await client.read_resource(AnyUrl("resource://a/b/c")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Template resource 2: a/b/c" - - result = await client.read_resource(AnyUrl("resource://a/b")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text == "Template resource 1: a/b" - - async def test_resource_template_with_annotations(self): - """Test that resource template annotations are visible to clients.""" - mcp = FastMCP() - - @mcp.resource( - "api://users/{user_id}", - annotations={"httpMethod": "GET", "Cache-Control": "no-cache"}, - ) - def get_user(user_id: str) -> str: - return f"User {user_id} data" - - async with Client(mcp) as client: - templates = await client.list_resource_templates() - assert len(templates) == 1 - - template = templates[0] - assert template.uriTemplate == "api://users/{user_id}" - - assert template.annotations is not None - assert hasattr(template.annotations, "httpMethod") - assert getattr(template.annotations, "httpMethod") == "GET" - assert hasattr(template.annotations, "Cache-Control") - assert getattr(template.annotations, "Cache-Control") == "no-cache" - - class TestResourceTemplatesTags: def create_server(self, include_tags=None, exclude_tags=None): mcp = FastMCP(include_tags=include_tags, exclude_tags=exclude_tags) @@ -2084,38 +703,6 @@ async def test_read_resource_template_excludes_tags(self): assert result[0].text == "Template resource 2: x" -class TestResourceTemplateContext: - async def test_resource_template_context(self): - mcp = FastMCP() - - @mcp.resource("resource://{param}") - def resource_template(param: str, ctx: Context) -> str: - assert isinstance(ctx, Context) - return f"Resource template: {param} {ctx.request_id}" - - async with Client(mcp) as client: - result = await client.read_resource(AnyUrl("resource://test")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text.startswith("Resource template: test 1") - - async def test_resource_template_context_with_callable_object(self): - mcp = FastMCP() - - class MyResource: - def __call__(self, param: str, ctx: Context) -> str: - return f"Resource template: {param} {ctx.request_id}" - - template = ResourceTemplate.from_function( - MyResource(), uri_template="resource://{param}" - ) - mcp.add_template(template) - - async with Client(mcp) as client: - result = await client.read_resource(AnyUrl("resource://test")) - assert isinstance(result[0], TextResourceContents) - assert result[0].text.startswith("Resource template: test 1") - - class TestResourceTemplateEnabled: async def test_toggle_enabled(self): mcp = FastMCP() @@ -2559,39 +1146,6 @@ def sample_prompt() -> str: await client.get_prompt("sample_prompt") -class TestPromptContext: - async def test_prompt_context(self): - mcp = FastMCP() - - @mcp.prompt - def prompt_fn(name: str, ctx: Context) -> str: - assert isinstance(ctx, Context) - return f"Hello, {name}! {ctx.request_id}" - - async with Client(mcp) as client: - result = await client.get_prompt("prompt_fn", {"name": "World"}) - assert len(result.messages) == 1 - message = result.messages[0] - assert message.role == "user" - - async def test_prompt_context_with_callable_object(self): - mcp = FastMCP() - - class MyPrompt: - def __call__(self, name: str, ctx: Context) -> str: - return f"Hello, {name}! {ctx.request_id}" - - mcp.add_prompt(Prompt.from_function(MyPrompt(), name="my_prompt")) # noqa: F821 - - async with Client(mcp) as client: - result = await client.get_prompt("my_prompt", {"name": "World"}) - assert len(result.messages) == 1 - message = result.messages[0] - assert message.role == "user" - assert isinstance(message.content, TextContent) - assert message.content.text == "Hello, World! 1" - - class TestPromptTags: def create_server(self, include_tags=None, exclude_tags=None): mcp = FastMCP(include_tags=include_tags, exclude_tags=exclude_tags) diff --git a/tests/server/test_tool_annotations.py b/tests/server/test_tool_annotations.py index 4a9983c72e..742dc7478c 100644 --- a/tests/server/test_tool_annotations.py +++ b/tests/server/test_tool_annotations.py @@ -23,7 +23,7 @@ def echo(message: str) -> str: return message # Check internal tool objects directly - tools_dict = await mcp._tool_manager.get_tools() + tools_dict = await mcp.get_tools() tools = list(tools_dict.values()) assert len(tools) == 1 assert tools[0].annotations is not None @@ -126,7 +126,7 @@ def modify(data: dict[str, Any]) -> dict[str, Any]: return {"modified": True, **data} # Check internal tool objects directly - tools_dict = await mcp._tool_manager.get_tools() + tools_dict = await mcp.get_tools() tools = list(tools_dict.values()) assert len(tools) == 1 assert tools[0].annotations is not None @@ -186,7 +186,7 @@ def create_item(name: str, value: int) -> dict[str, Any]: mcp.add_tool(tool) # Check internal tool objects directly - tools_dict = await mcp._tool_manager.get_tools() + tools_dict = await mcp.get_tools() tools = list(tools_dict.values()) assert len(tools) == 1 assert tools[0].annotations is not None diff --git a/tests/server/test_tool_transformation.py b/tests/server/test_tool_transformation.py index a35289fc1b..63e6900504 100644 --- a/tests/server/test_tool_transformation.py +++ b/tests/server/test_tool_transformation.py @@ -13,7 +13,7 @@ def echo(message: str) -> str: mcp.add_tool_transformation("echo", ToolTransformConfig(name="echo_transformed")) - tools_dict = await mcp._tool_manager.get_tools() + tools_dict = await mcp.get_tools() tools = list(tools_dict.values()) assert len(tools) == 1 assert "echo_transformed" in tools_dict diff --git a/tests/tools/test_tool_manager.py b/tests/tools/test_tool_manager.py deleted file mode 100644 index 18637af99a..0000000000 --- a/tests/tools/test_tool_manager.py +++ /dev/null @@ -1,1078 +0,0 @@ -import functools -import json -import logging -import uuid -from typing import Annotated, Any - -import pydantic_core -import pytest -from inline_snapshot import snapshot -from mcp.types import ImageContent, TextContent -from pydantic import BaseModel, ValidationError - -from fastmcp import Context, FastMCP -from fastmcp.exceptions import NotFoundError, ToolError -from fastmcp.tools import FunctionTool, ToolManager -from fastmcp.tools.tool import Tool -from fastmcp.tools.tool_transform import ArgTransformConfig, ToolTransformConfig -from fastmcp.utilities.tests import caplog_for_fastmcp, temporary_settings -from fastmcp.utilities.types import Image -from tests.conftest import get_fn_name - - -class TestAddTools: - async def test_basic_function(self): - """Test registering and running a basic function.""" - - def add(a: int, b: int) -> int: - """Add two numbers.""" - return a + b - - manager = ToolManager() - tool = Tool.from_function(add) - manager.add_tool(tool) - - tool = await manager.get_tool("add") - assert tool is not None - assert tool.name == "add" - assert tool.description == "Add two numbers." - assert tool.parameters["properties"]["a"]["type"] == "integer" - assert tool.parameters["properties"]["b"]["type"] == "integer" - - async def test_async_function(self): - """Test registering and running an async function.""" - - async def fetch_data(url: str) -> str: - """Fetch data from URL.""" - return f"Data from {url}" - - manager = ToolManager() - tool = Tool.from_function(fetch_data) - manager.add_tool(tool) - - tool = await manager.get_tool("fetch_data") - assert tool is not None - assert tool.name == "fetch_data" - assert tool.description == "Fetch data from URL." - assert tool.parameters["properties"]["url"]["type"] == "string" - - async def test_pydantic_model_function(self): - """Test registering a function that takes a Pydantic model.""" - - class UserInput(BaseModel): - name: str - age: int - - def create_user(user: UserInput, flag: bool) -> dict: - """Create a new user.""" - return {"id": 1, **user.model_dump()} - - manager = ToolManager() - tool = Tool.from_function(create_user) - manager.add_tool(tool) - - tool = await manager.get_tool("create_user") - assert tool is not None - assert tool.name == "create_user" - assert tool.description == "Create a new user." - assert "name" in tool.parameters["$defs"]["UserInput"]["properties"] - assert "age" in tool.parameters["$defs"]["UserInput"]["properties"] - assert "flag" in tool.parameters["properties"] - - async def test_callable_object(self): - class Adder: - """Adds two numbers.""" - - def __call__(self, x: int, y: int) -> int: - """ignore this""" - return x + y - - manager = ToolManager() - tool = Tool.from_function(Adder()) - manager.add_tool(tool) - - tool = await manager.get_tool("Adder") - assert tool is not None - assert tool.name == "Adder" - assert tool.description == "Adds two numbers." - assert len(tool.parameters["properties"]) == 2 - assert tool.parameters["properties"]["x"]["type"] == "integer" - assert tool.parameters["properties"]["y"]["type"] == "integer" - - async def test_async_callable_object(self): - class Adder: - """Adds two numbers.""" - - async def __call__(self, x: int, y: int) -> int: - """ignore this""" - return x + y - - manager = ToolManager() - tool = Tool.from_function(Adder()) - manager.add_tool(tool) - - tool = await manager.get_tool("Adder") - assert tool is not None - assert tool.name == "Adder" - assert tool.description == "Adds two numbers." - assert len(tool.parameters["properties"]) == 2 - assert tool.parameters["properties"]["x"]["type"] == "integer" - assert tool.parameters["properties"]["y"]["type"] == "integer" - - async def test_tool_with_image_return(self): - def image_tool(data: bytes) -> Image: - return Image(data=data) - - manager = ToolManager() - tool = Tool.from_function(image_tool) - manager.add_tool(tool) - - tool = await manager.get_tool("image_tool") - result = await tool.run({"data": "test.png"}) - assert tool.parameters["properties"]["data"]["type"] == "string" - assert isinstance(result.content[0], ImageContent) - assert result.structured_content is None - - def test_add_noncallable_tool(self): - manager = ToolManager() - with pytest.raises(TypeError, match="not a callable object"): - assert isinstance(1, int) # Intentionally passing invalid type - # Intentionally passing invalid type to test error handling - tool = Tool.from_function(1) # type: ignore[arg-type] - manager.add_tool(tool) - - def test_add_lambda(self): - manager = ToolManager() - tool = Tool.from_function(lambda x: x, name="my_tool") - manager.add_tool(tool) - assert tool.name == "my_tool" - - def test_add_lambda_with_no_name(self): - manager = ToolManager() - with pytest.raises( - ValueError, match="You must provide a name for lambda functions" - ): - tool = Tool.from_function(lambda x: x) - manager.add_tool(tool) - - async def test_remove_tool_successfully(self): - """Test removing an added tool by key.""" - manager = ToolManager() - - def add(a: int, b: int) -> int: - return a + b - - tool = Tool.from_function(add) - manager.add_tool(tool) - assert await manager.get_tool("add") is not None - - manager.remove_tool("add") - with pytest.raises(NotFoundError): - await manager.get_tool("add") - - def test_remove_tool_missing_key(self): - """Test removing a tool that does not exist raises NotFoundError.""" - manager = ToolManager() - with pytest.raises(NotFoundError, match="Tool 'missing' not found"): - manager.remove_tool("missing") - - async def test_warn_on_duplicate_tools(self, caplog): - """Test warning on duplicate tools.""" - manager = ToolManager(duplicate_behavior="warn") - - def test_fn(x: int) -> int: - return x - - tool1 = Tool.from_function(test_fn, name="test_tool") - manager.add_tool(tool1) - - with caplog_for_fastmcp(caplog): - tool2 = Tool.from_function(test_fn, name="test_tool") - manager.add_tool(tool2) - - assert "Tool already exists: test_tool" in caplog.text - # Should have the tool - assert await manager.get_tool("test_tool") is not None - - def test_disable_warn_on_duplicate_tools(self, caplog): - """Test disabling warning on duplicate tools.""" - - def f(x: int) -> int: - return x - - manager = ToolManager(duplicate_behavior="ignore") - tool1 = Tool.from_function(f) - manager.add_tool(tool1) - with caplog.at_level(logging.WARNING): - tool2 = Tool.from_function(f) - manager.add_tool(tool2) - assert "Tool already exists: f" not in caplog.text - - def test_error_on_duplicate_tools(self): - """Test error on duplicate tools.""" - manager = ToolManager(duplicate_behavior="error") - - def test_fn(x: int) -> int: - return x - - tool1 = Tool.from_function(test_fn, name="test_tool") - manager.add_tool(tool1) - - with pytest.raises(ValueError, match="Tool already exists: test_tool"): - tool2 = Tool.from_function(test_fn, name="test_tool") - manager.add_tool(tool2) - - async def test_replace_duplicate_tools(self): - """Test replacing duplicate tools.""" - manager = ToolManager(duplicate_behavior="replace") - - def original_fn(x: int) -> int: - return x - - def replacement_fn(x: int) -> int: - return x + 1 - - tool1 = Tool.from_function(original_fn, name="test_tool") - manager.add_tool(tool1) - result = Tool.from_function(replacement_fn, name="test_tool") - manager.add_tool(result) - - # Should have replaced with the new tool - tool = await manager.get_tool("test_tool") - assert tool is not None - assert isinstance(tool, FunctionTool) - assert get_fn_name(tool.fn) == "replacement_fn" - - async def test_ignore_duplicate_tools(self): - """Test ignoring duplicate tools.""" - manager = ToolManager(duplicate_behavior="ignore") - - def original_fn(x: int) -> int: - return x - - def replacement_fn(x: int) -> int: - return x * 2 - - tool1 = Tool.from_function(original_fn, name="test_tool") - manager.add_tool(tool1) - result = Tool.from_function(replacement_fn, name="test_tool") - manager.add_tool(result) - - # Should keep the original - tool = await manager.get_tool("test_tool") - assert tool is not None - assert isinstance(tool, FunctionTool) - assert get_fn_name(tool.fn) == "original_fn" - # Result should be the original tool - assert isinstance(result, FunctionTool) - assert get_fn_name(result.fn) == "replacement_fn" - - -class TestListTools: - async def test_list_tools_with_transformed_names(self): - """Test listing tools with transformations.""" - - tool_manager = ToolManager() - - def add(a: int, b: int) -> int: - return a + b - - tool = Tool.from_function(add) - tool_manager.add_tool(tool) - - tool_manager.add_tool_transformation( - "add", ToolTransformConfig(name="add_transformed") - ) - tools_dict = await tool_manager.get_tools() - tools_by_name = {tool.name: tool for tool in tools_dict.values()} - assert "add_transformed" in tools_by_name - assert "add" not in tools_by_name - - async def test_list_tools_with_transforms(self): - """Test listing tools with transformations.""" - - tool_manager = ToolManager() - - def add(a: int, b: int) -> int: - """Add two numbers.""" - return a + b - - tool = Tool.from_function(add) - tool_manager.add_tool(tool) - - tool_manager.add_tool_transformation( - "add", - ToolTransformConfig( - name="add_transformed", description=None, tags={"enabled_tools"} - ), - ) - tools_dict = await tool_manager.get_tools() - tools_by_name = {tool.name: tool for tool in tools_dict.values()} - assert "add_transformed" in tools_by_name - assert "add" not in tools_by_name - assert tools_by_name["add_transformed"].description is None - assert tools_by_name["add_transformed"].tags == {"enabled_tools"} - - -class TestToolTags: - """Test functionality related to tool tags.""" - - async def test_add_tool_with_tags(self): - """Test adding tags to a tool.""" - - def example_tool(x: int) -> int: - """An example tool with tags.""" - return x * 2 - - manager = ToolManager() - tool = Tool.from_function(example_tool, tags={"math", "utility"}) - manager.add_tool(tool) - - assert tool.tags == {"math", "utility"} - tool = await manager.get_tool("example_tool") - assert tool is not None - assert tool.tags == {"math", "utility"} - - async def test_add_tool_with_empty_tags(self): - """Test adding a tool with empty tags set.""" - - def example_tool(x: int) -> int: - """An example tool with empty tags.""" - return x * 2 - - manager = ToolManager() - tool = Tool.from_function(example_tool, tags=set()) - manager.add_tool(tool) - - assert tool.tags == set() - - async def test_add_tool_with_none_tags(self): - """Test adding a tool with None tags.""" - - def example_tool(x: int) -> int: - """An example tool with None tags.""" - return x * 2 - - manager = ToolManager() - tool = Tool.from_function(example_tool, tags=None) - manager.add_tool(tool) - - assert tool.tags == set() - - async def test_list_tools_with_tags(self): - """Test listing tools with specific tags.""" - - def math_tool(x: int) -> int: - """A math tool.""" - return x * 2 - - def string_tool(x: str) -> str: - """A string tool.""" - return x.upper() - - def mixed_tool(x: int) -> str: - """A tool with multiple tags.""" - return str(x) - - manager = ToolManager() - tool1 = Tool.from_function(math_tool, tags={"math"}) - manager.add_tool(tool1) - tool2 = Tool.from_function(string_tool, tags={"string", "utility"}) - manager.add_tool(tool2) - tool3 = Tool.from_function(mixed_tool, tags={"math", "utility", "string"}) - manager.add_tool(tool3) - - # Check if we can filter by tags when listing tools - math_tools = [ - tool for tool in (await manager.get_tools()).values() if "math" in tool.tags - ] - assert len(math_tools) == 2 - assert {tool.name for tool in math_tools} == {"math_tool", "mixed_tool"} - - utility_tools = [ - tool - for tool in (await manager.get_tools()).values() - if "utility" in tool.tags - ] - assert len(utility_tools) == 2 - assert {tool.name for tool in utility_tools} == {"string_tool", "mixed_tool"} - - -class TestCallTools: - async def test_call_tool(self): - def add(a: int, b: int) -> int: - """Add two numbers.""" - return a + b - - manager = ToolManager() - tool = Tool.from_function(add) - manager.add_tool(tool) - result = await manager.call_tool("add", {"a": 1, "b": 2}) - - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "3" - assert result.structured_content == {"result": 3} - - async def test_call_async_tool(self): - async def double(n: int) -> int: - """Double a number.""" - return n * 2 - - manager = ToolManager() - tool = Tool.from_function(double) - manager.add_tool(tool) - result = await manager.call_tool("double", {"n": 5}) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "10" - assert result.structured_content == {"result": 10} - - async def test_call_tool_callable_object(self): - class Adder: - """Adds two numbers.""" - - def __call__(self, x: int, y: int) -> int: - """ignore this""" - return x + y - - manager = ToolManager() - tool = Tool.from_function(Adder()) - manager.add_tool(tool) - result = await manager.call_tool("Adder", {"x": 1, "y": 2}) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "3" - assert result.structured_content == {"result": 3} - - async def test_call_tool_callable_object_async(self): - class Adder: - """Adds two numbers.""" - - async def __call__(self, x: int, y: int) -> int: - """ignore this""" - return x + y - - manager = ToolManager() - tool = Tool.from_function(Adder()) - manager.add_tool(tool) - result = await manager.call_tool("Adder", {"x": 1, "y": 2}) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "3" - assert result.structured_content == {"result": 3} - - async def test_call_tool_with_default_args(self): - def add(a: int, b: int = 1) -> int: - """Add two numbers.""" - return a + b - - manager = ToolManager() - tool = Tool.from_function(add) - manager.add_tool(tool) - result = await manager.call_tool("add", {"a": 1}) - - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "2" - assert result.structured_content == {"result": 2} - - async def test_call_tool_with_missing_args(self): - def add(a: int, b: int) -> int: - """Add two numbers.""" - return a + b - - manager = ToolManager() - tool = Tool.from_function(add) - manager.add_tool(tool) - with pytest.raises(ValidationError): - await manager.call_tool("add", {"a": 1}) - - async def test_call_unknown_tool(self): - manager = ToolManager() - with pytest.raises(NotFoundError, match="Tool 'unknown' not found"): - await manager.call_tool("unknown", {"a": 1}) - - async def test_call_transformed_tool(self): - manager = ToolManager() - - def add(a: int, b: int) -> int: - """Add two numbers.""" - return a + b - - tool = Tool.from_function(add) - manager.add_tool(tool) - - manager.add_tool_transformation( - "add", - ToolTransformConfig( - name="add_transformed", - description=None, - tags={"enabled_tools"}, - arguments={ - "a": ArgTransformConfig( - name="a_transformed", description=None, default=1 - ), - "b": ArgTransformConfig( - name="b_transformed", description=None, default=2 - ), - }, - ), - ) - - result = await manager.call_tool( - "add_transformed", {"a_transformed": 1, "b_transformed": 2} - ) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "3" - assert result.structured_content == {"result": 3} - - async def test_call_tool_with_list_int_input(self): - def sum_vals(vals: list[int]) -> int: - return sum(vals) - - manager = ToolManager() - tool = Tool.from_function(sum_vals) - manager.add_tool(tool) - - result = await manager.call_tool("sum_vals", {"vals": [1, 2, 3]}) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "6" - assert result.structured_content == {"result": 6} - - async def test_call_tool_with_list_str_or_str_input(self): - def concat_strs(vals: list[str] | str) -> str: - return vals if isinstance(vals, str) else "".join(vals) - - manager = ToolManager() - tool = Tool.from_function(concat_strs) - manager.add_tool(tool) - - # Try both with plain python object and with JSON list - result = await manager.call_tool("concat_strs", {"vals": ["a", "b", "c"]}) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "abc" - assert result.structured_content == {"result": "abc"} - - result = await manager.call_tool("concat_strs", {"vals": "a"}) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "a" - assert result.structured_content == {"result": "a"} - - async def test_call_tool_with_complex_model(self): - class MyShrimpTank(BaseModel): - class Shrimp(BaseModel): - name: str - - shrimp: list[Shrimp] - x: None - - def name_shrimp(tank: MyShrimpTank, ctx: Context | None) -> list[str]: - return [x.name for x in tank.shrimp] - - manager = ToolManager() - tool = Tool.from_function(name_shrimp) - manager.add_tool(tool) - - mcp = FastMCP() - context = Context(fastmcp=mcp) - - async with context: - result = await manager.call_tool( - "name_shrimp", - { - "tank": { - "x": None, - "shrimp": [{"name": "rex"}, {"name": "gertrude"}], - } - }, - ) - - # Adjacent non-MCP list items are combined into single content block - assert len(result.content) == 1 - assert result.content == snapshot( - [TextContent(type="text", text='["rex","gertrude"]')] - ) - assert result.structured_content == snapshot({"result": ["rex", "gertrude"]}) - - async def test_call_tool_with_custom_serializer(self): - """Test that a custom serializer provided to FastMCP is used by tools.""" - - def custom_serializer(data: Any) -> str: - if isinstance(data, dict): - return f"CUSTOM:{json.dumps(data)}" - return json.dumps(data) - - # Instantiate FastMCP with the custom serializer - mcp = FastMCP(tool_serializer=custom_serializer) - manager = mcp._tool_manager - - @mcp.tool - def get_data() -> dict: - return {"key": "value", "number": 123} - - result = await manager.call_tool("get_data", {}) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == 'CUSTOM:{"key": "value", "number": 123}' - assert result.structured_content == {"key": "value", "number": 123} - - async def test_call_tool_with_list_result_custom_serializer(self): - """Test that a custom serializer provided to FastMCP is used by tools that return lists.""" - - def custom_serializer(data: Any) -> str: - if isinstance(data, list): - return f"CUSTOM:{json.dumps(data)}" - return json.dumps(data) - - mcp = FastMCP(tool_serializer=custom_serializer) - manager = mcp._tool_manager - - @mcp.tool - def get_data() -> list[dict]: - return [ - {"key": "value", "number": 123}, - {"key": "value2", "number": 456}, - ] - - result = await manager.call_tool("get_data", {}) - # Adjacent non-MCP list items get combined with custom serializer applied to each - assert len(result.content) == 1 - assert result.content == snapshot( - [ - TextContent( - type="text", - text='CUSTOM:[{"key": "value", "number": 123}, {"key": "value2", "number": 456}]', - ) - ] - ) - assert result.structured_content == snapshot( - { - "result": [ - {"key": "value", "number": 123}, - {"key": "value2", "number": 456}, - ] - } - ) - - async def test_custom_serializer_fallback_on_error(self): - """Test that a broken custom serializer gracefully falls back.""" - - uuid_result = uuid.uuid4() - - def custom_serializer(data: Any) -> str: - return json.dumps(data) - - mcp = FastMCP(tool_serializer=custom_serializer) - manager = mcp._tool_manager - - @mcp.tool - def get_data() -> uuid.UUID: - return uuid_result - - result = await manager.call_tool("get_data", {}) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == pydantic_core.to_json(uuid_result).decode() - assert result.structured_content == {"result": str(uuid_result)} - - -class TestToolSchema: - async def test_context_arg_excluded_from_schema(self): - def something(a: int, ctx: Context) -> int: - return a - - manager = ToolManager() - tool = Tool.from_function(something) - manager.add_tool(tool) - assert "ctx" not in json.dumps(tool.parameters) - assert "Context" not in json.dumps(tool.parameters) - - async def test_optional_context_arg_excluded_from_schema(self): - def something(a: int, ctx: Context | None) -> int: - return a - - manager = ToolManager() - tool = Tool.from_function(something) - manager.add_tool(tool) - assert "ctx" not in json.dumps(tool.parameters) - assert "Context" not in json.dumps(tool.parameters) - - async def test_annotated_context_arg_excluded_from_schema(self): - def something(a: int, ctx: Annotated[Context | int | None, "ctx"]) -> int: - return a - - manager = ToolManager() - tool = Tool.from_function(something) - manager.add_tool(tool) - assert "ctx" not in json.dumps(tool.parameters) - assert "Context" not in json.dumps(tool.parameters) - - -class TestContextHandling: - """Test context handling in the tool manager.""" - - def test_context_parameter_detection(self): - """Test that context parameters are properly detected in - Tool.from_function().""" - - def tool_with_context(x: int, ctx: Context) -> str: - return str(x) - - manager = ToolManager() - tool = Tool.from_function(tool_with_context) - manager.add_tool(tool) - - def tool_without_context(x: int) -> str: - return str(x) - - manager.add_tool(Tool.from_function(tool_without_context)) - - async def test_context_injection(self): - """Test that context is properly injected during tool execution.""" - - def tool_with_context(x: int, ctx: Context) -> str: - assert isinstance(ctx, Context) - return str(x) - - manager = ToolManager() - tool = Tool.from_function(tool_with_context) - manager.add_tool(tool) - - mcp = FastMCP() - context = Context(fastmcp=mcp) - - async with context: - result = await manager.call_tool("tool_with_context", {"x": 42}) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "42" - assert result.structured_content == {"result": "42"} - - async def test_context_injection_async(self): - """Test that context is properly injected in async tools.""" - - async def async_tool(x: int, ctx: Context) -> str: - assert isinstance(ctx, Context) - return str(x) - - manager = ToolManager() - tool = Tool.from_function(async_tool) - manager.add_tool(tool) - - mcp = FastMCP() - context = Context(fastmcp=mcp) - - async with context: - result = await manager.call_tool("async_tool", {"x": 42}) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "42" - assert result.structured_content == {"result": "42"} - - async def test_context_optional(self): - """Test that context is optional when calling tools.""" - - def tool_with_context(x: int, ctx: Context | None) -> int: - return x - - manager = ToolManager() - tool = Tool.from_function(tool_with_context) - manager.add_tool(tool) - # Should not raise an error when context is not provided - - mcp = FastMCP() - context = Context(fastmcp=mcp) - - async with context: - result = await manager.call_tool("tool_with_context", {"x": 42}) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "42" - assert result.structured_content == {"result": 42} - - def test_parameterized_context_parameter_detection(self): - """Test that context parameters are properly detected in - Tool.from_function().""" - - def tool_with_context(x: int, ctx: Context) -> str: - return str(x) - - manager = ToolManager() - tool = Tool.from_function(tool_with_context) - manager.add_tool(tool) - - def test_annotated_context_parameter_detection(self): - def tool_with_context(x: int, ctx: Annotated[Context, "ctx"]) -> str: - return str(x) - - manager = ToolManager() - tool = Tool.from_function(tool_with_context) - manager.add_tool(tool) - - def test_parameterized_union_context_parameter_detection(self): - """Test that context parameters are properly detected in - Tool.from_function().""" - - def tool_with_context(x: int, ctx: Context | None) -> str: - return str(x) - - manager = ToolManager() - tool = Tool.from_function(tool_with_context) - manager.add_tool(tool) - - async def test_context_error_handling(self): - """Test error handling when context injection fails.""" - - def tool_with_context(x: int, ctx: Context) -> str: - raise ValueError("Test error") - - manager = ToolManager() - tool = Tool.from_function(tool_with_context) - manager.add_tool(tool) - - mcp = FastMCP() - context = Context(fastmcp=mcp) - - async with context: - with pytest.raises( - ToolError, match="Error calling tool 'tool_with_context'" - ): - await manager.call_tool("tool_with_context", {"x": 42}) - - async def test_context_with_functools_wraps_decorator(self): - """Regression test for #2524: decorated tools with Context should work.""" - - def custom_decorator(func): - @functools.wraps(func) - async def wrapper(*args, **kwargs): - return await func(*args, **kwargs) - - return wrapper - - @custom_decorator - async def decorated_tool(ctx: Context, query: str) -> str: - assert isinstance(ctx, Context) - return f"query: {query}" - - manager = ToolManager() - tool = Tool.from_function(decorated_tool) - manager.add_tool(tool) - - # Verify ctx is excluded from schema - assert "ctx" not in json.dumps(tool.parameters) - - mcp = FastMCP() - context = Context(fastmcp=mcp) - - async with context: - result = await manager.call_tool("decorated_tool", {"query": "test"}) - assert result.structured_content == {"result": "query: test"} - - -class TestCustomToolNames: - """Test adding tools with custom names that differ from their function names.""" - - async def test_add_tool_with_custom_name(self): - """Test adding a tool with a custom name parameter using add_tool_from_fn.""" - - def original_fn(x: int) -> int: - return x * 2 - - manager = ToolManager() - tool = Tool.from_function(original_fn, name="custom_name") - manager.add_tool(tool) - - # The tool is stored under the custom name and its .name is also set to custom_name - assert await manager.get_tool("custom_name") is not None - assert tool.name == "custom_name" - assert isinstance(tool, FunctionTool) - assert get_fn_name(tool.fn) == "original_fn" - # The tool should not be accessible via its original function name - with pytest.raises(NotFoundError, match="Tool 'original_fn' not found"): - await manager.get_tool("original_fn") - - async def test_call_tool_with_custom_name(self): - """Test calling a tool added with a custom name.""" - - def multiply(a: int, b: int) -> int: - """Multiply two numbers.""" - return a * b - - manager = ToolManager() - tool = Tool.from_function(multiply, name="custom_multiply") - manager.add_tool(tool) - - # Tool should be callable by its custom name - result = await manager.call_tool("custom_multiply", {"a": 5, "b": 3}) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "15" - assert result.structured_content == {"result": 15} - - # Original name should not be registered - with pytest.raises(NotFoundError, match="Tool 'multiply' not found"): - await manager.call_tool("multiply", {"a": 5, "b": 3}) - - async def test_replace_tool_keeps_original_name(self): - """Test that replacing a tool with "replace" keeps the original name.""" - - def original_fn(x: int) -> int: - return x - - def replacement_fn(x: int) -> int: - return x * 2 - - # Create a manager with REPLACE behavior - manager = ToolManager(duplicate_behavior="replace") - - # Add the original tool - original_tool = Tool.from_function(original_fn, name="test_tool") - manager.add_tool(original_tool) - assert original_tool.name == "test_tool" - - # Replace with a new function but keep the same registered name - replacement_tool = Tool.from_function(replacement_fn, name="test_tool") - manager.add_tool(replacement_tool) - - # The tool object should have been replaced - stored_tool = await manager.get_tool("test_tool") - assert stored_tool is not None - assert stored_tool == replacement_tool - - # The name should still be the same - assert stored_tool.name == "test_tool" - - # But the function is different - assert isinstance(stored_tool, FunctionTool) - assert get_fn_name(stored_tool.fn) == "replacement_fn" - - -class TestToolErrorHandling: - """Test error handling in the ToolManager.""" - - async def test_tool_error_passthrough(self): - """Test that ToolErrors are passed through directly.""" - manager = ToolManager() - - def error_tool(x: int) -> int: - """Tool that raises a ToolError.""" - raise ToolError("Specific tool error") - - manager.add_tool(Tool.from_function(error_tool)) - - with pytest.raises(ToolError, match="Specific tool error"): - await manager.call_tool("error_tool", {"x": 42}) - - async def test_exception_converted_to_tool_error_with_details(self): - """Test that other exceptions include details by default.""" - manager = ToolManager() - - def buggy_tool(x: int) -> int: - """Tool that raises a ValueError.""" - raise ValueError("Internal error details") - - manager.add_tool(Tool.from_function(buggy_tool)) - - with pytest.raises(ToolError) as excinfo: - await manager.call_tool("buggy_tool", {"x": 42}) - - # Exception message should include the tool name and the internal details - assert "Error calling tool 'buggy_tool'" in str(excinfo.value) - assert "Internal error details" in str(excinfo.value) - - async def test_exception_converted_to_masked_tool_error(self): - """Test that other exceptions are masked when enabled.""" - manager = ToolManager(mask_error_details=True) - - def buggy_tool(x: int) -> int: - """Tool that raises a ValueError.""" - raise ValueError("Internal error details") - - manager.add_tool(Tool.from_function(buggy_tool)) - - with pytest.raises(ToolError) as excinfo: - await manager.call_tool("buggy_tool", {"x": 42}) - - # Exception message should only contain the tool name, not the internal details - assert "Error calling tool 'buggy_tool'" in str(excinfo.value) - assert "Internal error details" not in str(excinfo.value) - - async def test_async_tool_error_passthrough(self): - """Test that ToolErrors from async tools are passed through directly.""" - manager = ToolManager() - - async def async_error_tool(x: int) -> int: - """Async tool that raises a ToolError.""" - raise ToolError("Async tool error") - - manager.add_tool(Tool.from_function(async_error_tool)) - - with pytest.raises(ToolError, match="Async tool error"): - await manager.call_tool("async_error_tool", {"x": 42}) - - async def test_async_exception_converted_to_tool_error_with_details(self): - """Test that other exceptions from async tools include details by default.""" - manager = ToolManager() - - async def async_buggy_tool(x: int) -> int: - """Async tool that raises a ValueError.""" - raise ValueError("Internal async error details") - - manager.add_tool(Tool.from_function(async_buggy_tool)) - - with pytest.raises(ToolError) as excinfo: - await manager.call_tool("async_buggy_tool", {"x": 42}) - - # Exception message should include the tool name and the internal details - assert "Error calling tool 'async_buggy_tool'" in str(excinfo.value) - assert "Internal async error details" in str(excinfo.value) - - async def test_async_exception_converted_to_masked_tool_error(self): - """Test that other exceptions from async tools are masked when enabled.""" - manager = ToolManager(mask_error_details=True) - - async def async_buggy_tool(x: int) -> int: - """Async tool that raises a ValueError.""" - raise ValueError("Internal async error details") - - manager.add_tool(Tool.from_function(async_buggy_tool)) - - with pytest.raises(ToolError) as excinfo: - await manager.call_tool("async_buggy_tool", {"x": 42}) - - # Exception message should contain the tool name but not the internal details - assert "Error calling tool 'async_buggy_tool'" in str(excinfo.value) - assert "Internal async error details" not in str(excinfo.value) - - -class TestMountedComponentsRaiseOnLoadError: - """Test the mounted_components_raise_on_load_error setting.""" - - async def test_mounted_components_raise_on_load_error_default_false(self): - """Test that by default, mounted component load errors are warned and not raised.""" - import fastmcp - - # Ensure default setting is False - assert fastmcp.settings.mounted_components_raise_on_load_error is False - - parent_mcp = FastMCP("ParentServer") - child_mcp = FastMCP("FailingChildServer") - - # Create a failing mounted server by corrupting it - parent_mcp.mount(child_mcp, namespace="child") - # Corrupt the parent's providers to make it fail during loading - assert isinstance(parent_mcp._providers, list) - parent_mcp._providers.append("invalid") # type: ignore[arg-type] - - # Should not raise, just warn; use server middleware path now - tools = await parent_mcp._list_tools_middleware() - assert isinstance(tools, list) # Should return list, not raise - - async def test_mounted_components_raise_on_load_error_true(self): - """Test that when enabled, mounted component load errors are raised.""" - parent_mcp = FastMCP("ParentServer") - child_mcp = FastMCP("FailingChildServer") - - # Create a failing mounted server - parent_mcp.mount(child_mcp, namespace="child") - # Corrupt the parent's providers to make it fail during loading - assert isinstance(parent_mcp._providers, list) - parent_mcp._providers.append("invalid") # type: ignore[arg-type] - - # Use temporary settings context manager - with temporary_settings(mounted_components_raise_on_load_error=True): - # Should raise the exception - with pytest.raises( - AttributeError, match="'str' object has no attribute 'list_tools'" - ): - await parent_mcp._list_tools_middleware()