diff --git a/docs/servers/context.mdx b/docs/servers/context.mdx index b516b31d31..5dcea7f67b 100644 --- a/docs/servers/context.mdx +++ b/docs/servers/context.mdx @@ -23,6 +23,7 @@ The `Context` object provides a clean interface to access MCP features within yo - **LLM Sampling**: Request the client's LLM to generate text based on provided messages - **User Elicitation**: Request structured input from users during tool execution - **Session State**: Store data that persists across requests within an MCP session +- **Session Visibility**: [Control which components are visible](/servers/enabled#per-session-visibility) to the current session - **Request Information**: Access metadata about the current request - **Server Access**: When needed, access the underlying FastMCP server instance @@ -261,6 +262,12 @@ Any backend compatible with the [py-key-value-aio](https://github.com/strawgate/ State set during `on_initialize` middleware persists to subsequent tool calls when using the same session object (STDIO, SSE, single-server HTTP). For distributed/serverless HTTP deployments where different machines handle init and tool calls, state is isolated by the `mcp-session-id` header. +### Session Visibility + + + +Tools can customize which components are visible to their current session using `ctx.enable_components()`, `ctx.disable_components()`, and `ctx.reset_components()`. These methods apply visibility rules that affect only the calling session, leaving other sessions unchanged. See [Per-Session Visibility](/servers/enabled#per-session-visibility) for complete documentation, filter criteria, and patterns like namespace activation. + ### Change Notifications diff --git a/docs/servers/enabled.mdx b/docs/servers/enabled.mdx index 39cff4da24..bd718aa149 100644 --- a/docs/servers/enabled.mdx +++ b/docs/servers/enabled.mdx @@ -3,6 +3,7 @@ title: Component Visibility sidebarTitle: Component Visibility description: Control which components are available to clients icon: toggle-on +tag: NEW --- import { VersionBadge } from '/snippets/version-badge.mdx' @@ -285,9 +286,153 @@ def check_permissions(ctx: Context) -> str: return "Admin tools disabled" ``` - -Dynamic enabled state changes affect all connected clients. For per-user filtering, consider using separate server instances or implementing authorization in the tools themselves. - +## Per-Session Visibility + +Server-level visibility changes affect all connected clients simultaneously. When you need different clients to see different components, use per-session visibility instead. + +Session visibility lets individual sessions customize their view of available components. When a tool calls `ctx.enable_components()` or `ctx.disable_components()`, those rules apply only to the current session. Other sessions continue to see the global defaults. This enables patterns like progressive disclosure, role-based access, and on-demand feature activation. + +```python +from fastmcp import FastMCP +from fastmcp.server.context import Context + +mcp = FastMCP("Session-Aware Server") + +@mcp.tool(tags={"premium"}) +def premium_analysis(data: str) -> str: + """Advanced analysis available to premium users.""" + return f"Premium analysis of: {data}" + +@mcp.tool +async def unlock_premium(ctx: Context) -> str: + """Unlock premium features for this session.""" + await ctx.enable_components(tags={"premium"}) + return "Premium features unlocked" + +@mcp.tool +async def reset_features(ctx: Context) -> str: + """Reset to default feature set.""" + await ctx.reset_components() + return "Features reset to defaults" + +# Premium tools are disabled globally by default +mcp.disable(tags={"premium"}) +``` + +All sessions start with `premium_analysis` hidden. When a session calls `unlock_premium`, that session gains access to premium tools while other sessions remain unaffected. Calling `reset_features` returns the session to the global defaults. + +### How Session Rules Work + +Session rules override global transforms. When listing components, FastMCP first applies global enable/disable rules, then applies session-specific rules on top. Rules within a session accumulate, and later rules override earlier ones for the same component. + +```python +@mcp.tool +async def customize_session(ctx: Context) -> str: + # Enable finance tools for this session + await ctx.enable_components(tags={"finance"}) + + # Also enable admin tools + await ctx.enable_components(tags={"admin"}) + + # Later: disable a specific admin tool + await ctx.disable_components(names={"dangerous_admin_tool"}) + + return "Session customized" +``` + +Each call adds a rule to the session. The `dangerous_admin_tool` ends up disabled because its disable rule was added after the admin enable rule. + +### Filter Criteria + +The session visibility methods accept the same filter criteria as `server.enable()` and `server.disable()`: + +| Parameter | Description | +|-----------|-------------| +| `names` | Component names or URIs to match | +| `keys` | Component keys (e.g., `{"tool:my_tool"}`) | +| `tags` | Tags to match (component must have at least one) | +| `version` | Version specification to match | +| `components` | Component types (`{"tool"}`, `{"resource"}`, `{"prompt"}`, `{"template"}`) | +| `match_all` | If `True`, matches all components regardless of other criteria | + +```python +from fastmcp.utilities.versions import VersionSpec + +@mcp.tool +async def enable_recent_tools(ctx: Context) -> str: + """Enable only tools from version 2.0.0 or later.""" + await ctx.enable_components( + version=VersionSpec(gte="2.0.0"), + components={"tool"} + ) + return "Recent tools enabled" +``` + +### Automatic Notifications + +When session visibility changes, FastMCP automatically sends notifications to that session. Clients receive `ToolListChangedNotification`, `ResourceListChangedNotification`, and `PromptListChangedNotification` so they can refresh their component lists. These notifications go only to the affected session. + +When you specify the `components` parameter, FastMCP optimizes by sending only the relevant notifications: + +```python +# Only sends ToolListChangedNotification +await ctx.enable_components(tags={"finance"}, components={"tool"}) + +# Sends all three notifications (no components filter) +await ctx.enable_components(tags={"finance"}) +``` + +### Namespace Activation Pattern + +A common pattern organizes tools into namespaces using tag prefixes, disables them globally, then provides activation tools that unlock namespaces on demand: + +```python +from fastmcp import FastMCP +from fastmcp.server.context import Context + +server = FastMCP("Multi-Domain Assistant") + +# Finance namespace +@server.tool(tags={"namespace:finance"}) +def analyze_portfolio(symbols: list[str]) -> str: + return f"Analysis for: {', '.join(symbols)}" + +@server.tool(tags={"namespace:finance"}) +def get_market_data(symbol: str) -> dict: + return {"symbol": symbol, "price": 150.25} + +# Admin namespace +@server.tool(tags={"namespace:admin"}) +def list_users() -> list[str]: + return ["alice", "bob", "charlie"] + +# Activation tools - always visible +@server.tool +async def activate_finance(ctx: Context) -> str: + await ctx.enable_components(tags={"namespace:finance"}) + return "Finance tools activated" + +@server.tool +async def activate_admin(ctx: Context) -> str: + await ctx.enable_components(tags={"namespace:admin"}) + return "Admin tools activated" + +@server.tool +async def deactivate_all(ctx: Context) -> str: + await ctx.reset_components() + return "All namespaces deactivated" + +# Disable namespace tools globally +server.disable(tags={"namespace:finance", "namespace:admin"}) +``` + +Sessions start seeing only the activation tools. Calling `activate_finance` reveals finance tools for that session only. Multiple namespaces can be activated independently, and `deactivate_all` returns to the initial state. + +### Method Reference + +- **`await ctx.enable_components(...) -> None`**: Enable matching components for this session +- **`await ctx.disable_components(...) -> None`**: Disable matching components for this session +- **`await ctx.reset_components() -> None`**: Clear all session rules, returning to global defaults ## Client Notifications diff --git a/examples/namespace_activation/README.md b/examples/namespace_activation/README.md new file mode 100644 index 0000000000..3dffb146d2 --- /dev/null +++ b/examples/namespace_activation/README.md @@ -0,0 +1,55 @@ +# Namespace Activation + +Demonstrates session-specific visibility control using tags to organize tools into namespaces that can be activated on demand. + +## Pattern + +1. Tag tools with namespaces: `@server.tool(tags={"namespace:finance"})` +2. Globally disable namespaces: `server.disable(tags={"namespace:finance"})` +3. Provide activation tools that call `ctx.enable_components(tags={"namespace:finance"})` + +Each session starts with only the activation tools visible. When a session calls an activation tool, that namespace becomes visible **only for that session**. + +## Run + +```bash +# Server +uv run python server.py + +# Client (in another terminal) +uv run python client.py +``` + +## Example Output + +``` +Namespace Activation Demo + +╭─────────────────── Initial Tools ───────────────────╮ +│ activate_finance, activate_admin, deactivate_all │ +╰─────────────────────────────────────────────────────╯ + +→ Calling activate_finance() + Finance tools activated +╭─────────────── After Activating Finance ────────────╮ +│ analyze_portfolio, get_market_data, execute_trade, │ +│ activate_finance, activate_admin, deactivate_all │ +╰─────────────────────────────────────────────────────╯ + +→ Calling get_market_data(symbol='AAPL') + {'symbol': 'AAPL', 'price': 150.25, 'change': '+2.5%'} + +→ Calling activate_admin() + Admin tools activated +╭────────────── After Activating Admin ───────────────╮ +│ analyze_portfolio, get_market_data, execute_trade, │ +│ list_users, reset_user_password, activate_finance, │ +│ activate_admin, deactivate_all │ +╰─────────────────────────────────────────────────────╯ + +→ Calling deactivate_all() + All namespaces deactivated +╭────────────── After Deactivating All ───────────────╮ +│ activate_finance, activate_admin, deactivate_all │ +╰─────────────────────────────────────────────────────╯ +``` diff --git a/examples/namespace_activation/client.py b/examples/namespace_activation/client.py new file mode 100644 index 0000000000..7090f27e42 --- /dev/null +++ b/examples/namespace_activation/client.py @@ -0,0 +1,76 @@ +""" +Namespace Activation Client + +Demonstrates how session-specific visibility works from the client perspective. +""" + +import asyncio +import sys +from pathlib import Path + +from rich import print +from rich.panel import Panel + +from fastmcp import Client + + +def load_server(): + """Load the example server.""" + examples_dir = Path(__file__).parent + if str(examples_dir) not in sys.path: + sys.path.insert(0, str(examples_dir)) + + import server as server_module + + return server_module.server + + +server = load_server() + + +def show_tools(tools: list, title: str) -> None: + """Display available tools in a panel.""" + tool_names = [f"[cyan]{t.name}[/]" for t in tools] + print(Panel(", ".join(tool_names) or "[dim]No tools[/]", title=title)) + + +async def main(): + print("\n[bold]Namespace Activation Demo[/]\n") + + async with Client(server) as client: + # Initially only activation tools are visible + tools = await client.list_tools() + show_tools(tools, "Initial Tools") + + # Activate finance namespace + print("\n[yellow]→ Calling activate_finance()[/]") + result = await client.call_tool("activate_finance", {}) + print(f" [green]{result.data}[/]") + + tools = await client.list_tools() + show_tools(tools, "After Activating Finance") + + # Use a finance tool + print("\n[yellow]→ Calling get_market_data(symbol='AAPL')[/]") + result = await client.call_tool("get_market_data", {"symbol": "AAPL"}) + print(f" [green]{result.data}[/]") + + # Activate admin namespace too + print("\n[yellow]→ Calling activate_admin()[/]") + result = await client.call_tool("activate_admin", {}) + print(f" [green]{result.data}[/]") + + tools = await client.list_tools() + show_tools(tools, "After Activating Admin") + + # Deactivate all - back to defaults + print("\n[yellow]→ Calling deactivate_all()[/]") + result = await client.call_tool("deactivate_all", {}) + print(f" [green]{result.data}[/]") + + tools = await client.list_tools() + show_tools(tools, "After Deactivating All") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/namespace_activation/server.py b/examples/namespace_activation/server.py new file mode 100644 index 0000000000..758790b0c9 --- /dev/null +++ b/examples/namespace_activation/server.py @@ -0,0 +1,73 @@ +""" +Namespace Activation Server + +Tools are organized into namespaces using tags, globally disabled by default, +and selectively enabled per-session via activation tools. +""" + +from fastmcp import FastMCP +from fastmcp.server.context import Context + +server = FastMCP("Multi-Domain Assistant") + + +# Finance namespace +@server.tool(tags={"namespace:finance"}) +def analyze_portfolio(symbols: list[str]) -> str: + """Analyze a portfolio of stock symbols.""" + return f"Portfolio analysis for: {', '.join(symbols)}" + + +@server.tool(tags={"namespace:finance"}) +def get_market_data(symbol: str) -> dict: + """Get current market data for a symbol.""" + return {"symbol": symbol, "price": 150.25, "change": "+2.5%"} + + +@server.tool(tags={"namespace:finance"}) +def execute_trade(symbol: str, quantity: int, side: str) -> str: + """Execute a trade (simulated).""" + return f"Executed {side} order: {quantity} shares of {symbol}" + + +# Admin namespace +@server.tool(tags={"namespace:admin"}) +def list_users() -> list[str]: + """List all system users.""" + return ["alice", "bob", "charlie"] + + +@server.tool(tags={"namespace:admin"}) +def reset_user_password(username: str) -> str: + """Reset a user's password (simulated).""" + return f"Password reset for {username}" + + +# Activation tools - always visible +@server.tool +async def activate_finance(ctx: Context) -> str: + """Activate finance tools for this session.""" + await ctx.enable_components(tags={"namespace:finance"}) + return "Finance tools activated" + + +@server.tool +async def activate_admin(ctx: Context) -> str: + """Activate admin tools for this session.""" + await ctx.enable_components(tags={"namespace:admin"}) + return "Admin tools activated" + + +@server.tool +async def deactivate_all(ctx: Context) -> str: + """Deactivate all namespaces, returning to defaults.""" + await ctx.reset_components() + return "All namespaces deactivated" + + +# Globally disable namespace tools by default +server.disable(tags={"namespace:finance", "namespace:admin"}) + + +if __name__ == "__main__": + server.run() diff --git a/loq.toml b/loq.toml index 52222dfd2f..c54baae1b6 100644 --- a/loq.toml +++ b/loq.toml @@ -28,7 +28,7 @@ max_lines = 1899 [[rules]] path = "tests/server/middleware/test_middleware.py" -max_lines = 1250 +max_lines = 1070 [[rules]] path = "src/fastmcp/server/context.py" @@ -40,7 +40,7 @@ max_lines = 1748 [[rules]] path = "tests/server/test_mount.py" -max_lines = 1560 +max_lines = 1545 [[rules]] path = "tests/utilities/test_inspect.py" @@ -60,7 +60,7 @@ max_lines = 3250 [[rules]] path = "tests/tools/test_tool.py" -max_lines = 2250 +max_lines = 2026 [[rules]] path = "tests/client/test_elicitation.py" @@ -68,7 +68,7 @@ max_lines = 1132 [[rules]] path = "src/fastmcp/client/client.py" -max_lines = 2000 +max_lines = 1885 [[rules]] path = "tests/utilities/test_json_schema_type.py" @@ -95,9 +95,9 @@ path = "tests/server/auth/test_jwt_provider.py" max_lines = 1101 [[rules]] -path = "docs/servers/tools.mdx" -max_lines = 1200 +path = "src/fastmcp/server/providers/local_provider.py" +max_lines = 1187 [[rules]] -path = "docs/changelog.mdx" -max_lines = 2280 +path = "tests/server/test_versioning.py" +max_lines = 1235 diff --git a/src/fastmcp/server/context.py b/src/fastmcp/server/context.py index b245d164b2..00757da28e 100644 --- a/src/fastmcp/server/context.py +++ b/src/fastmcp/server/context.py @@ -54,9 +54,12 @@ execute_tools as run_sampling_tools, ) from fastmcp.server.server import FastMCP, StateValue +from fastmcp.server.transforms.enabled import Enabled +from fastmcp.utilities.components import FastMCPComponent from fastmcp.utilities.json_schema import compress_schema from fastmcp.utilities.logging import _clamp_logger, get_logger from fastmcp.utilities.types import get_cached_typeadapter +from fastmcp.utilities.versions import VersionSpec logger: Logger = get_logger(name=__name__) to_client_logger: Logger = logger.getChild(suffix="to_client") @@ -1141,6 +1144,194 @@ async def delete_state(self, key: str) -> None: prefixed_key = self._make_state_key(key) await self.fastmcp._state_store.delete(key=prefixed_key) + # ------------------------------------------------------------------------- + # Session visibility control + # ------------------------------------------------------------------------- + + async def _get_visibility_rules(self) -> list[dict[str, Any]]: + """Load visibility rule dicts from session state.""" + return await self.get_state("_visibility_rules") or [] + + async def _save_visibility_rules( + self, + rules: list[dict[str, Any]], + *, + components: set[Literal["tool", "resource", "template", "prompt"]] + | None = None, + ) -> None: + """Save visibility rule dicts to session state and send notifications. + + Args: + rules: The visibility rules to save. + components: Optional hint about which component types are affected. + If None, sends notifications for all types (safe default). + If provided, only sends notifications for specified types. + """ + await self.set_state("_visibility_rules", rules) + + # Send notifications based on components hint + # Note: MCP has no separate template notification - templates use ResourceListChangedNotification + if components is None or "tool" in components: + await self.send_notification(mcp.types.ToolListChangedNotification()) + if components is None or "resource" in components or "template" in components: + await self.send_notification(mcp.types.ResourceListChangedNotification()) + if components is None or "prompt" in components: + await self.send_notification(mcp.types.PromptListChangedNotification()) + + def _create_enabled_transforms(self, rules: list[dict[str, Any]]) -> list[Enabled]: + """Convert rule dicts to Enabled transforms.""" + transforms = [] + for params in rules: + version = None + if params.get("version"): + version_dict = params["version"] + version = VersionSpec( + gte=version_dict.get("gte"), + lt=version_dict.get("lt"), + eq=version_dict.get("eq"), + ) + transforms.append( + Enabled( + params["enabled"], + names=set(params["names"]) if params.get("names") else None, + keys=set(params["keys"]) if params.get("keys") else None, + version=version, + tags=set(params["tags"]) if params.get("tags") else None, + components=( + set(params["components"]) if params.get("components") else None + ), + match_all=params.get("match_all", False), + ) + ) + return transforms + + async def _get_session_transforms(self) -> list[Enabled]: + """Get session-specific Enabled transforms from state store.""" + try: + # Will raise RuntimeError if no session available + _ = self.session_id + except RuntimeError: + return [] + + rules = await self._get_visibility_rules() + return self._create_enabled_transforms(rules) + + async def enable_components( + self, + *, + names: set[str] | None = None, + keys: set[str] | None = None, + version: VersionSpec | None = None, + tags: set[str] | None = None, + components: set[Literal["tool", "resource", "template", "prompt"]] + | None = None, + match_all: bool = False, + ) -> None: + """Enable components matching criteria for this session only. + + Session rules override global transforms. Rules accumulate - each call + adds a new rule to the session. Later marks override earlier ones + (Enabled transform semantics). + + Sends notifications to this session only: ToolListChangedNotification, + ResourceListChangedNotification, and PromptListChangedNotification. + + Args: + names: Component names or URIs to match. + keys: Component keys to match (e.g., {"tool:my_tool@v1"}). + version: Component version spec to match. + tags: Tags to match (component must have at least one). + components: Component types to match (e.g., {"tool", "prompt"}). + match_all: If True, matches all components regardless of other criteria. + """ + # Normalize empty sets to None (empty = match all) + components = components if components else None + + # Load current rules + rules = await self._get_visibility_rules() + + # Create new rule dict + rule: dict[str, Any] = { + "enabled": True, + "names": list(names) if names else None, + "keys": list(keys) if keys else None, + "version": ( + {"gte": version.gte, "lt": version.lt, "eq": version.eq} + if version + else None + ), + "tags": list(tags) if tags else None, + "components": list(components) if components else None, + "match_all": match_all, + } + + # Add and save (notifications sent by _save_visibility_rules) + rules.append(rule) + await self._save_visibility_rules(rules, components=components) + + async def disable_components( + self, + *, + names: set[str] | None = None, + keys: set[str] | None = None, + version: VersionSpec | None = None, + tags: set[str] | None = None, + components: set[Literal["tool", "resource", "template", "prompt"]] + | None = None, + match_all: bool = False, + ) -> None: + """Disable components matching criteria for this session only. + + Session rules override global transforms. Rules accumulate - each call + adds a new rule to the session. Later marks override earlier ones + (Enabled transform semantics). + + Sends notifications to this session only: ToolListChangedNotification, + ResourceListChangedNotification, and PromptListChangedNotification. + + Args: + names: Component names or URIs to match. + keys: Component keys to match (e.g., {"tool:my_tool@v1"}). + version: Component version spec to match. + tags: Tags to match (component must have at least one). + components: Component types to match (e.g., {"tool", "prompt"}). + match_all: If True, matches all components regardless of other criteria. + """ + # Normalize empty sets to None (empty = match all) + components = components if components else None + + # Load current rules + rules = await self._get_visibility_rules() + + # Create new rule dict + rule: dict[str, Any] = { + "enabled": False, + "names": list(names) if names else None, + "keys": list(keys) if keys else None, + "version": ( + {"gte": version.gte, "lt": version.lt, "eq": version.eq} + if version + else None + ), + "tags": list(tags) if tags else None, + "components": list(components) if components else None, + "match_all": match_all, + } + + # Add and save (notifications sent by _save_visibility_rules) + rules.append(rule) + await self._save_visibility_rules(rules, components=components) + + async def reset_components(self) -> None: + """Clear all session visibility rules. + + Use this to reset session visibility back to global defaults. + + Sends notifications to this session only: ToolListChangedNotification, + ResourceListChangedNotification, and PromptListChangedNotification. + """ + await self._save_visibility_rules([]) + async def _log_to_server_and_client( data: LogData, @@ -1269,3 +1460,36 @@ def _extract_tool_calls( elif isinstance(content, ToolUseContent): return [content] return [] + + +ComponentT = TypeVar("ComponentT", bound="FastMCPComponent") + + +async def apply_session_transforms( + components: Sequence[ComponentT], +) -> Sequence[ComponentT]: + """Apply session-specific visibility transforms to components. + + This helper applies session-level enable/disable rules by marking + components with their enabled state. Session transforms override + global transforms due to mark-based semantics (later marks win). + + Args: + components: The components to apply session transforms to. + + Returns: + The components with session transforms applied. + """ + current_ctx = _current_context.get() + if current_ctx is None: + return components + + session_transforms = await current_ctx._get_session_transforms() + if not session_transforms: + return components + + # Apply each transform's marking to each component + result = list(components) + for transform in session_transforms: + result = [transform._mark_component(c) for c in result] + return result diff --git a/src/fastmcp/server/providers/base.py b/src/fastmcp/server/providers/base.py index 8da2512f84..e6dbc70141 100644 --- a/src/fastmcp/server/providers/base.py +++ b/src/fastmcp/server/providers/base.py @@ -545,7 +545,7 @@ def enable( keys: set[str] | None = None, version: VersionSpec | None = None, tags: set[str] | None = None, - components: list[Literal["tool", "resource", "template", "prompt"]] + components: set[Literal["tool", "resource", "template", "prompt"]] | None = None, only: bool = False, ) -> Self: @@ -564,7 +564,7 @@ def enable( version: Component version spec to enable (e.g., VersionSpec(eq="v1") or VersionSpec(gte="v2")). Unversioned components will not match. tags: Enable components with these tags. - components: Component types to include (e.g., ["tool", "prompt"]). + components: Component types to include (e.g., {"tool", "prompt"}). only: If True, ONLY enable matching components (allowlist mode). Returns: @@ -580,8 +580,8 @@ def enable( names=names, keys=keys, version=version, - components=frozenset(components) if components else None, - tags=frozenset(tags) if tags else None, + components=set(components) if components else None, + tags=set(tags) if tags else None, ) ) @@ -594,7 +594,7 @@ def disable( keys: set[str] | None = None, version: VersionSpec | None = None, tags: set[str] | None = None, - components: list[Literal["tool", "resource", "template", "prompt"]] + components: set[Literal["tool", "resource", "template", "prompt"]] | None = None, ) -> Self: """Disable components matching all specified criteria. @@ -609,7 +609,7 @@ def disable( version: Component version spec to disable (e.g., VersionSpec(eq="v1") or VersionSpec(gte="v2")). Unversioned components will not match. tags: Disable components with these tags. - components: Component types to include (e.g., ["tool", "prompt"]). + components: Component types to include (e.g., {"tool", "prompt"}). Returns: Self for method chaining. @@ -620,8 +620,8 @@ def disable( names=names, keys=keys, version=version, - components=frozenset(components) if components else None, - tags=frozenset(tags) if tags else None, + components=set(components) if components else None, + tags=set(tags) if tags else None, ) ) return self diff --git a/src/fastmcp/server/server.py b/src/fastmcp/server/server.py index 012da6bf99..a2e500198f 100644 --- a/src/fastmcp/server/server.py +++ b/src/fastmcp/server/server.py @@ -1026,7 +1026,13 @@ async def list_tools(self, *, run_middleware: bool = True) -> Sequence[Tool]: call_next=lambda context: self.list_tools(run_middleware=False), ) - tools = [t for t in await super().list_tools() if is_enabled(t)] + # Get all tools, apply session transforms, then filter enabled + from fastmcp.server.context import apply_session_transforms + + tools = list(await super().list_tools()) + tools = await apply_session_transforms(tools) + tools = [t for t in tools if is_enabled(t)] + skip_auth, token = _get_auth_context() authorized: list[Tool] = [] for tool in tools: @@ -1088,9 +1094,16 @@ async def get_tool( The tool if found and enabled, None otherwise. """ tool = await super().get_tool(name, version) - if tool is None or not is_enabled(tool): + if tool is None: return None - return tool + + # Apply session transforms to single item + from fastmcp.server.context import apply_session_transforms + + tools = await apply_session_transforms([tool]) + if not tools or not is_enabled(tools[0]): + return None + return tools[0] async def list_resources( self, *, run_middleware: bool = True @@ -1115,7 +1128,13 @@ async def list_resources( call_next=lambda context: self.list_resources(run_middleware=False), ) - resources = [r for r in await super().list_resources() if is_enabled(r)] + # Get all resources, apply session transforms, then filter enabled + from fastmcp.server.context import apply_session_transforms + + resources = list(await super().list_resources()) + resources = await apply_session_transforms(resources) + resources = [r for r in resources if is_enabled(r)] + skip_auth, token = _get_auth_context() authorized: list[Resource] = [] for resource in resources: @@ -1176,9 +1195,16 @@ async def get_resource( The resource if found and enabled, None otherwise. """ resource = await super().get_resource(uri, version) - if resource is None or not is_enabled(resource): + if resource is None: return None - return resource + + # Apply session transforms to single item + from fastmcp.server.context import apply_session_transforms + + resources = await apply_session_transforms([resource]) + if not resources or not is_enabled(resources[0]): + return None + return resources[0] async def list_resource_templates( self, *, run_middleware: bool = True @@ -1205,9 +1231,13 @@ async def list_resource_templates( ), ) - templates = [ - t for t in await super().list_resource_templates() if is_enabled(t) - ] + # Get all templates, apply session transforms, then filter enabled + from fastmcp.server.context import apply_session_transforms + + templates = list(await super().list_resource_templates()) + templates = await apply_session_transforms(templates) + templates = [t for t in templates if is_enabled(t)] + skip_auth, token = _get_auth_context() authorized: list[ResourceTemplate] = [] for template in templates: @@ -1268,9 +1298,16 @@ async def get_resource_template( The template if found and enabled, None otherwise. """ template = await super().get_resource_template(uri, version) - if template is None or not is_enabled(template): + if template is None: return None - return template + + # Apply session transforms to single item + from fastmcp.server.context import apply_session_transforms + + templates = await apply_session_transforms([template]) + if not templates or not is_enabled(templates[0]): + return None + return templates[0] async def list_prompts(self, *, run_middleware: bool = True) -> Sequence[Prompt]: """List all enabled prompts from providers. @@ -1293,7 +1330,13 @@ async def list_prompts(self, *, run_middleware: bool = True) -> Sequence[Prompt] call_next=lambda context: self.list_prompts(run_middleware=False), ) - prompts = [p for p in await super().list_prompts() if is_enabled(p)] + # Get all prompts, apply session transforms, then filter enabled + from fastmcp.server.context import apply_session_transforms + + prompts = list(await super().list_prompts()) + prompts = await apply_session_transforms(prompts) + prompts = [p for p in prompts if is_enabled(p)] + skip_auth, token = _get_auth_context() authorized: list[Prompt] = [] for prompt in prompts: @@ -1354,9 +1397,16 @@ async def get_prompt( The prompt if found and enabled, None otherwise. """ prompt = await super().get_prompt(name, version) - if prompt is None or not is_enabled(prompt): + if prompt is None: return None - return prompt + + # Apply session transforms to single item + from fastmcp.server.context import apply_session_transforms + + prompts = await apply_session_transforms([prompt]) + if not prompts or not is_enabled(prompts[0]): + return None + return prompts[0] @overload async def call_tool( diff --git a/src/fastmcp/server/transforms/enabled.py b/src/fastmcp/server/transforms/enabled.py index 98614bc09a..cd9f023e65 100644 --- a/src/fastmcp/server/transforms/enabled.py +++ b/src/fastmcp/server/transforms/enabled.py @@ -8,7 +8,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Literal, TypeVar from fastmcp.resources.resource import Resource from fastmcp.resources.template import ResourceTemplate @@ -47,14 +47,14 @@ class Enabled(Transform): Example: ```python # Disable components tagged "internal" - Enabled(False, tags=frozenset({"internal"})) + Enabled(False, tags={"internal"}) # Re-enable specific tool (override earlier disable) Enabled(True, names={"safe_tool"}) # Allowlist via composition: Enabled(False, match_all=True) # disable everything - Enabled(True, tags=frozenset({"public"})) # enable public + Enabled(True, tags={"public"}) # enable public ``` """ @@ -65,8 +65,9 @@ def __init__( names: set[str] | None = None, keys: set[str] | None = None, version: VersionSpec | None = None, - tags: frozenset[str] | None = None, - components: frozenset[str] | None = None, + tags: set[str] | None = None, + components: set[Literal["tool", "resource", "template", "prompt"]] + | None = None, match_all: bool = False, ) -> None: """Initialize an enabled marker. @@ -78,15 +79,15 @@ def __init__( version: Component version spec to match. Unversioned components (version=None) will NOT match a version spec. tags: Tags to match (component must have at least one). - components: Component types to match (e.g., frozenset({"tool", "prompt"})). + components: Component types to match (e.g., {"tool", "prompt"}). match_all: If True, matches all components regardless of other criteria. """ self._enabled = enabled self.names = names self.keys = keys self.version = version - self.tags = tags # e.g., frozenset({"internal", "deprecated"}) - self.components = components # e.g., frozenset({"tool", "prompt"}) + self.tags = tags # e.g., {"internal", "deprecated"} + self.components = components # e.g., {"tool", "prompt"} self.match_all = match_all def __repr__(self) -> str: diff --git a/tests/contrib/test_component_manager.py b/tests/contrib/test_component_manager.py index c2443110a7..8acae9aecc 100644 --- a/tests/contrib/test_component_manager.py +++ b/tests/contrib/test_component_manager.py @@ -50,7 +50,7 @@ def client(self, mcp): async def test_enable_tool_route(self, client, mcp): """Test enabling a tool via the HTTP route.""" # First disable the tool - mcp.disable(names={"test_tool"}, components=["tool"]) + mcp.disable(names={"test_tool"}, components={"tool"}) tools = await mcp.list_tools() assert not any(t.name == "test_tool" for t in tools) @@ -83,7 +83,7 @@ async def test_disable_tool_route(self, client, mcp): async def test_enable_resource_route(self, client, mcp): """Test enabling a resource via the HTTP route.""" # First disable the resource (can use URI as name for resources) - mcp.disable(names={"data://test_resource"}, components=["resource"]) + mcp.disable(names={"data://test_resource"}, components={"resource"}) resources = await mcp.list_resources() assert not any(str(r.uri) == "data://test_resource" for r in resources) @@ -116,7 +116,7 @@ async def test_disable_resource_route(self, client, mcp): async def test_enable_template_route(self, client, mcp): """Test enabling a resource template via the HTTP route.""" key = "data://test_resource/{id}" - mcp.disable(names={"data://test_resource/{id}"}, components=["template"]) + mcp.disable(names={"data://test_resource/{id}"}, components={"template"}) templates = await mcp.list_resource_templates() assert not any(t.uri_template == key for t in templates) response = client.post("/resources/data://test_resource/{id}/enable") @@ -143,7 +143,7 @@ async def test_disable_template_route(self, client, mcp): async def test_enable_prompt_route(self, client, mcp): """Test enabling a prompt via the HTTP route.""" # First disable the prompt - mcp.disable(names={"test_prompt"}, components=["prompt"]) + mcp.disable(names={"test_prompt"}, components={"prompt"}) prompts = await mcp.list_prompts() assert not any(p.name == "test_prompt" for p in prompts) @@ -224,7 +224,7 @@ def test_prompt() -> str: async def test_unauthorized_enable_tool(self): """Test that unauthenticated requests to enable a tool are rejected.""" - self.mcp.disable(names={"test_tool"}, components=["tool"]) + self.mcp.disable(names={"test_tool"}, components={"tool"}) tools = await self.mcp.list_tools() assert not any(t.name == "test_tool" for t in tools) @@ -235,7 +235,7 @@ async def test_unauthorized_enable_tool(self): async def test_authorized_enable_tool(self): """Test that authenticated requests to enable a tool are allowed.""" - self.mcp.disable(names={"test_tool"}, components=["tool"]) + self.mcp.disable(names={"test_tool"}, components={"tool"}) tools = await self.mcp.list_tools() assert not any(t.name == "test_tool" for t in tools) @@ -273,7 +273,7 @@ async def test_authorized_disable_tool(self): async def test_forbidden_enable_tool(self): """Test that requests with insufficient scopes are rejected.""" - self.mcp.disable(names={"test_tool"}, components=["tool"]) + self.mcp.disable(names={"test_tool"}, components={"tool"}) tools = await self.mcp.list_tools() assert not any(t.name == "test_tool" for t in tools) @@ -287,7 +287,7 @@ async def test_forbidden_enable_tool(self): async def test_authorized_enable_resource(self): """Test that authenticated requests to enable a resource are allowed.""" - self.mcp.disable(names={"data://test_resource"}, components=["resource"]) + self.mcp.disable(names={"data://test_resource"}, components={"resource"}) resources = await self.mcp.list_resources() assert not any(str(r.uri) == "data://test_resource" for r in resources) @@ -312,7 +312,7 @@ async def test_unauthorized_disable_resource(self): async def test_forbidden_enable_resource(self): """Test that requests with insufficient scopes are rejected.""" - self.mcp.disable(names={"data://test_resource"}, components=["resource"]) + self.mcp.disable(names={"data://test_resource"}, components={"resource"}) resources = await self.mcp.list_resources() assert not any(str(r.uri) == "data://test_resource" for r in resources) @@ -340,7 +340,7 @@ async def test_authorized_disable_resource(self): async def test_unauthorized_enable_prompt(self): """Test that unauthenticated requests to enable a prompt are rejected.""" - self.mcp.disable(names={"test_prompt"}, components=["prompt"]) + self.mcp.disable(names={"test_prompt"}, components={"prompt"}) prompts = await self.mcp.list_prompts() assert not any(p.name == "test_prompt" for p in prompts) @@ -351,7 +351,7 @@ async def test_unauthorized_enable_prompt(self): async def test_authorized_enable_prompt(self): """Test that authenticated requests to enable a prompt are allowed.""" - self.mcp.disable(names={"test_prompt"}, components=["prompt"]) + self.mcp.disable(names={"test_prompt"}, components={"prompt"}) prompts = await self.mcp.list_prompts() assert not any(p.name == "test_prompt" for p in prompts) @@ -429,7 +429,7 @@ def client_with_path(self, mcp_with_path): return TestClient(mcp_with_path.http_app()) async def test_enable_tool_route_with_path(self, client_with_path, mcp_with_path): - mcp_with_path.disable(names={"test_tool"}, components=["tool"]) + mcp_with_path.disable(names={"test_tool"}, components={"tool"}) tools = await mcp_with_path.list_tools() assert not any(t.name == "test_tool" for t in tools) response = client_with_path.post("/test/tools/test_tool/enable") @@ -450,7 +450,7 @@ async def test_disable_resource_route_with_path( assert not any(str(r.uri) == "data://test_resource" for r in resources) async def test_enable_prompt_route_with_path(self, client_with_path, mcp_with_path): - mcp_with_path.disable(names={"test_prompt"}, components=["prompt"]) + mcp_with_path.disable(names={"test_prompt"}, components={"prompt"}) prompts = await mcp_with_path.list_prompts() assert not any(p.name == "test_prompt" for p in prompts) response = client_with_path.post("/test/prompts/test_prompt/enable") @@ -503,7 +503,7 @@ def test_prompt() -> str: self.client = TestClient(self.mcp.http_app()) async def test_unauthorized_enable_tool(self): - self.mcp.disable(names={"test_tool"}, components=["tool"]) + self.mcp.disable(names={"test_tool"}, components={"tool"}) tools = await self.mcp.list_tools() assert not any(t.name == "test_tool" for t in tools) response = self.client.post("/test/tools/test_tool/enable") @@ -512,7 +512,7 @@ async def test_unauthorized_enable_tool(self): assert not any(t.name == "test_tool" for t in tools) async def test_forbidden_enable_tool(self): - self.mcp.disable(names={"test_tool"}, components=["tool"]) + self.mcp.disable(names={"test_tool"}, components={"tool"}) tools = await self.mcp.list_tools() assert not any(t.name == "test_tool" for t in tools) response = self.client.post( @@ -524,7 +524,7 @@ async def test_forbidden_enable_tool(self): assert not any(t.name == "test_tool" for t in tools) async def test_authorized_enable_tool(self): - self.mcp.disable(names={"test_tool"}, components=["tool"]) + self.mcp.disable(names={"test_tool"}, components={"tool"}) tools = await self.mcp.list_tools() assert not any(t.name == "test_tool" for t in tools) response = self.client.post( @@ -568,7 +568,7 @@ async def test_authorized_disable_resource(self): assert not any(str(r.uri) == "data://test_resource" for r in resources) async def test_unauthorized_enable_prompt(self): - self.mcp.disable(names={"test_prompt"}, components=["prompt"]) + self.mcp.disable(names={"test_prompt"}, components={"prompt"}) prompts = await self.mcp.list_prompts() assert not any(p.name == "test_prompt" for p in prompts) response = self.client.post("/test/prompts/test_prompt/enable") @@ -577,7 +577,7 @@ async def test_unauthorized_enable_prompt(self): assert not any(p.name == "test_prompt" for p in prompts) async def test_forbidden_enable_prompt(self): - self.mcp.disable(names={"test_prompt"}, components=["prompt"]) + self.mcp.disable(names={"test_prompt"}, components={"prompt"}) prompts = await self.mcp.list_prompts() assert not any(p.name == "test_prompt" for p in prompts) response = self.client.post( @@ -589,7 +589,7 @@ async def test_forbidden_enable_prompt(self): assert not any(p.name == "test_prompt" for p in prompts) async def test_authorized_enable_prompt(self): - self.mcp.disable(names={"test_prompt"}, components=["prompt"]) + self.mcp.disable(names={"test_prompt"}, components={"prompt"}) prompts = await self.mcp.list_prompts() assert not any(p.name == "test_prompt" for p in prompts) response = self.client.post( diff --git a/tests/deprecated/server/test_include_exclude_tags.py b/tests/deprecated/server/test_include_exclude_tags.py index c64f400f7b..223f70048a 100644 --- a/tests/deprecated/server/test_include_exclude_tags.py +++ b/tests/deprecated/server/test_include_exclude_tags.py @@ -29,7 +29,7 @@ def test_exclude_tags_still_works(self): assert len(enabled_transforms) == 1 e = enabled_transforms[0] assert e._enabled is False - assert e.tags == frozenset({"internal"}) + assert e.tags == {"internal"} def test_include_tags_still_works(self): """include_tags adds Enabled transforms for allowlist mode.""" @@ -49,7 +49,7 @@ def test_include_tags_still_works(self): # Second should enable matching tags enable_transform = enabled_transforms[1] assert enable_transform._enabled is True - assert enable_transform.tags == frozenset({"public"}) + assert enable_transform.tags == {"public"} def test_exclude_and_include_both_create_transforms(self): """exclude_tags and include_tags both create transforms.""" @@ -63,6 +63,6 @@ def test_exclude_and_include_both_create_transforms(self): assert len(enabled_transforms) == 3 # Check we have both tag rules - tags_in_transforms = {frozenset(t.tags) for t in enabled_transforms if t.tags} - assert frozenset({"public"}) in tags_in_transforms - assert frozenset({"deprecated"}) in tags_in_transforms + tags_in_transforms = [t.tags for t in enabled_transforms if t.tags] + assert {"public"} in tags_in_transforms + assert {"deprecated"} in tags_in_transforms diff --git a/tests/server/middleware/test_middleware.py b/tests/server/middleware/test_middleware.py index 096080bc09..fdf523a22a 100644 --- a/tests/server/middleware/test_middleware.py +++ b/tests/server/middleware/test_middleware.py @@ -926,7 +926,7 @@ async def test_call_tool( self, mcp_server: FastMCP, recording_middleware: RecordingMiddleware ): # proxy server will have its tools listed as well as called in order to - # run the `should_enable_component` hook prior to the call. + # apply transforms and filters prior to the call. proxy_server = FastMCP.as_proxy(mcp_server, name="Proxy Server") async with Client(proxy_server) as client: await client.call_tool("add", {"a": 1, "b": 2}) diff --git a/tests/server/providers/test_local_provider_prompts.py b/tests/server/providers/test_local_provider_prompts.py index 234534f802..6b0400e647 100644 --- a/tests/server/providers/test_local_provider_prompts.py +++ b/tests/server/providers/test_local_provider_prompts.py @@ -330,12 +330,12 @@ def sample_prompt() -> str: prompts = await mcp.list_prompts() assert any(p.name == "sample_prompt" for p in prompts) - mcp.disable(names={"sample_prompt"}, components=["prompt"]) + mcp.disable(names={"sample_prompt"}, components={"prompt"}) prompts = await mcp.list_prompts() assert not any(p.name == "sample_prompt" for p in prompts) - mcp.enable(names={"sample_prompt"}, components=["prompt"]) + mcp.enable(names={"sample_prompt"}, components={"prompt"}) prompts = await mcp.list_prompts() assert any(p.name == "sample_prompt" for p in prompts) @@ -347,7 +347,7 @@ async def test_prompt_disabled(self): def sample_prompt() -> str: return "Hello, world!" - mcp.disable(names={"sample_prompt"}, components=["prompt"]) + mcp.disable(names={"sample_prompt"}, components={"prompt"}) prompts = await mcp.list_prompts() assert len(prompts) == 0 @@ -358,11 +358,11 @@ async def test_prompt_toggle_enabled(self): def sample_prompt() -> str: return "Hello, world!" - mcp.disable(names={"sample_prompt"}, components=["prompt"]) + mcp.disable(names={"sample_prompt"}, components={"prompt"}) prompts = await mcp.list_prompts() assert not any(p.name == "sample_prompt" for p in prompts) - mcp.enable(names={"sample_prompt"}, components=["prompt"]) + mcp.enable(names={"sample_prompt"}, components={"prompt"}) prompts = await mcp.list_prompts() assert len(prompts) == 1 @@ -373,7 +373,7 @@ async def test_prompt_toggle_disabled(self): def sample_prompt() -> str: return "Hello, world!" - mcp.disable(names={"sample_prompt"}, components=["prompt"]) + mcp.disable(names={"sample_prompt"}, components={"prompt"}) prompts = await mcp.list_prompts() assert len(prompts) == 0 @@ -391,7 +391,7 @@ def sample_prompt() -> str: prompt = await mcp.get_prompt("sample_prompt") assert prompt is not None - mcp.disable(names={"sample_prompt"}, components=["prompt"]) + mcp.disable(names={"sample_prompt"}, components={"prompt"}) prompts = await mcp.list_prompts() assert len(prompts) == 0 @@ -406,7 +406,7 @@ async def test_cant_get_disabled_prompt(self): def sample_prompt() -> str: return "Hello, world!" - mcp.disable(names={"sample_prompt"}, components=["prompt"]) + mcp.disable(names={"sample_prompt"}, components={"prompt"}) # get_prompt() applies enabled transform, returns None for disabled prompt = await mcp.get_prompt("sample_prompt") diff --git a/tests/server/providers/test_local_provider_resources.py b/tests/server/providers/test_local_provider_resources.py index b18ffc10f4..972b9358b0 100644 --- a/tests/server/providers/test_local_provider_resources.py +++ b/tests/server/providers/test_local_provider_resources.py @@ -738,12 +738,12 @@ def sample_resource() -> str: resources = await mcp.list_resources() assert any(str(r.uri) == "resource://data" for r in resources) - mcp.disable(names={"resource://data"}, components=["resource"]) + mcp.disable(names={"resource://data"}, components={"resource"}) resources = await mcp.list_resources() assert not any(str(r.uri) == "resource://data" for r in resources) - mcp.enable(names={"resource://data"}, components=["resource"]) + mcp.enable(names={"resource://data"}, components={"resource"}) resources = await mcp.list_resources() assert any(str(r.uri) == "resource://data" for r in resources) @@ -755,7 +755,7 @@ async def test_resource_disabled(self): def sample_resource() -> str: return "Hello, world!" - mcp.disable(names={"resource://data"}, components=["resource"]) + mcp.disable(names={"resource://data"}, components={"resource"}) resources = await mcp.list_resources() assert len(resources) == 0 @@ -769,11 +769,11 @@ async def test_resource_toggle_enabled(self): def sample_resource() -> str: return "Hello, world!" - mcp.disable(names={"resource://data"}, components=["resource"]) + mcp.disable(names={"resource://data"}, components={"resource"}) resources = await mcp.list_resources() assert not any(str(r.uri) == "resource://data" for r in resources) - mcp.enable(names={"resource://data"}, components=["resource"]) + mcp.enable(names={"resource://data"}, components={"resource"}) resources = await mcp.list_resources() assert len(resources) == 1 @@ -784,7 +784,7 @@ async def test_resource_toggle_disabled(self): def sample_resource() -> str: return "Hello, world!" - mcp.disable(names={"resource://data"}, components=["resource"]) + mcp.disable(names={"resource://data"}, components={"resource"}) resources = await mcp.list_resources() assert len(resources) == 0 @@ -801,7 +801,7 @@ def sample_resource() -> str: resource = await mcp.get_resource("resource://data") assert resource is not None - mcp.disable(names={"resource://data"}, components=["resource"]) + mcp.disable(names={"resource://data"}, components={"resource"}) resources = await mcp.list_resources() assert len(resources) == 0 @@ -815,7 +815,7 @@ async def test_cant_read_disabled_resource(self): def sample_resource() -> str: return "Hello, world!" - mcp.disable(names={"resource://data"}, components=["resource"]) + mcp.disable(names={"resource://data"}, components={"resource"}) with pytest.raises(NotFoundError, match="Unknown resource"): await mcp.read_resource("resource://data") @@ -891,12 +891,12 @@ def sample_template(param: str) -> str: templates = await mcp.list_resource_templates() assert any(t.uri_template == "resource://{param}" for t in templates) - mcp.disable(names={"resource://{param}"}, components=["template"]) + mcp.disable(names={"resource://{param}"}, components={"template"}) templates = await mcp.list_resource_templates() assert not any(t.uri_template == "resource://{param}" for t in templates) - mcp.enable(names={"resource://{param}"}, components=["template"]) + mcp.enable(names={"resource://{param}"}, components={"template"}) templates = await mcp.list_resource_templates() assert any(t.uri_template == "resource://{param}" for t in templates) @@ -908,7 +908,7 @@ async def test_template_disabled(self): def sample_template(param: str) -> str: return f"Template: {param}" - mcp.disable(names={"resource://{param}"}, components=["template"]) + mcp.disable(names={"resource://{param}"}, components={"template"}) templates = await mcp.list_resource_templates() assert len(templates) == 0 @@ -922,11 +922,11 @@ async def test_template_toggle_enabled(self): def sample_template(param: str) -> str: return f"Template: {param}" - mcp.disable(names={"resource://{param}"}, components=["template"]) + mcp.disable(names={"resource://{param}"}, components={"template"}) templates = await mcp.list_resource_templates() assert not any(t.uri_template == "resource://{param}" for t in templates) - mcp.enable(names={"resource://{param}"}, components=["template"]) + mcp.enable(names={"resource://{param}"}, components={"template"}) templates = await mcp.list_resource_templates() assert len(templates) == 1 @@ -937,7 +937,7 @@ async def test_template_toggle_disabled(self): def sample_template(param: str) -> str: return f"Template: {param}" - mcp.disable(names={"resource://{param}"}, components=["template"]) + mcp.disable(names={"resource://{param}"}, components={"template"}) templates = await mcp.list_resource_templates() assert len(templates) == 0 @@ -954,7 +954,7 @@ def sample_template(param: str) -> str: template = await mcp.get_resource_template("resource://{param}") assert template is not None - mcp.disable(names={"resource://{param}"}, components=["template"]) + mcp.disable(names={"resource://{param}"}, components={"template"}) templates = await mcp.list_resource_templates() assert len(templates) == 0 @@ -968,7 +968,7 @@ async def test_cant_read_disabled_template(self): def sample_template(param: str) -> str: return f"Template: {param}" - mcp.disable(names={"resource://{param}"}, components=["template"]) + mcp.disable(names={"resource://{param}"}, components={"template"}) with pytest.raises(NotFoundError, match="Unknown resource"): await mcp.read_resource("resource://test") diff --git a/tests/server/providers/test_local_provider_tools.py b/tests/server/providers/test_local_provider_tools.py index d5a725bae8..fb000519be 100644 --- a/tests/server/providers/test_local_provider_tools.py +++ b/tests/server/providers/test_local_provider_tools.py @@ -1473,14 +1473,14 @@ def sample_tool(x: int) -> int: assert any(t.name == "sample_tool" for t in tools) # Disable via server - mcp.disable(names={"sample_tool"}, components=["tool"]) + mcp.disable(names={"sample_tool"}, components={"tool"}) # Tool should not be in list when disabled tools = await mcp.list_tools() assert not any(t.name == "sample_tool" for t in tools) # Re-enable via server - mcp.enable(names={"sample_tool"}, components=["tool"]) + mcp.enable(names={"sample_tool"}, components={"tool"}) tools = await mcp.list_tools() assert any(t.name == "sample_tool" for t in tools) @@ -1491,7 +1491,7 @@ async def test_tool_disabled_via_server(self): def sample_tool(x: int) -> int: return x * 2 - mcp.disable(names={"sample_tool"}, components=["tool"]) + mcp.disable(names={"sample_tool"}, components={"tool"}) tools = await mcp.list_tools() assert len(tools) == 0 @@ -1505,8 +1505,8 @@ async def test_tool_toggle_enabled(self): def sample_tool(x: int) -> int: return x * 2 - mcp.disable(names={"sample_tool"}, components=["tool"]) - mcp.enable(names={"sample_tool"}, components=["tool"]) + mcp.disable(names={"sample_tool"}, components={"tool"}) + mcp.enable(names={"sample_tool"}, components={"tool"}) tools = await mcp.list_tools() assert len(tools) == 1 @@ -1517,7 +1517,7 @@ async def test_tool_toggle_disabled(self): def sample_tool(x: int) -> int: return x * 2 - mcp.disable(names={"sample_tool"}, components=["tool"]) + mcp.disable(names={"sample_tool"}, components={"tool"}) tools = await mcp.list_tools() assert len(tools) == 0 @@ -1534,7 +1534,7 @@ def sample_tool(x: int) -> int: tool = await mcp.get_tool("sample_tool") assert tool is not None - mcp.disable(names={"sample_tool"}, components=["tool"]) + mcp.disable(names={"sample_tool"}, components={"tool"}) tools = await mcp.list_tools() assert len(tools) == 0 @@ -1548,7 +1548,7 @@ async def test_cant_call_disabled_tool(self): def sample_tool(x: int) -> int: return x * 2 - mcp.disable(names={"sample_tool"}, components=["tool"]) + mcp.disable(names={"sample_tool"}, components={"tool"}) with pytest.raises(NotFoundError, match="Unknown tool"): await mcp.call_tool("sample_tool", {"x": 5}) diff --git a/tests/server/test_mount.py b/tests/server/test_mount.py index fea71bf2e6..2dae73a23a 100644 --- a/tests/server/test_mount.py +++ b/tests/server/test_mount.py @@ -1533,12 +1533,12 @@ def my_tool() -> str: assert any(t.name == "my_tool" for t in tools) # Disable and re-enable - main_app.disable(names={"my_tool"}, components=["tool"]) + main_app.disable(names={"my_tool"}, components={"tool"}) # Verify tool is now disabled tools = await main_app.list_tools() assert not any(t.name == "my_tool" for t in tools) - main_app.enable(names={"my_tool"}, components=["tool"]) + main_app.enable(names={"my_tool"}, components={"tool"}) # Verify tool is now enabled tools = await main_app.list_tools() assert any(t.name == "my_tool" for t in tools) @@ -1556,12 +1556,12 @@ def my_resource() -> str: main_app.mount(sub_app) # Disable and re-enable - main_app.disable(names={"data://test"}, components=["resource"]) + main_app.disable(names={"data://test"}, components={"resource"}) # Verify resource is now disabled resources = await main_app.list_resources() assert not any(str(r.uri) == "data://test" for r in resources) - main_app.enable(names={"data://test"}, components=["resource"]) + main_app.enable(names={"data://test"}, components={"resource"}) # Verify resource is now enabled resources = await main_app.list_resources() assert any(str(r.uri) == "data://test" for r in resources) @@ -1579,12 +1579,12 @@ def my_prompt() -> str: main_app.mount(sub_app) # Disable and re-enable - main_app.disable(names={"my_prompt"}, components=["prompt"]) + main_app.disable(names={"my_prompt"}, components={"prompt"}) # Verify prompt is now disabled prompts = await main_app.list_prompts() assert not any(p.name == "my_prompt" for p in prompts) - main_app.enable(names={"my_prompt"}, components=["prompt"]) + main_app.enable(names={"my_prompt"}, components={"prompt"}) # Verify prompt is now enabled prompts = await main_app.list_prompts() assert any(p.name == "my_prompt" for p in prompts) diff --git a/tests/server/test_server.py b/tests/server/test_server.py index 045e6b2026..12166715c5 100644 --- a/tests/server/test_server.py +++ b/tests/server/test_server.py @@ -212,7 +212,7 @@ def dummy_tool() -> str: middleware=(), # Empty tuple tools=(Tool.from_function(dummy_tool),), # Tuple of tools include_tags={"tag1", "tag2"}, # Set - exclude_tags=frozenset({"tag3"}), # Frozen set + exclude_tags={"tag3"}, # Set ) assert mcp is not None assert mcp.name == "test" diff --git a/tests/server/test_session_visibility.py b/tests/server/test_session_visibility.py new file mode 100644 index 0000000000..ca072c3f67 --- /dev/null +++ b/tests/server/test_session_visibility.py @@ -0,0 +1,620 @@ +"""Tests for session-specific visibility control via Context.""" + +from dataclasses import dataclass, field +from datetime import datetime + +import anyio +import mcp.types + +from fastmcp.client.messages import MessageHandler +from fastmcp.server.context import Context +from fastmcp.server.server import FastMCP + + +@dataclass +class NotificationRecording: + """Record of a notification that was received.""" + + method: str + notification: mcp.types.ServerNotification + timestamp: datetime = field(default_factory=datetime.now) + + +class RecordingMessageHandler(MessageHandler): + """A message handler that records all notifications.""" + + def __init__(self): + super().__init__() + self.notifications: list[NotificationRecording] = [] + + async def on_notification(self, message: mcp.types.ServerNotification) -> None: + """Record all notifications with timestamp.""" + self.notifications.append( + NotificationRecording(method=message.root.method, notification=message) + ) + + def get_notifications( + self, method: str | None = None + ) -> list[NotificationRecording]: + """Get all recorded notifications, optionally filtered by method.""" + if method is None: + return self.notifications + return [n for n in self.notifications if n.method == method] + + def reset(self): + """Clear all recorded notifications.""" + self.notifications.clear() + + +class TestSessionVisibility: + """Test session-specific visibility control via Context.""" + + async def test_enable_components_stores_rule_dict(self): + """Test that enable_components stores a rule dict in session state.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool(tags={"finance"}) + def finance_tool() -> str: + return "finance" + + @mcp.tool + async def activate_finance(ctx: Context) -> str: + await ctx.enable_components(tags={"finance"}) + # Check that the rule was stored + rules = await ctx._get_visibility_rules() + assert len(rules) == 1 + assert rules[0]["enabled"] is True + assert rules[0]["tags"] == ["finance"] + return "activated" + + async with Client(mcp) as client: + result = await client.call_tool("activate_finance", {}) + assert result.data == "activated" + + async def test_disable_components_stores_rule_dict(self): + """Test that disable_components stores a rule dict in session state.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool(tags={"internal"}) + def internal_tool() -> str: + return "internal" + + @mcp.tool + async def deactivate_internal(ctx: Context) -> str: + await ctx.disable_components(tags={"internal"}) + # Check that the rule was stored + rules = await ctx._get_visibility_rules() + assert len(rules) == 1 + assert rules[0]["enabled"] is False + assert rules[0]["tags"] == ["internal"] + return "deactivated" + + async with Client(mcp) as client: + result = await client.call_tool("deactivate_internal", {}) + assert result.data == "deactivated" + + async def test_session_rules_override_global_disables(self): + """Test that session enable rules override global disable transforms.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool(tags={"finance"}) + def finance_tool() -> str: + return "finance" + + @mcp.tool + async def activate_finance(ctx: Context) -> str: + await ctx.enable_components(tags={"finance"}) + return "activated" + + # Globally disable finance tools + mcp.disable(tags={"finance"}) + + async with Client(mcp) as client: + # Before activation, finance tool should not be visible + tools_before = await client.list_tools() + assert not any(t.name == "finance_tool" for t in tools_before) + + # Activate finance for this session + await client.call_tool("activate_finance", {}) + + # After activation, finance tool should be visible in this session + tools_after = await client.list_tools() + assert any(t.name == "finance_tool" for t in tools_after) + + async def test_rules_persist_across_requests(self): + """Test that session rules persist across multiple requests.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool(tags={"finance"}) + def finance_tool() -> str: + return "finance" + + @mcp.tool + async def activate_finance(ctx: Context) -> str: + await ctx.enable_components(tags={"finance"}) + return "activated" + + @mcp.tool + async def check_rules(ctx: Context) -> int: + rules = await ctx._get_visibility_rules() + return len(rules) + + # Globally disable finance tools + mcp.disable(tags={"finance"}) + + async with Client(mcp) as client: + # Activate finance + await client.call_tool("activate_finance", {}) + + # In a subsequent request, rules should still be there + result = await client.call_tool("check_rules", {}) + assert result.data == 1 + + # And finance tool should still be visible + tools = await client.list_tools() + assert any(t.name == "finance_tool" for t in tools) + + async def test_rules_isolated_between_sessions(self): + """Test that session rules are isolated between different sessions.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool(tags={"finance"}) + def finance_tool() -> str: + return "finance" + + @mcp.tool + async def activate_finance(ctx: Context) -> str: + await ctx.enable_components(tags={"finance"}) + return "activated" + + # Globally disable finance tools + mcp.disable(tags={"finance"}) + + # Session A activates finance + async with Client(mcp) as client_a: + await client_a.call_tool("activate_finance", {}) + tools_a = await client_a.list_tools() + assert any(t.name == "finance_tool" for t in tools_a) + + # Session B should not see finance tool (different session) + async with Client(mcp) as client_b: + tools_b = await client_b.list_tools() + assert not any(t.name == "finance_tool" for t in tools_b) + + async def test_version_spec_serialization(self): + """Test that VersionSpec is serialized/deserialized correctly.""" + from fastmcp import Client + from fastmcp.utilities.versions import VersionSpec + + mcp = FastMCP("test") + + @mcp.tool(version="1.0.0") + def old_tool() -> str: + return "old" + + @mcp.tool(version="2.0.0") + def new_tool() -> str: + return "new" + + @mcp.tool + async def enable_v2_only(ctx: Context) -> str: + await ctx.enable_components(version=VersionSpec(gte="2.0.0")) + # Check serialization - version is stored as a dict + rules = await ctx._get_visibility_rules() + assert rules[0]["version"]["gte"] == "2.0.0" + assert rules[0]["version"]["lt"] is None + assert rules[0]["version"]["eq"] is None + return "enabled" + + # Globally disable all versioned tools + mcp.disable(names={"old_tool", "new_tool"}) + + async with Client(mcp) as client: + # Enable v2 tools + await client.call_tool("enable_v2_only", {}) + + # Should see new_tool (v2.0.0) but not old_tool (v1.0.0) + tools = await client.list_tools() + assert any(t.name == "new_tool" for t in tools) + assert not any(t.name == "old_tool" for t in tools) + + async def test_clear_visibility_rules(self): + """Test that reset_components removes all session rules.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool(tags={"finance"}) + def finance_tool() -> str: + return "finance" + + @mcp.tool + async def activate_finance(ctx: Context) -> str: + await ctx.enable_components(tags={"finance"}) + return "activated" + + @mcp.tool + async def clear_rules(ctx: Context) -> str: + await ctx.reset_components() + rules = await ctx._get_visibility_rules() + assert len(rules) == 0 + return "cleared" + + # Globally disable finance tools + mcp.disable(tags={"finance"}) + + async with Client(mcp) as client: + # Activate finance + await client.call_tool("activate_finance", {}) + tools_after_activate = await client.list_tools() + assert any(t.name == "finance_tool" for t in tools_after_activate) + + # Clear rules + await client.call_tool("clear_rules", {}) + + # Finance tool should no longer be visible (back to global disable) + tools_after_clear = await client.list_tools() + assert not any(t.name == "finance_tool" for t in tools_after_clear) + + async def test_multiple_rules_accumulate(self): + """Test that multiple enable/disable calls accumulate rules.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool(tags={"finance"}) + def finance_tool() -> str: + return "finance" + + @mcp.tool(tags={"admin"}) + def admin_tool() -> str: + return "admin" + + @mcp.tool + async def activate_multiple(ctx: Context) -> str: + await ctx.enable_components(tags={"finance"}) + await ctx.enable_components(tags={"admin"}) + rules = await ctx._get_visibility_rules() + assert len(rules) == 2 + return "activated" + + # Globally disable finance and admin tools + mcp.disable(tags={"finance", "admin"}) + + async with Client(mcp) as client: + # Activate both + await client.call_tool("activate_multiple", {}) + + # Both should be visible + tools = await client.list_tools() + assert any(t.name == "finance_tool" for t in tools) + assert any(t.name == "admin_tool" for t in tools) + + async def test_later_rules_override_earlier_rules(self): + """Test that later session rules override earlier ones (mark semantics).""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool(tags={"test"}) + def test_tool() -> str: + return "test" + + @mcp.tool + async def toggle_test(ctx: Context) -> str: + # First enable, then disable + await ctx.enable_components(tags={"test"}) + await ctx.disable_components(tags={"test"}) + return "toggled" + + async with Client(mcp) as client: + # Toggle (enable then disable) + await client.call_tool("toggle_test", {}) + + # The disable should win (later mark overrides earlier) + tools = await client.list_tools() + assert not any(t.name == "test_tool" for t in tools) + + async def test_session_transforms_apply_to_resources(self): + """Test that session transforms apply to resources too.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.resource("resource://finance", tags={"finance"}) + def finance_resource() -> str: + return "finance data" + + @mcp.tool + async def activate_finance(ctx: Context) -> str: + await ctx.enable_components(tags={"finance"}) + return "activated" + + # Globally disable finance resources + mcp.disable(tags={"finance"}) + + async with Client(mcp) as client: + # Before activation, finance resource should not be visible + resources_before = await client.list_resources() + assert not any(str(r.uri) == "resource://finance" for r in resources_before) + + # Activate finance for this session + await client.call_tool("activate_finance", {}) + + # After activation, finance resource should be visible + resources_after = await client.list_resources() + assert any(str(r.uri) == "resource://finance" for r in resources_after) + + async def test_session_transforms_apply_to_prompts(self): + """Test that session transforms apply to prompts too.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.prompt(tags={"finance"}) + def finance_prompt() -> str: + return "finance prompt" + + @mcp.tool + async def activate_finance(ctx: Context) -> str: + await ctx.enable_components(tags={"finance"}) + return "activated" + + # Globally disable finance prompts + mcp.disable(tags={"finance"}) + + async with Client(mcp) as client: + # Before activation, finance prompt should not be visible + prompts_before = await client.list_prompts() + assert not any(p.name == "finance_prompt" for p in prompts_before) + + # Activate finance for this session + await client.call_tool("activate_finance", {}) + + # After activation, finance prompt should be visible + prompts_after = await client.list_prompts() + assert any(p.name == "finance_prompt" for p in prompts_after) + + +class TestSessionVisibilityNotifications: + """Test that notifications are sent when session visibility changes.""" + + async def test_enable_components_sends_notifications(self): + """Test that enable_components sends all three notification types.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool + async def activate(ctx: Context) -> str: + await ctx.enable_components(tags={"finance"}) + return "activated" + + handler = RecordingMessageHandler() + async with Client(mcp, message_handler=handler) as client: + handler.reset() + await client.call_tool("activate", {}) + + # Should receive all three notifications + tool_notifications = handler.get_notifications( + "notifications/tools/list_changed" + ) + resource_notifications = handler.get_notifications( + "notifications/resources/list_changed" + ) + prompt_notifications = handler.get_notifications( + "notifications/prompts/list_changed" + ) + assert len(tool_notifications) == 1 + assert len(resource_notifications) == 1 + assert len(prompt_notifications) == 1 + + async def test_disable_components_sends_notifications(self): + """Test that disable_components sends all three notification types.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool + async def deactivate(ctx: Context) -> str: + await ctx.disable_components(tags={"finance"}) + return "deactivated" + + handler = RecordingMessageHandler() + async with Client(mcp, message_handler=handler) as client: + handler.reset() + await client.call_tool("deactivate", {}) + + # Should receive all three notifications + assert ( + len(handler.get_notifications("notifications/tools/list_changed")) == 1 + ) + assert ( + len(handler.get_notifications("notifications/resources/list_changed")) + == 1 + ) + assert ( + len(handler.get_notifications("notifications/prompts/list_changed")) + == 1 + ) + + async def test_clear_visibility_rules_sends_notifications(self): + """Test that reset_components sends notifications.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool + async def clear(ctx: Context) -> str: + await ctx.reset_components() + return "cleared" + + handler = RecordingMessageHandler() + async with Client(mcp, message_handler=handler) as client: + handler.reset() + await client.call_tool("clear", {}) + + # Should receive all three notifications + assert ( + len(handler.get_notifications("notifications/tools/list_changed")) == 1 + ) + assert ( + len(handler.get_notifications("notifications/resources/list_changed")) + == 1 + ) + assert ( + len(handler.get_notifications("notifications/prompts/list_changed")) + == 1 + ) + + async def test_components_hint_limits_notifications(self): + """Test that the components hint limits which notifications are sent.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool + async def activate_tools_only(ctx: Context) -> str: + # Only specify tool components - should only send tool notification + await ctx.enable_components(tags={"finance"}, components={"tool"}) + return "activated" + + handler = RecordingMessageHandler() + async with Client(mcp, message_handler=handler) as client: + handler.reset() + await client.call_tool("activate_tools_only", {}) + + # Should only receive tool notification + assert ( + len(handler.get_notifications("notifications/tools/list_changed")) == 1 + ) + assert ( + len(handler.get_notifications("notifications/resources/list_changed")) + == 0 + ) + assert ( + len(handler.get_notifications("notifications/prompts/list_changed")) + == 0 + ) + + +class TestConcurrentSessionIsolation: + """Test that concurrent sessions don't leak visibility transforms.""" + + async def test_concurrent_sessions_isolated(self): + """Test that two concurrent clients don't leak session transforms.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool(tags={"finance"}) + def finance_tool() -> str: + return "finance" + + @mcp.tool + async def activate_finance(ctx: Context) -> str: + await ctx.enable_components(tags={"finance"}) + return "activated" + + # Globally disable finance tools + mcp.disable(tags={"finance"}) + + # Track what each session sees + session_a_sees_finance = False + session_b_sees_finance = False + ready_event = anyio.Event() + + async def session_a(): + nonlocal session_a_sees_finance + async with Client(mcp) as client: + # Activate finance for this session + await client.call_tool("activate_finance", {}) + + # Signal that session A has activated + ready_event.set() + + # Check that session A sees finance tool + tools = await client.list_tools() + session_a_sees_finance = any(t.name == "finance_tool" for t in tools) + + # Keep session A alive while session B checks + await anyio.sleep(0.2) + + async def session_b(): + nonlocal session_b_sees_finance + # Wait for session A to activate + await ready_event.wait() + + async with Client(mcp) as client: + # Session B should NOT see finance tool + tools = await client.list_tools() + session_b_sees_finance = any(t.name == "finance_tool" for t in tools) + + async with anyio.create_task_group() as tg: + tg.start_soon(session_a) + tg.start_soon(session_b) + + # Session A should see finance, session B should not + assert session_a_sees_finance is True, "Session A should see finance tool" + assert session_b_sees_finance is False, "Session B should NOT see finance tool" + + async def test_many_concurrent_sessions_isolated(self): + """Test that many concurrent sessions remain properly isolated.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool(tags={"premium"}) + def premium_tool() -> str: + return "premium" + + @mcp.tool + async def activate_premium(ctx: Context) -> str: + await ctx.enable_components(tags={"premium"}) + return "activated" + + # Globally disable premium tools + mcp.disable(tags={"premium"}) + + results: dict[str, bool] = {} + + async def activated_session(session_id: str): + async with Client(mcp) as client: + await client.call_tool("activate_premium", {}) + tools = await client.list_tools() + results[session_id] = any(t.name == "premium_tool" for t in tools) + + async def non_activated_session(session_id: str): + async with Client(mcp) as client: + tools = await client.list_tools() + results[session_id] = any(t.name == "premium_tool" for t in tools) + + async with anyio.create_task_group() as tg: + # Start 5 activated sessions + for i in range(5): + tg.start_soon(activated_session, f"activated_{i}") + # Start 5 non-activated sessions + for i in range(5): + tg.start_soon(non_activated_session, f"non_activated_{i}") + + # All activated sessions should see premium tool + for i in range(5): + assert results[f"activated_{i}"] is True, ( + f"Activated session {i} should see premium tool" + ) + + # All non-activated sessions should NOT see premium tool + for i in range(5): + assert results[f"non_activated_{i}"] is False, ( + f"Non-activated session {i} should NOT see premium tool" + ) diff --git a/tests/server/transforms/test_enabled.py b/tests/server/transforms/test_enabled.py index 6dca8e56b9..c31083cd56 100644 --- a/tests/server/transforms/test_enabled.py +++ b/tests/server/transforms/test_enabled.py @@ -65,13 +65,13 @@ def test_unversioned_does_not_match_version_spec(self): def test_match_by_tag(self): """Matches if component has any of the specified tags.""" - t = Enabled(False, tags=frozenset({"internal", "deprecated"})) + t = Enabled(False, tags=set({"internal", "deprecated"})) assert t._matches(Tool(name="foo", parameters={}, tags={"internal"})) is True assert t._matches(Tool(name="foo", parameters={}, tags={"public"})) is False def test_match_by_component_type(self): """Only matches specified component types.""" - t = Enabled(False, names={"foo"}, components=frozenset({"prompt"})) + t = Enabled(False, names={"foo"}, components={"prompt"}) # Tool has key "tool:foo@", not "prompt:foo@" assert t._matches(Tool(name="foo", parameters={})) is False @@ -81,7 +81,7 @@ def test_all_criteria_must_match(self): False, names={"foo"}, version=VersionSpec(eq="v1"), - tags=frozenset({"internal"}), + tags=set({"internal"}), ) # All match assert ( @@ -234,7 +234,7 @@ def tools(self): async def test_list_tools_marks_matching(self, tools): """list_tools applies marks to matching components.""" - disable_internal = Enabled(False, tags=frozenset({"internal"})) + disable_internal = Enabled(False, tags=set({"internal"})) async def base(): return tools @@ -248,8 +248,8 @@ async def base(): async def test_later_transform_overrides(self, tools): """Later transforms in chain override earlier ones.""" - disable_internal = Enabled(False, tags=frozenset({"internal"})) - enable_safe = Enabled(True, tags=frozenset({"safe"})) + disable_internal = Enabled(False, tags=set({"internal"})) + enable_safe = Enabled(True, tags=set({"safe"})) async def base(): return tools @@ -268,7 +268,7 @@ async def after_disable(): async def test_allowlist_pattern(self, tools): """Disable all, then enable specific = allowlist.""" disable_all = Enabled(False, match_all=True) - enable_public = Enabled(True, tags=frozenset({"public"})) + enable_public = Enabled(True, tags=set({"public"})) async def base(): return tools diff --git a/tests/tools/test_tool_transform.py b/tests/tools/test_tool_transform.py index 369475615c..48f7db81cf 100644 --- a/tests/tools/test_tool_transform.py +++ b/tests/tools/test_tool_transform.py @@ -1061,7 +1061,7 @@ def add(x: int, y: int = 10) -> int: mcp.add_tool(new_add) # Disable original tool, but new_add should still work - mcp.disable(names={"add"}, components=["tool"]) + mcp.disable(names={"add"}, components={"tool"}) async with Client(mcp) as client: tools = await client.list_tools() @@ -1088,8 +1088,8 @@ def add(x: int, y: int = 10) -> int: mcp.add_tool(new_add) # Disable both tools via server - mcp.disable(names={"add"}, components=["tool"]).disable( - names={"new_add"}, components=["tool"] + mcp.disable(names={"add"}, components={"tool"}).disable( + names={"new_add"}, components={"tool"} ) async with Client(mcp) as client: