diff --git a/src/fastmcp/server/providers/base.py b/src/fastmcp/server/providers/base.py index 6c35fcf141..c65c266a01 100644 --- a/src/fastmcp/server/providers/base.py +++ b/src/fastmcp/server/providers/base.py @@ -35,6 +35,7 @@ async def get_tool(self, name: str) -> Tool | None: from fastmcp.resources.resource import Resource from fastmcp.resources.template import ResourceTemplate from fastmcp.tools.tool import Tool +from fastmcp.utilities.async_utils import gather from fastmcp.utilities.components import FastMCPComponent from fastmcp.utilities.visibility import VisibilityFilter @@ -60,6 +61,9 @@ class Provider: def __init__(self) -> None: self._visibility = VisibilityFilter() + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + def with_transforms( self, *, @@ -219,19 +223,18 @@ async def get_component( Returns: The component if found, or None to continue searching other providers. """ - # Default implementation: iterate through all components and match by key - for tool in await self.list_tools(): - if tool.key == key: - return tool - for resource in await self.list_resources(): - if resource.key == key: - return resource - for template in await self.list_resource_templates(): - if template.key == key: - return template - for prompt in await self.list_prompts(): - if prompt.key == key: - return prompt + # Default implementation: fetch all component types in parallel + # Exceptions propagate since return_exceptions=False + results = await gather( + self.list_tools(), + self.list_resources(), + self.list_resource_templates(), + self.list_prompts(), + ) + for components in results: + for component in components: # type: ignore[union-attr] + if component.key == key: + return component return None # ------------------------------------------------------------------------- diff --git a/src/fastmcp/server/server.py b/src/fastmcp/server/server.py index 99f030e9ed..b3a8c40db2 100644 --- a/src/fastmcp/server/server.py +++ b/src/fastmcp/server/server.py @@ -85,6 +85,7 @@ from fastmcp.settings import Settings from fastmcp.tools.tool import FunctionTool, Tool, ToolResult from fastmcp.tools.tool_transform import ToolTransformConfig +from fastmcp.utilities.async_utils import gather from fastmcp.utilities.cli import log_server_banner from fastmcp.utilities.components import FastMCPComponent from fastmcp.utilities.logging import get_logger, temporary_log_level @@ -473,20 +474,23 @@ async def _docket_lifespan(self) -> AsyncIterator[None]: # Store on server instance for cross-task access (FastMCPTransport) self._docket = docket - # Register task-enabled components from all providers (LocalProvider first) - for provider in self._providers: - try: - for component in await provider.get_tasks(): - component.register_with_docket(docket) - except Exception as e: - provider_name = getattr( - provider, "server", provider - ).__class__.__name__ + # Register task-enabled components from all providers in parallel + task_results = await gather( + *[p.get_tasks() for p in self._providers], + return_exceptions=True, + ) + + for i, result in enumerate(task_results): + if isinstance(result, BaseException): + provider = self._providers[i] logger.warning( - f"Failed to register tasks from provider {provider_name!r}: {e}" + f"Failed to register tasks from {provider}: {result}" ) if fastmcp.settings.mounted_components_raise_on_load_error: - raise + raise result + continue + for component in result: + component.register_with_docket(docket) # Set Docket in ContextVar so CurrentDocket can access it docket_token = _current_docket.set(docket) @@ -789,39 +793,47 @@ def _is_component_enabled(self, component: FastMCPComponent) -> bool: async def get_tools(self) -> dict[str, Tool]: """Get all enabled tools from providers, indexed by name. - Iterates through all providers (LocalProvider first) and collects tools. + Queries all providers in parallel and collects tools. First provider wins for duplicate names. Filters by server blocklist. """ + results = await gather( + *[p.list_tools() for p in self._providers], + return_exceptions=True, + ) + all_tools: dict[str, Tool] = {} - for provider in self._providers: - try: - provider_tools = await provider.list_tools() - for tool in provider_tools: - if tool.name not in all_tools and self._is_component_enabled(tool): - all_tools[tool.name] = tool - except Exception as e: - provider_name = getattr(provider, "server", provider).__class__.__name__ - logger.warning( - f"Failed to get tools from provider {provider_name!r}: {e}" - ) + for i, result in enumerate(results): + if isinstance(result, BaseException): + provider = self._providers[i] + logger.warning(f"Failed to get tools from {provider}: {result}") if fastmcp.settings.mounted_components_raise_on_load_error: - raise + raise result continue + for tool in result: + if tool.name not in all_tools and self._is_component_enabled(tool): + all_tools[tool.name] = tool return all_tools async def get_tool(self, name: str) -> Tool: """Get an enabled tool by name. - Iterates through all providers (LocalProvider first) to find the tool. + Queries all providers in parallel to find the tool. First provider wins. Returns only if enabled. """ - for provider in self._providers: - try: - tool = await provider.get_tool(name) - if tool is not None and self._is_component_enabled(tool): - return tool - except NotFoundError: + results = await gather( + *[p.get_tool(name) for p in self._providers], + return_exceptions=True, + ) + + for i, result in enumerate(results): + if isinstance(result, BaseException): + if not isinstance(result, NotFoundError): + logger.debug( + f"Error getting tool from {self._providers[i]}: {result}" + ) continue + if isinstance(result, Tool) and self._is_component_enabled(result): + return result raise NotFoundError(f"Unknown tool: {name}") @@ -833,151 +845,176 @@ async def _get_resource_or_template_or_none( Returns the original ResourceTemplate (not a Resource created from it) to preserve the registered function for task execution. - Iterates through all providers (LocalProvider first). + Queries all providers in parallel. First provider wins. Checks concrete resources first, then templates. """ - # First pass: check concrete resources from all providers - for provider in self._providers: - try: - resource = await provider.get_resource(uri) - if resource is not None and self._is_component_enabled(resource): - return resource - except NotFoundError: - continue + # Resources listed first so they have priority over templates + results = await gather( + *[p.get_resource(uri) for p in self._providers], + *[p.get_resource_template(uri) for p in self._providers], + return_exceptions=True, + ) - # Second pass: check templates from all providers - for provider in self._providers: - try: - template = await provider.get_resource_template(uri) - if template is not None and self._is_component_enabled(template): - return template - except NotFoundError: + for result in results: + if isinstance(result, BaseException): + if not isinstance(result, NotFoundError): + logger.debug(f"Error getting resource/template: {result}") continue + if isinstance( + result, (Resource, ResourceTemplate) + ) and self._is_component_enabled(result): + return result return None async def get_resources(self) -> dict[str, Resource]: """Get all enabled resources from providers, indexed by URI. - Iterates through all providers (LocalProvider first) and collects resources. + Queries all providers in parallel and collects resources. First provider wins for duplicate URIs. Filters by server blocklist. """ + results = await gather( + *[p.list_resources() for p in self._providers], + return_exceptions=True, + ) + all_resources: dict[str, Resource] = {} - for provider in self._providers: - try: - provider_resources = await provider.list_resources() - for resource in provider_resources: - uri = str(resource.uri) - if uri not in all_resources and self._is_component_enabled( - resource - ): - all_resources[uri] = resource - except Exception as e: - provider_name = getattr(provider, "server", provider).__class__.__name__ - logger.warning( - f"Failed to get resources from provider {provider_name!r}: {e}" - ) + for i, result in enumerate(results): + if isinstance(result, BaseException): + provider = self._providers[i] + logger.warning(f"Failed to get resources from {provider}: {result}") if fastmcp.settings.mounted_components_raise_on_load_error: - raise + raise result continue + for resource in result: + uri = str(resource.uri) + if uri not in all_resources and self._is_component_enabled(resource): + all_resources[uri] = resource return all_resources async def get_resource(self, uri: str) -> Resource: """Get an enabled resource by URI. - Iterates through all providers (LocalProvider first) to find the resource. + Queries all providers in parallel to find the resource. First provider wins. Returns only if enabled. """ - for provider in self._providers: - try: - resource = await provider.get_resource(uri) - if resource is not None and self._is_component_enabled(resource): - return resource - except NotFoundError: + results = await gather( + *[p.get_resource(uri) for p in self._providers], + return_exceptions=True, + ) + + for i, result in enumerate(results): + if isinstance(result, BaseException): + if not isinstance(result, NotFoundError): + logger.debug( + f"Error getting resource from {self._providers[i]}: {result}" + ) continue + if isinstance(result, Resource) and self._is_component_enabled(result): + return result raise NotFoundError(f"Unknown resource: {uri}") async def get_resource_templates(self) -> dict[str, ResourceTemplate]: """Get all enabled resource templates from providers, indexed by uri_template. - Iterates through all providers (LocalProvider first) and collects templates. + Queries all providers in parallel and collects templates. First provider wins for duplicate uri_templates. Filters by server blocklist. """ + results = await gather( + *[p.list_resource_templates() for p in self._providers], + return_exceptions=True, + ) + all_templates: dict[str, ResourceTemplate] = {} - for provider in self._providers: - try: - provider_templates = await provider.list_resource_templates() - for template in provider_templates: - if ( - template.uri_template not in all_templates - and self._is_component_enabled(template) - ): - all_templates[template.uri_template] = template - except Exception as e: - provider_name = getattr(provider, "server", provider).__class__.__name__ + for i, result in enumerate(results): + if isinstance(result, BaseException): + provider = self._providers[i] logger.warning( - f"Failed to get resource templates from provider {provider_name!r}: {e}" + f"Failed to get resource templates from {provider}: {result}" ) if fastmcp.settings.mounted_components_raise_on_load_error: - raise + raise result continue + for template in result: + if ( + template.uri_template not in all_templates + and self._is_component_enabled(template) + ): + all_templates[template.uri_template] = template return all_templates async def get_resource_template(self, uri: str) -> ResourceTemplate: """Get an enabled resource template that matches the given URI. - Iterates through all providers (LocalProvider first) to find the template. + Queries all providers in parallel to find the template. First provider wins. Returns only if enabled. """ - for provider in self._providers: - try: - template = await provider.get_resource_template(uri) - if template is not None and self._is_component_enabled(template): - return template - except NotFoundError: + results = await gather( + *[p.get_resource_template(uri) for p in self._providers], + return_exceptions=True, + ) + + for i, result in enumerate(results): + if isinstance(result, BaseException): + if not isinstance(result, NotFoundError): + logger.debug( + f"Error getting template from {self._providers[i]}: {result}" + ) continue + if isinstance(result, ResourceTemplate) and self._is_component_enabled( + result + ): + return result raise NotFoundError(f"Unknown resource template: {uri}") async def get_prompts(self) -> dict[str, Prompt]: """Get all enabled prompts from providers, indexed by name. - Iterates through all providers (LocalProvider first) and collects prompts. + Queries all providers in parallel and collects prompts. First provider wins for duplicate names. Filters by server blocklist. """ + results = await gather( + *[p.list_prompts() for p in self._providers], + return_exceptions=True, + ) + all_prompts: dict[str, Prompt] = {} - for provider in self._providers: - try: - provider_prompts = await provider.list_prompts() - for prompt in provider_prompts: - if prompt.name not in all_prompts and self._is_component_enabled( - prompt - ): - all_prompts[prompt.name] = prompt - except Exception as e: - provider_name = getattr(provider, "server", provider).__class__.__name__ - logger.warning( - f"Failed to get prompts from provider {provider_name!r}: {e}" - ) + for i, result in enumerate(results): + if isinstance(result, BaseException): + provider = self._providers[i] + logger.warning(f"Failed to get prompts from {provider}: {result}") if fastmcp.settings.mounted_components_raise_on_load_error: - raise + raise result continue + for prompt in result: + if prompt.name not in all_prompts and self._is_component_enabled( + prompt + ): + all_prompts[prompt.name] = prompt return all_prompts async def get_prompt(self, name: str) -> Prompt: """Get an enabled prompt by name. - Iterates through all providers (LocalProvider first) to find the prompt. + Queries all providers in parallel to find the prompt. First provider wins. Returns only if enabled. """ - for provider in self._providers: - try: - prompt = await provider.get_prompt(name) - if prompt is not None and self._is_component_enabled(prompt): - return prompt - except NotFoundError: + results = await gather( + *[p.get_prompt(name) for p in self._providers], + return_exceptions=True, + ) + + for i, result in enumerate(results): + if isinstance(result, BaseException): + if not isinstance(result, NotFoundError): + logger.debug( + f"Error getting prompt from {self._providers[i]}: {result}" + ) continue + if isinstance(result, Prompt) and self._is_component_enabled(result): + return result raise NotFoundError(f"Unknown prompt: {name}") @@ -986,7 +1023,7 @@ async def get_component( ) -> Tool | Resource | ResourceTemplate | Prompt: """Get a component by its prefixed key. - Iterates through all providers (LocalProvider first) to find the component. + Queries all providers in parallel to find the component. First provider wins. Args: @@ -998,13 +1035,20 @@ async def get_component( Raises: NotFoundError: If no component is found with the given key. """ - for provider in self._providers: - try: - component = await provider.get_component(key) - if component is not None: - return component - except NotFoundError: + results = await gather( + *[p.get_component(key) for p in self._providers], + return_exceptions=True, + ) + + for i, result in enumerate(results): + if isinstance(result, BaseException): + if not isinstance(result, NotFoundError): + logger.debug( + f"Error getting component from {self._providers[i]}: {result}" + ) continue + if isinstance(result, FastMCPComponent): + return result raise NotFoundError(f"Unknown component: {key}") @@ -1115,20 +1159,24 @@ async def _list_tools( """ List all available tools. - Iterates through all providers (LocalProvider first) and collects tools. + Queries all providers in parallel and collects tools. First provider wins for duplicate keys. """ + results = await gather( + *[p.list_tools() for p in self._providers], + return_exceptions=True, + ) + all_tools: dict[str, Tool] = {} - for provider in self._providers: - try: - provider_tools = await provider.list_tools() - for tool in provider_tools: - if self._is_component_enabled(tool) and tool.key not in all_tools: - all_tools[tool.key] = tool - except Exception: - logger.exception("Error listing tools from provider") + for result in results: + if isinstance(result, BaseException): + logger.exception("Error listing tools from provider", exc_info=result) if fastmcp.settings.mounted_components_raise_on_load_error: - raise + raise result + continue + for tool in result: + if self._is_component_enabled(tool) and tool.key not in all_tools: + all_tools[tool.key] = tool return list(all_tools.values()) async def _list_resources_mcp(self) -> list[SDKResource]: @@ -1177,23 +1225,29 @@ async def _list_resources( """ List all available resources. - Iterates through all providers (LocalProvider first) and collects resources. + Queries all providers in parallel and collects resources. First provider wins for duplicate keys. """ + results = await gather( + *[p.list_resources() for p in self._providers], + return_exceptions=True, + ) + all_resources: dict[str, Resource] = {} - for provider in self._providers: - try: - provider_resources = await provider.list_resources() - for resource in provider_resources: - if ( - self._is_component_enabled(resource) - and resource.key not in all_resources - ): - all_resources[resource.key] = resource - except Exception: - logger.exception("Error listing resources from provider") + for result in results: + if isinstance(result, BaseException): + logger.exception( + "Error listing resources from provider", exc_info=result + ) if fastmcp.settings.mounted_components_raise_on_load_error: - raise + raise result + continue + for resource in result: + if ( + self._is_component_enabled(resource) + and resource.key not in all_resources + ): + all_resources[resource.key] = resource return list(all_resources.values()) async def _list_resource_templates_mcp(self) -> list[SDKResourceTemplate]: @@ -1243,23 +1297,29 @@ async def _list_resource_templates( """ List all available resource templates. - Iterates through all providers (LocalProvider first) and collects templates. + Queries all providers in parallel and collects templates. First provider wins for duplicate keys. """ + results = await gather( + *[p.list_resource_templates() for p in self._providers], + return_exceptions=True, + ) + all_templates: dict[str, ResourceTemplate] = {} - for provider in self._providers: - try: - provider_templates = await provider.list_resource_templates() - for template in provider_templates: - if ( - self._is_component_enabled(template) - and template.key not in all_templates - ): - all_templates[template.key] = template - except Exception: - logger.exception("Error listing resource templates from provider") + for result in results: + if isinstance(result, BaseException): + logger.exception( + "Error listing resource templates from provider", exc_info=result + ) if fastmcp.settings.mounted_components_raise_on_load_error: - raise + raise result + continue + for template in result: + if ( + self._is_component_enabled(template) + and template.key not in all_templates + ): + all_templates[template.key] = template return list(all_templates.values()) async def _list_prompts_mcp(self) -> list[SDKPrompt]: @@ -1309,23 +1369,24 @@ async def _list_prompts( """ List all available prompts. - Iterates through all providers (LocalProvider first) and collects prompts. + Queries all providers in parallel and collects prompts. First provider wins for duplicate keys. """ + results = await gather( + *[p.list_prompts() for p in self._providers], + return_exceptions=True, + ) + all_prompts: dict[str, Prompt] = {} - for provider in self._providers: - try: - provider_prompts = await provider.list_prompts() - for prompt in provider_prompts: - if ( - self._is_component_enabled(prompt) - and prompt.key not in all_prompts - ): - all_prompts[prompt.key] = prompt - except Exception: - logger.exception("Error listing prompts from provider") + for result in results: + if isinstance(result, BaseException): + logger.exception("Error listing prompts from provider", exc_info=result) if fastmcp.settings.mounted_components_raise_on_load_error: - raise + raise result + continue + for prompt in result: + if self._is_component_enabled(prompt) and prompt.key not in all_prompts: + all_prompts[prompt.key] = prompt return list(all_prompts.values()) async def _call_tool_mcp( @@ -1515,7 +1576,7 @@ async def _call_tool( """ Call a tool. - Iterates through all providers (LocalProvider first) to find the tool. + Iterates through all providers to find the tool. First provider wins. """ tool_name = context.message.name @@ -1614,7 +1675,7 @@ async def _read_resource( """ Read a resource. - Iterates through all providers (LocalProvider first) to find the resource. + Iterates through all providers to find the resource. First provider wins. Checks concrete resources first, then templates. Returns list[ResourceContent] for synchronous execution, or CreateTaskResult @@ -1758,7 +1819,7 @@ async def _get_prompt( """ Get a prompt. - Iterates through all providers (LocalProvider first) to find the prompt. + Iterates through all providers to find the prompt. First provider wins. Returns PromptResult for synchronous execution, or CreateTaskResult diff --git a/src/fastmcp/utilities/async_utils.py b/src/fastmcp/utilities/async_utils.py new file mode 100644 index 0000000000..b69077c9d2 --- /dev/null +++ b/src/fastmcp/utilities/async_utils.py @@ -0,0 +1,56 @@ +"""Async utilities for FastMCP.""" + +from collections.abc import Awaitable +from typing import Literal, TypeVar, overload + +import anyio + +T = TypeVar("T") + + +@overload +async def gather( + *awaitables: Awaitable[T], + return_exceptions: Literal[True], +) -> list[T | BaseException]: ... + + +@overload +async def gather( + *awaitables: Awaitable[T], + return_exceptions: Literal[False] = ..., +) -> list[T]: ... + + +async def gather( + *awaitables: Awaitable[T], + return_exceptions: bool = False, +) -> list[T] | list[T | BaseException]: + """Run awaitables concurrently and return results in order. + + Uses anyio TaskGroup for structured concurrency. + + Args: + *awaitables: Awaitables to run concurrently + return_exceptions: If True, exceptions are returned in results. + If False, first exception cancels all and raises. + + Returns: + List of results in the same order as input awaitables. + """ + results: list[T | BaseException] = [None] * len(awaitables) # type: ignore[assignment] + + async def run_at(i: int, aw: Awaitable[T]) -> None: + try: + results[i] = await aw + except BaseException as e: + if return_exceptions: + results[i] = e + else: + raise + + async with anyio.create_task_group() as tg: + for i, aw in enumerate(awaitables): + tg.start_soon(run_at, i, aw) + + return results