diff --git a/src/fastmcp/server/transforms/visibility.py b/src/fastmcp/server/transforms/visibility.py index 41061a6510..5a35888866 100644 --- a/src/fastmcp/server/transforms/visibility.py +++ b/src/fastmcp/server/transforms/visibility.py @@ -171,23 +171,23 @@ def _matches(self, component: FastMCPComponent) -> bool: return self.tags is None or bool(component.tags & self.tags) def _mark_component(self, component: T) -> T: - """Set visibility state in component metadata if rule matches.""" + """Set visibility state in component metadata if rule matches. + + Returns a copy of the component with updated metadata to avoid + mutating shared objects cached in providers. + """ if not self._matches(component): return component - # Create new dicts to avoid mutating shared dicts - # (e.g., when Tool.from_tool shares the meta dict between tools) if component.meta is None: - component.meta = { - _FASTMCP_KEY: {_INTERNAL_KEY: {"visibility": self._enabled}} - } + new_meta = {_FASTMCP_KEY: {_INTERNAL_KEY: {"visibility": self._enabled}}} else: old_fastmcp = component.meta.get(_FASTMCP_KEY, {}) old_internal = old_fastmcp.get(_INTERNAL_KEY, {}) new_internal = {**old_internal, "visibility": self._enabled} new_fastmcp = {**old_fastmcp, _INTERNAL_KEY: new_internal} - component.meta = {**component.meta, _FASTMCP_KEY: new_fastmcp} - return component + new_meta = {**component.meta, _FASTMCP_KEY: new_fastmcp} + return component.model_copy(update={"meta": new_meta}) # ------------------------------------------------------------------------- # Transform methods (mark components, don't filter) diff --git a/tests/server/test_session_visibility.py b/tests/server/test_session_visibility.py index ae307dfc68..887f4f3307 100644 --- a/tests/server/test_session_visibility.py +++ b/tests/server/test_session_visibility.py @@ -618,3 +618,151 @@ async def non_activated_session(session_id: str): assert results[f"non_activated_{i}"] is False, ( f"Non-activated session {i} should NOT see premium tool" ) + + +class TestSessionVisibilityResetBug: + """Regression tests for #3034: visibility marks leak via shared component mutation.""" + + async def test_disable_then_reset_restores_tools(self): + """After disable + reset within the same session, tools should reappear.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool(tags={"system"}) + def my_tool() -> str: + return "hello" + + @mcp.tool(tags={"env"}) + async def enter_env(ctx: Context) -> str: + await ctx.disable_components(tags={"system"}) + return "entered" + + @mcp.tool(tags={"env"}) + async def exit_env(ctx: Context) -> str: + await ctx.reset_visibility() + return "exited" + + async with Client(mcp) as client: + # Tool visible initially + tools = await client.list_tools() + assert any(t.name == "my_tool" for t in tools) + + # Disable it + await client.call_tool("enter_env", {}) + tools = await client.list_tools() + assert not any(t.name == "my_tool" for t in tools) + + # Reset — tool should come back + await client.call_tool("exit_env", {}) + tools = await client.list_tools() + assert any(t.name == "my_tool" for t in tools), ( + "Tool should be visible again after reset_visibility" + ) + + async def test_disable_reset_loop(self): + """Repeated disable/reset cycles should work every time (the exact bug from #3034).""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool(tags={"system"}) + def create_project() -> str: + return "created" + + @mcp.tool(tags={"env"}) + async def enter_env(ctx: Context) -> str: + await ctx.disable_components(tags={"system"}) + return "entered" + + @mcp.tool(tags={"env"}) + async def exit_env(ctx: Context) -> str: + await ctx.reset_visibility() + return "exited" + + async with Client(mcp) as client: + for i in range(3): + # create_project should be visible + tools = await client.list_tools() + assert any(t.name == "create_project" for t in tools), ( + f"Iteration {i}: create_project should be visible before enter_env" + ) + + # Enter env — disables system tools + await client.call_tool("enter_env", {}) + tools = await client.list_tools() + assert not any(t.name == "create_project" for t in tools), ( + f"Iteration {i}: create_project should be hidden after enter_env" + ) + + # Exit env — reset + await client.call_tool("exit_env", {}) + + async def test_session_disable_does_not_leak_to_concurrent_session(self): + """Disabling tools in one session must not affect a concurrent session.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool(tags={"system"}) + def shared_tool() -> str: + return "shared" + + @mcp.tool + async def disable_system(ctx: Context) -> str: + await ctx.disable_components(tags={"system"}) + return "disabled" + + session_b_sees_tool = False + ready = anyio.Event() + check_done = anyio.Event() + + async def session_a(): + async with Client(mcp) as client: + await client.call_tool("disable_system", {}) + ready.set() + await check_done.wait() + + async def session_b(): + nonlocal session_b_sees_tool + await ready.wait() + async with Client(mcp) as client: + tools = await client.list_tools() + session_b_sees_tool = any(t.name == "shared_tool" for t in tools) + check_done.set() + + async with anyio.create_task_group() as tg: + tg.start_soon(session_a) + tg.start_soon(session_b) + + assert session_b_sees_tool is True, ( + "Session B should still see shared_tool despite Session A disabling it" + ) + + async def test_session_disable_does_not_leak_to_sequential_session(self): + """Disabling tools in one session must not affect a later session.""" + from fastmcp import Client + + mcp = FastMCP("test") + + @mcp.tool(tags={"system"}) + def shared_tool() -> str: + return "shared" + + @mcp.tool + async def disable_system(ctx: Context) -> str: + await ctx.disable_components(tags={"system"}) + return "disabled" + + # Session A disables the tool (no reset) + async with Client(mcp) as client_a: + await client_a.call_tool("disable_system", {}) + tools = await client_a.list_tools() + assert not any(t.name == "shared_tool" for t in tools) + + # Session B should see it fresh + async with Client(mcp) as client_b: + tools = await client_b.list_tools() + assert any(t.name == "shared_tool" for t in tools), ( + "New session should see shared_tool regardless of previous session" + ) diff --git a/tests/server/transforms/test_visibility.py b/tests/server/transforms/test_visibility.py index cf7b9f1b7e..a784af1f47 100644 --- a/tests/server/transforms/test_visibility.py +++ b/tests/server/transforms/test_visibility.py @@ -101,36 +101,39 @@ class TestMarking: def test_disable_marks_as_disabled(self): """Visibility(False, ...) marks matching components as disabled.""" tool = Tool(name="foo", parameters={}) - Visibility(False, names={"foo"})._mark_component(tool) - assert is_enabled(tool) is False + marked = Visibility(False, names={"foo"})._mark_component(tool) + assert is_enabled(marked) is False def test_enable_marks_as_enabled(self): """Visibility(True, ...) marks matching components as enabled.""" tool = Tool(name="foo", parameters={}) - Visibility(True, names={"foo"})._mark_component(tool) - assert is_enabled(tool) is True - assert tool.meta is not None - assert tool.meta["fastmcp"]["_internal"]["visibility"] is True + marked = Visibility(True, names={"foo"})._mark_component(tool) + assert is_enabled(marked) is True + assert marked.meta is not None + assert marked.meta["fastmcp"]["_internal"]["visibility"] is True def test_non_matching_unchanged(self): """Non-matching components are not modified.""" tool = Tool(name="bar", parameters={}) - Visibility(False, names={"foo"})._mark_component(tool) + result = Visibility(False, names={"foo"})._mark_component(tool) # No _internal key added - assert tool.meta is None or "_internal" not in tool.meta.get("fastmcp", {}) - assert is_enabled(tool) is True + assert result.meta is None or "_internal" not in result.meta.get("fastmcp", {}) + assert is_enabled(result) is True - def test_mutates_in_place(self): - """Marking mutates the component in place.""" + def test_returns_copy_for_matching(self): + """Marking returns a copy to avoid mutating shared provider objects.""" tool = Tool(name="foo", parameters={}) result = Visibility(False, names={"foo"})._mark_component(tool) - assert result is tool + assert result is not tool + assert is_enabled(result) is False + # Original is untouched + assert is_enabled(tool) is True def test_disable_all(self): """match_all=True disables all components.""" tool = Tool(name="anything", parameters={}) - Visibility(False, match_all=True)._mark_component(tool) - assert is_enabled(tool) is False + marked = Visibility(False, match_all=True)._mark_component(tool) + assert is_enabled(marked) is False class TestOverride: @@ -139,20 +142,20 @@ class TestOverride: def test_enable_overrides_disable(self): """An enable after disable results in enabled.""" tool = Tool(name="foo", parameters={}) - Visibility(False, names={"foo"})._mark_component(tool) - assert is_enabled(tool) is False + marked = Visibility(False, names={"foo"})._mark_component(tool) + assert is_enabled(marked) is False - Visibility(True, names={"foo"})._mark_component(tool) - assert is_enabled(tool) is True + marked = Visibility(True, names={"foo"})._mark_component(marked) + assert is_enabled(marked) is True def test_disable_overrides_enable(self): """A disable after enable results in disabled.""" tool = Tool(name="foo", parameters={}) - Visibility(True, names={"foo"})._mark_component(tool) - assert is_enabled(tool) is True + marked = Visibility(True, names={"foo"})._mark_component(tool) + assert is_enabled(marked) is True - Visibility(False, names={"foo"})._mark_component(tool) - assert is_enabled(tool) is False + marked = Visibility(False, names={"foo"})._mark_component(marked) + assert is_enabled(marked) is False class TestHelperFunctions: @@ -169,9 +172,10 @@ def test_filtering_pattern(self): Tool(name="enabled", parameters={}), Tool(name="disabled", parameters={}), ] - Visibility(False, names={"disabled"})._mark_component(tools[1]) + vis = Visibility(False, names={"disabled"}) + marked_tools = [vis._mark_component(t) for t in tools] - visible = [t for t in tools if is_enabled(t)] + visible = [t for t in marked_tools if is_enabled(t)] assert [t.name for t in visible] == ["enabled"] @@ -181,14 +185,14 @@ class TestMetadata: def test_internal_metadata_stripped_by_get_meta(self): """Internal metadata is stripped when calling get_meta().""" tool = Tool(name="foo", parameters={}) - Visibility(True, names={"foo"})._mark_component(tool) + marked = Visibility(True, names={"foo"})._mark_component(tool) # Raw meta has _internal - assert tool.meta is not None - assert "_internal" in tool.meta.get("fastmcp", {}) + assert marked.meta is not None + assert "_internal" in marked.meta.get("fastmcp", {}) # get_meta() strips it - output = tool.get_meta() + output = marked.get_meta() assert "_internal" not in output.get("fastmcp", {}) def test_user_metadata_preserved(self):