From 662f2f66368493a8156006ffb885d2ef5764d310 Mon Sep 17 00:00:00 2001 From: taivu1998 <46636857+taivu1998@users.noreply.github.com> Date: Thu, 2 Apr 2026 15:39:37 -0700 Subject: [PATCH] Add multi-server support to MCPEnvironment --- examples/archive/mcp/README.md | 29 +- rllm/environments/tools/mcp_env.py | 429 ++++++++++++++++++++++++++--- tests/envs/test_mcp_env.py | 400 ++++++++++++++++++++++++++- 3 files changed, 808 insertions(+), 50 deletions(-) diff --git a/examples/archive/mcp/README.md b/examples/archive/mcp/README.md index d87c3226c..ece63b705 100644 --- a/examples/archive/mcp/README.md +++ b/examples/archive/mcp/README.md @@ -7,7 +7,7 @@ This example demonstrates how to use external MCP servers as tool providers with ```bash # Install MCP CLI (if needed for other MCP servers) uv pip install mcp - +``` ## Files @@ -71,6 +71,33 @@ This will: - **`MCPEnvironment`** - Environment that manages MCP server connections and tool execution - **`MCPConnectionManager`** - Handles MCP server lifecycle and tool discovery +### Multiple MCP Servers + +`MCPEnvironment` also supports routing tool calls across multiple named MCP servers: + +```python +env = MCPEnvironment( + task={"question": "Find and summarize the latest updates."}, + mcp_servers={ + "search": { + "command": "npx", + "args": ["-y", "tavily-mcp@0.2.4"], + "env": {"TAVILY_API_KEY": "..."}, + }, + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + }, + }, + tool_name_to_server_name={ + "tavily_search": "search", + "read_file": "filesystem", + }, +) +``` + +Use `tool_name_to_server_name` when multiple servers expose the same public tool name, including underscore aliases for tools whose original MCP names contain hyphens. + ### Integration with RLLM The example follows standard RLLM patterns: diff --git a/rllm/environments/tools/mcp_env.py b/rllm/environments/tools/mcp_env.py index 16c5c4692..bb6b18cd7 100644 --- a/rllm/environments/tools/mcp_env.py +++ b/rllm/environments/tools/mcp_env.py @@ -4,6 +4,7 @@ import threading import warnings from contextlib import AsyncExitStack +from dataclasses import dataclass from typing import Any try: @@ -17,6 +18,135 @@ from rllm.tools.mcp_tool import MCPTool +@dataclass(frozen=True) +class MCPServerSpec: + name: str + command: str + args: tuple[str, ...] = () + env_items: tuple[tuple[str, str], ...] | None = None + + @property + def args_list(self) -> list[str]: + return list(self.args) + + @property + def env_dict(self) -> dict[str, str] | None: + if self.env_items is None: + return None + return dict(self.env_items) + + +def _normalize_server_spec(name: str, config: dict[str, Any]) -> MCPServerSpec: + if not isinstance(config, dict): + raise ValueError(f"Config for MCP server '{name}' must be a dictionary") + + command = config.get("command", config.get("mcp_server_command")) + if not command: + raise ValueError(f"Config for MCP server '{name}' must include 'command'") + + raw_args = config.get("args", config.get("mcp_server_args")) or [] + if not isinstance(raw_args, list | tuple): + raise ValueError(f"Config for MCP server '{name}' must include list-like 'args'") + args = tuple(str(arg) for arg in raw_args) + + raw_env = config.get("env", config.get("mcp_server_env")) + if raw_env is not None and not isinstance(raw_env, dict): + raise ValueError(f"Config for MCP server '{name}' must include dict-like 'env'") + env_items = tuple(sorted((str(key), str(value)) for key, value in raw_env.items())) if raw_env is not None else None + + return MCPServerSpec(name=name, command=str(command), args=args, env_items=env_items) + + +def _normalize_mcp_servers( + mcp_server_command: str | None, + mcp_server_args: list[str] | None, + mcp_server_env: dict[str, str] | None, + mcp_servers: dict[str, dict[str, Any]] | None, +) -> dict[str, MCPServerSpec]: + has_legacy_config = mcp_server_command is not None or mcp_server_args is not None or mcp_server_env is not None + + if mcp_servers is not None and has_legacy_config: + raise ValueError("Cannot specify both legacy single-server MCP args and 'mcp_servers'") + + if mcp_servers is not None: + if not isinstance(mcp_servers, dict): + raise ValueError("'mcp_servers' must be a dictionary mapping server names to configs") + return {server_name: _normalize_server_spec(server_name, server_config) for server_name, server_config in mcp_servers.items()} + + if mcp_server_command is None: + return {} + + return { + "default": MCPServerSpec( + name="default", + command=str(mcp_server_command), + args=tuple(str(arg) for arg in (mcp_server_args or [])), + env_items=tuple(sorted((str(key), str(value)) for key, value in mcp_server_env.items())) if mcp_server_env is not None else None, + ) + } + + +def _tool_call_id(tool_call: Any, fallback_idx: int) -> str: + if isinstance(tool_call, dict): + tool_call_id = tool_call.get("id") + if isinstance(tool_call_id, str) and tool_call_id: + return tool_call_id + return f"tool_call_{fallback_idx}" + + +def _tool_call_name(tool_call: Any) -> str | None: + if not isinstance(tool_call, dict): + return None + function = tool_call.get("function") + if not isinstance(function, dict): + return None + tool_name = function.get("name") + if isinstance(tool_name, str) and tool_name: + return tool_name + return None + + +def _parse_tool_arguments(tool_call: Any) -> tuple[dict[str, Any] | None, str | None]: + if not isinstance(tool_call, dict): + return None, "Tool call must be a dictionary" + + function = tool_call.get("function") + if not isinstance(function, dict): + return None, "Tool call missing function payload" + + raw_arguments = function.get("arguments", {}) + if isinstance(raw_arguments, dict): + return raw_arguments, None + if isinstance(raw_arguments, str): + try: + parsed = json.loads(raw_arguments) + except json.JSONDecodeError as exc: + return None, f"Invalid tool arguments JSON: {exc}" + if not isinstance(parsed, dict): + return None, "Tool arguments JSON must decode to an object" + return parsed, None + return None, "Tool arguments must be a dict or JSON string" + + +def _assign_missing_tool_call_ids(tool_calls: list[Any]) -> list[Any]: + normalized_tool_calls: list[Any] = [] + for idx, tool_call in enumerate(tool_calls): + if not isinstance(tool_call, dict): + normalized_tool_calls.append(tool_call) + continue + + tool_call_id = tool_call.get("id") + if isinstance(tool_call_id, str) and tool_call_id: + normalized_tool_calls.append(tool_call) + continue + + normalized_tool_call = dict(tool_call) + normalized_tool_call["id"] = _tool_call_id(tool_call, idx) + normalized_tool_calls.append(normalized_tool_call) + + return normalized_tool_calls + + class MCPConnectionManager: """Manages MCP connections in a dedicated thread to avoid asyncio context issues.""" @@ -154,16 +284,24 @@ async def _execute_tools(self, tool_calls: list[dict[str, Any]]) -> dict[str, st """Execute tool calls.""" tool_outputs: dict[str, str] = {} - for tool_call in tool_calls: - tool_name = tool_call["function"]["name"] - tool_args = json.loads(tool_call["function"]["arguments"]) + for idx, tool_call in enumerate(tool_calls): + tool_call_id = _tool_call_id(tool_call, idx) + tool_name = _tool_call_name(tool_call) + if tool_name is None: + tool_outputs[tool_call_id] = "Error: Tool call missing function.name" + continue + + tool_args, parse_error = _parse_tool_arguments(tool_call) + if parse_error is not None or tool_args is None: + tool_outputs[tool_call_id] = f"Error: {parse_error}" + continue if tool_name in self.tool_map: tool_instance = self.tool_map[tool_name] result = await tool_instance.async_forward(**tool_args) - tool_outputs[tool_call["id"]] = result.to_string() + tool_outputs[tool_call_id] = result.to_string() else: - tool_outputs[tool_call["id"]] = f"Error: Tool {tool_name} not found" + tool_outputs[tool_call_id] = f"Error: Tool {tool_name} not found" return tool_outputs @@ -179,11 +317,23 @@ class MCPEnvironment(BaseEnv): Uses a dedicated connection manager to avoid asyncio context issues. """ - # Class-level connection manager to share across instances - _connection_manager: MCPConnectionManager | None = None + # Class-level connection managers shared across instances + _connection_manager: MCPConnectionManager | None = None # backward-compatible alias for single-server usage + _connection_managers: dict[str, MCPConnectionManager] = {} + _server_specs: dict[str, MCPServerSpec] = {} _manager_lock = threading.Lock() - def __init__(self, task: dict[str, Any] | None = None, mcp_server_command: str | None = None, mcp_server_args: list[str] | None = None, mcp_server_env: dict[str, str] | None = None, reward_fn: RewardFunction | None = None, max_steps: int = 10): + def __init__( + self, + task: dict[str, Any] | None = None, + mcp_server_command: str | None = None, + mcp_server_args: list[str] | None = None, + mcp_server_env: dict[str, str] | None = None, + mcp_servers: dict[str, dict[str, Any]] | None = None, + tool_name_to_server_name: dict[str, str] | None = None, + reward_fn: RewardFunction | None = None, + max_steps: int = 10, + ): """ Initialize the MCPEnvironment. @@ -192,6 +342,9 @@ def __init__(self, task: dict[str, Any] | None = None, mcp_server_command: str | mcp_server_command: Command to run the MCP server. mcp_server_args: Arguments for the MCP server. mcp_server_env: Environment variables for the MCP server. + mcp_servers: Named MCP server configurations for multi-server routing. + tool_name_to_server_name: Optional explicit mapping from public tool names + to MCP server names. reward_fn: Reward function to use for evaluation. max_steps: Maximum number of steps allowed in the environment. """ @@ -206,12 +359,202 @@ def __init__(self, task: dict[str, Any] | None = None, mcp_server_command: str | self.mcp_server_command = mcp_server_command self.mcp_server_args = mcp_server_args or [] self.mcp_server_env = mcp_server_env + self.mcp_servers = _normalize_mcp_servers(mcp_server_command, mcp_server_args, mcp_server_env, mcp_servers) + self.tool_name_to_server_name = dict(tool_name_to_server_name or {}) + self._resolved_tool_name_to_server_name: dict[str, str] = {} - # Initialize shared connection manager - with MCPEnvironment._manager_lock: - if MCPEnvironment._connection_manager is None and mcp_server_command is not None: - MCPEnvironment._connection_manager = MCPConnectionManager(mcp_server_command, mcp_server_args, mcp_server_env) - MCPEnvironment._connection_manager.start() + newly_created_server_names: list[str] = [] + try: + newly_created_server_names = self._ensure_connection_managers() + self._resolved_tool_name_to_server_name = self._build_tool_routing() + except Exception: + if newly_created_server_names: + self._rollback_connection_managers(newly_created_server_names) + raise + + @classmethod + def _sync_connection_manager_alias_locked(cls) -> None: + cls._connection_manager = next(iter(cls._connection_managers.values())) if len(cls._connection_managers) == 1 else None + + @classmethod + def _rollback_connection_managers(cls, server_names: list[str]) -> None: + managers_to_stop: list[MCPConnectionManager] = [] + with cls._manager_lock: + for server_name in server_names: + manager = cls._connection_managers.pop(server_name, None) + cls._server_specs.pop(server_name, None) + if manager is not None: + managers_to_stop.append(manager) + cls._sync_connection_manager_alias_locked() + + for manager in managers_to_stop: + try: + manager.stop() + except Exception: + pass + + def _ensure_connection_managers(self) -> list[str]: + newly_created_server_names: list[str] = [] + managers_to_stop: list[MCPConnectionManager] = [] + + try: + with MCPEnvironment._manager_lock: + for server_name, server_spec in self.mcp_servers.items(): + existing_spec = MCPEnvironment._server_specs.get(server_name) + if existing_spec is not None: + if existing_spec != server_spec: + raise ValueError( + f"MCP server '{server_name}' is already initialized with a different configuration" + ) + continue + + manager = MCPConnectionManager( + mcp_server_command=server_spec.command, + mcp_server_args=server_spec.args_list, + mcp_server_env=server_spec.env_dict, + ) + try: + manager.start() + except Exception: + managers_to_stop.append(manager) + raise + + MCPEnvironment._connection_managers[server_name] = manager + MCPEnvironment._server_specs[server_name] = server_spec + newly_created_server_names.append(server_name) + + MCPEnvironment._sync_connection_manager_alias_locked() + except Exception: + with MCPEnvironment._manager_lock: + for server_name in newly_created_server_names: + manager = MCPEnvironment._connection_managers.pop(server_name, None) + MCPEnvironment._server_specs.pop(server_name, None) + if manager is not None: + managers_to_stop.append(manager) + MCPEnvironment._sync_connection_manager_alias_locked() + + for manager in managers_to_stop: + try: + manager.stop() + except Exception: + pass + raise + + return newly_created_server_names + + def _build_tool_routing(self) -> dict[str, str]: + if not self.mcp_servers: + return {} + + discovered_tool_servers: dict[str, set[str]] = {} + for server_name in self.mcp_servers: + manager = MCPEnvironment._connection_managers.get(server_name) + if manager is None: + continue + for public_tool_name in getattr(manager, "tool_map", {}): + discovered_tool_servers.setdefault(public_tool_name, set()).add(server_name) + + resolved: dict[str, str] = {} + + for public_tool_name, candidate_servers in discovered_tool_servers.items(): + explicit_server_name = self.tool_name_to_server_name.get(public_tool_name) + if explicit_server_name is not None: + if explicit_server_name not in candidate_servers: + raise ValueError( + f"Tool '{public_tool_name}' is not provided by mapped MCP server '{explicit_server_name}'" + ) + resolved[public_tool_name] = explicit_server_name + elif len(candidate_servers) == 1: + resolved[public_tool_name] = next(iter(candidate_servers)) + else: + raise ValueError( + f"Tool '{public_tool_name}' is provided by multiple MCP servers {sorted(candidate_servers)}. " + "Supply 'tool_name_to_server_name' to disambiguate." + ) + + for public_tool_name, mapped_server_name in self.tool_name_to_server_name.items(): + if mapped_server_name not in self.mcp_servers: + raise ValueError( + f"Tool mapping for '{public_tool_name}' references unknown MCP server '{mapped_server_name}'" + ) + if public_tool_name not in discovered_tool_servers: + raise ValueError( + f"Tool mapping for '{public_tool_name}' does not match any discovered tool on the configured MCP servers" + ) + + return resolved + + @staticmethod + def _is_finish_tool_call(tool_call: Any) -> bool: + return _tool_call_name(tool_call) == "finish" + + def _extract_final_response(self, action: list[dict[str, Any]] | str) -> str: + if isinstance(action, str): + return action + + finish_action = None + for tool_call in action: + if self._is_finish_tool_call(tool_call): + finish_action = tool_call + break + + if finish_action is None: + return str(action) + + arguments, parse_error = _parse_tool_arguments(finish_action) + if parse_error is not None or arguments is None: + return str(action) + + response = arguments.get("response", "") + return response if isinstance(response, str) else str(response) + + def _execute_tool_calls_by_server(self, tool_calls: list[dict[str, Any]]) -> dict[str, str]: + tool_calls = _assign_missing_tool_call_ids(tool_calls) + tool_outputs: dict[str, str] = {} + grouped_calls: dict[str, list[dict[str, Any]]] = {} + + for idx, tool_call in enumerate(tool_calls): + tool_call_id = _tool_call_id(tool_call, idx) + tool_name = _tool_call_name(tool_call) + if tool_name is None: + tool_outputs[tool_call_id] = "Error: Tool call missing function.name" + continue + + server_name = self._resolved_tool_name_to_server_name.get(tool_name) + if server_name is None and len(self.mcp_servers) == 1: + # Preserve legacy single-server behavior where every tool call is + # forwarded to the sole configured MCP server. + server_name = next(iter(self.mcp_servers)) + if server_name is None: + tool_outputs[tool_call_id] = f"Error: Tool {tool_name} not found" + continue + + grouped_calls.setdefault(server_name, []).append(tool_call) + + for server_name, grouped_tool_calls in grouped_calls.items(): + manager = MCPEnvironment._connection_managers.get(server_name) + if manager is None: + for idx, tool_call in enumerate(grouped_tool_calls): + tool_outputs[_tool_call_id(tool_call, idx)] = f"Error: MCP server {server_name} is not available" + continue + + try: + tool_outputs.update(manager.execute_tool_calls(grouped_tool_calls)) + except Exception as exc: + for idx, tool_call in enumerate(grouped_tool_calls): + tool_outputs[_tool_call_id(tool_call, idx)] = f"Error: MCP server {server_name} failed: {exc}" + + ordered_tool_outputs: dict[str, str] = {} + for idx, tool_call in enumerate(tool_calls): + tool_call_id = _tool_call_id(tool_call, idx) + if tool_call_id in tool_outputs: + ordered_tool_outputs[tool_call_id] = tool_outputs[tool_call_id] + + for tool_call_id, tool_output in tool_outputs.items(): + if tool_call_id not in ordered_tool_outputs: + ordered_tool_outputs[tool_call_id] = tool_output + + return ordered_tool_outputs def reset(self): """Reset the environment and return initial observations.""" @@ -239,32 +582,13 @@ def step(self, action: Any): # Check if action contains a "finish" tool call if isinstance(action, list) and action: for tool_call in action: - if tool_call.get("function", {}).get("name") == "finish": + if self._is_finish_tool_call(tool_call): done = True break if done: # Agent is done - evaluate the response - if isinstance(action, str): - llm_response = action - elif isinstance(action, list): - # Find the finish tool call - finish_action = None - for tool_call in action: - if tool_call.get("function", {}).get("name") == "finish": - finish_action = tool_call - break - if finish_action: - arguments = finish_action.get("function", {}).get("arguments", {}) - if isinstance(arguments, str): - arguments = json.loads(arguments) - - if isinstance(arguments, dict): - llm_response = arguments.get("response", "") - else: - llm_response = str(arguments) - else: - llm_response = str(action) + llm_response = self._extract_final_response(action) if self.reward_fn and self.task is not None: reward_output = self.reward_fn(task_info=self.task, action=llm_response) @@ -275,11 +599,8 @@ def step(self, action: Any): # Execute tool calls using the connection manager tool_calls = action try: - if MCPEnvironment._connection_manager is not None: - tool_outputs = MCPEnvironment._connection_manager.execute_tool_calls(tool_calls) - next_obs = {"tool_outputs": tool_outputs} - else: - next_obs = {"tool_outputs": {}} + tool_outputs = self._execute_tool_calls_by_server(tool_calls) if isinstance(tool_calls, list) else {} + next_obs = {"tool_outputs": tool_outputs} except Exception as e: print(f"Tool execution error: {e}") next_obs = {"tool_outputs": {}} @@ -293,17 +614,37 @@ def close(self): @staticmethod def cleanup_global_resources(): - """Clean up global connection manager.""" + """Clean up global connection managers.""" + managers_to_stop: list[MCPConnectionManager] = [] with MCPEnvironment._manager_lock: - if MCPEnvironment._connection_manager: - MCPEnvironment._connection_manager.stop() - MCPEnvironment._connection_manager = None + managers_to_stop = list(MCPEnvironment._connection_managers.values()) + MCPEnvironment._connection_managers = {} + MCPEnvironment._server_specs = {} + MCPEnvironment._sync_connection_manager_alias_locked() + + for manager in managers_to_stop: + try: + manager.stop() + except Exception: + pass @staticmethod def from_dict(env_args: dict[str, Any]) -> "MCPEnvironment": + env_args = dict(env_args) mcp_server_command = env_args.pop("mcp_server_command", None) mcp_server_args = env_args.pop("mcp_server_args", None) mcp_server_env = env_args.pop("mcp_server_env", None) + mcp_servers = env_args.pop("mcp_servers", None) + tool_name_to_server_name = env_args.pop("tool_name_to_server_name", None) reward_fn = env_args.pop("reward_fn", None) max_steps = env_args.pop("max_steps", 10) - return MCPEnvironment(task=env_args, mcp_server_command=mcp_server_command, mcp_server_args=mcp_server_args, mcp_server_env=mcp_server_env, max_steps=max_steps, reward_fn=reward_fn) + return MCPEnvironment( + task=env_args, + mcp_server_command=mcp_server_command, + mcp_server_args=mcp_server_args, + mcp_server_env=mcp_server_env, + mcp_servers=mcp_servers, + tool_name_to_server_name=tool_name_to_server_name, + max_steps=max_steps, + reward_fn=reward_fn, + ) diff --git a/tests/envs/test_mcp_env.py b/tests/envs/test_mcp_env.py index 0ce9953cf..f1189a3e6 100644 --- a/tests/envs/test_mcp_env.py +++ b/tests/envs/test_mcp_env.py @@ -16,6 +16,25 @@ def __call__(self, task_info, action, **kwargs): return RewardOutput(reward=reward, metadata=metadata) +@pytest.fixture(autouse=True) +def reset_mcp_environment_state(): + MCPEnvironment._connection_manager = None + MCPEnvironment._connection_managers = {} + MCPEnvironment._server_specs = {} + yield + MCPEnvironment._connection_manager = None + MCPEnvironment._connection_managers = {} + MCPEnvironment._server_specs = {} + + +def make_start_side_effect(command_to_tools): + def _start(manager): + manager.running = True + manager.tool_map = command_to_tools.get(manager.mcp_server_command, {}) + + return _start + + class TestMCPConnectionManager: """Test suite for MCPConnectionManager class.""" @@ -277,6 +296,7 @@ def test_step_with_regular_tool_calls(self, mock_init, mock_start): mock_manager = Mock() mock_manager.execute_tool_calls.return_value = {"call_1": "Tool output"} MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} action = [{"id": "call_1", "function": {"name": "search", "arguments": {"query": "test"}}}] @@ -302,6 +322,7 @@ def test_step_max_steps_termination(self, mock_init, mock_start): mock_manager = Mock() mock_manager.execute_tool_calls.return_value = {"call_1": "Tool output"} MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} # Take steps until max_steps for i in range(2): @@ -328,6 +349,7 @@ def test_step_with_dict_action(self, mock_init, mock_start): mock_manager = Mock() mock_manager.execute_tool_calls.return_value = {"call_1": "Tool output"} MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} action = {"id": "call_1", "function": {"name": "search", "arguments": {"query": "test"}}} @@ -370,11 +392,14 @@ def test_cleanup_global_resources_with_manager(self): """Test cleanup_global_resources with existing manager.""" mock_manager = Mock() MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} MCPEnvironment.cleanup_global_resources() mock_manager.stop.assert_called_once() assert MCPEnvironment._connection_manager is None + assert MCPEnvironment._connection_managers == {} + assert MCPEnvironment._server_specs == {} @patch.object(MCPConnectionManager, "start") @patch.object(MCPConnectionManager, "__init__", return_value=None) @@ -398,7 +423,14 @@ def test_is_multithread_safe(self): def test_from_dict(self): """Test creating environment from dictionary.""" - env_args = {"question": "Test question", "mcp_server_command": "test_command", "mcp_server_args": ["--arg1"], "mcp_server_env": {"VAR": "value"}, "max_steps": 15, "reward_fn": MockRewardFunction()} + env_args = { + "question": "Test question", + "mcp_server_command": "test_command", + "mcp_server_args": ["--arg1"], + "mcp_server_env": {"VAR": "value"}, + "max_steps": 15, + "reward_fn": MockRewardFunction(), + } with patch.object(MCPConnectionManager, "start"), patch.object(MCPConnectionManager, "__init__", return_value=None): # Clear any existing manager @@ -447,6 +479,7 @@ def test_full_interaction_flow(self, mock_init, mock_start): mock_manager = Mock() mock_manager.execute_tool_calls.return_value = {"call_1": "Paris is the capital of France"} MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} action1 = [{"id": "call_1", "function": {"name": "search", "arguments": {"query": "capital of France"}}}] @@ -499,12 +532,15 @@ def test_step_with_tool_execution_error(self, mock_init, mock_start): mock_manager = Mock() mock_manager.execute_tool_calls.side_effect = Exception("Tool execution failed") MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} action = [{"id": "call_1", "function": {"name": "search", "arguments": {"query": "test"}}}] - # Should not raise error but return empty tool outputs obs, reward, done, info = env.step(action) - assert obs == {"tool_outputs": {}} + assert obs == {"tool_outputs": {"call_1": "Error: MCP server default failed: Tool execution failed"}} + assert reward == 0 + assert done is False + assert info["response"] == action @patch.object(MCPConnectionManager, "start") @patch.object(MCPConnectionManager, "__init__", return_value=None) @@ -520,6 +556,7 @@ def test_edge_cases(self, mock_init, mock_start): mock_manager = Mock() mock_manager.execute_tool_calls.return_value = {} MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} # Empty action list obs, reward, done, info = env.step([]) @@ -569,6 +606,8 @@ def test_connection_manager_thread_safety(self): """Test that connection manager handles thread safety correctly.""" # Both environments should be able to access the class-level manager assert hasattr(MCPEnvironment, "_connection_manager") + assert hasattr(MCPEnvironment, "_connection_managers") + assert hasattr(MCPEnvironment, "_server_specs") assert hasattr(MCPEnvironment, "_manager_lock") @patch.object(MCPConnectionManager, "__init__", return_value=None) @@ -583,6 +622,9 @@ def test_connection_manager_singleton_behavior(self, mock_start, mock_init): # Both environments should use the same manager assert MCPEnvironment._connection_manager is not None + assert len(MCPEnvironment._connection_managers) == 1 + assert mock_init.call_count == 1 + assert mock_start.call_count == 1 @patch.object(MCPConnectionManager, "start") @patch.object(MCPConnectionManager, "__init__", return_value=None) @@ -598,12 +640,360 @@ def test_malformed_tool_call_handling(self, mock_init, mock_start): mock_manager = Mock() mock_manager.execute_tool_calls.return_value = {"call_1": "Tool output"} MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} # Malformed action (missing required fields) action = [{"id": "call_1"}] # Missing function field obs, reward, done, info = env.step(action) - # Should still process the action - assert obs == {"tool_outputs": {"call_1": "Tool output"}} + assert obs == {"tool_outputs": {"call_1": "Error: Tool call missing function.name"}} + assert done is False + mock_manager.execute_tool_calls.assert_not_called() + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_init_with_multiple_servers(self, mock_start): + """Test initializing MCPEnvironment with multiple named servers.""" + mock_start.side_effect = make_start_side_effect( + { + "search-command": {"search": Mock()}, + "wiki-command": {"lookup": Mock()}, + } + ) + + env = MCPEnvironment( + mcp_servers={ + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + } + ) + + assert set(env.mcp_servers) == {"search_server", "wiki_server"} + assert set(MCPEnvironment._connection_managers) == {"search_server", "wiki_server"} + assert MCPEnvironment._connection_manager is None + assert env._resolved_tool_name_to_server_name == { + "search": "search_server", + "lookup": "wiki_server", + } + assert mock_start.call_count == 2 + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_same_server_name_with_different_config_raises(self, mock_start): + """Test that reusing a server name with a different config fails fast.""" + mock_start.side_effect = make_start_side_effect({"search-command": {"search": Mock()}}) + + MCPEnvironment(mcp_servers={"shared": {"command": "search-command"}}) + + with pytest.raises(ValueError, match="already initialized with a different configuration"): + MCPEnvironment(mcp_servers={"shared": {"command": "different-command"}}) + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_explicit_tool_name_to_server_name_resolves_ambiguity(self, mock_start): + """Test that explicit tool routing resolves duplicate tool names.""" + mock_start.side_effect = make_start_side_effect( + { + "command-a": {"shared_tool": Mock()}, + "command-b": {"shared_tool": Mock()}, + } + ) + + env = MCPEnvironment( + mcp_servers={ + "server_a": {"command": "command-a"}, + "server_b": {"command": "command-b"}, + }, + tool_name_to_server_name={"shared_tool": "server_b"}, + ) + + assert env._resolved_tool_name_to_server_name == {"shared_tool": "server_b"} + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_duplicate_public_tool_name_without_mapping_raises(self, mock_start): + """Test that duplicate tool names across servers require explicit routing.""" + mock_start.side_effect = make_start_side_effect( + { + "command-a": {"shared_tool": Mock()}, + "command-b": {"shared_tool": Mock()}, + } + ) + + with pytest.raises(ValueError, match="Tool 'shared_tool' is provided by multiple MCP servers"): + MCPEnvironment( + mcp_servers={ + "server_a": {"command": "command-a"}, + "server_b": {"command": "command-b"}, + } + ) + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_step_routes_tool_calls_to_correct_server(self, mock_start): + """Test that tool calls are routed to the correct MCP server.""" + mock_start.side_effect = make_start_side_effect( + { + "search-command": {"search": Mock()}, + "wiki-command": {"lookup": Mock()}, + } + ) + env = MCPEnvironment( + mcp_servers={ + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + } + ) + env.reset() + + search_manager = MCPEnvironment._connection_managers["search_server"] + wiki_manager = MCPEnvironment._connection_managers["wiki_server"] + search_manager.execute_tool_calls = Mock(return_value={"call_1": "Search output"}) + wiki_manager.execute_tool_calls = Mock(return_value={"call_2": "Lookup output"}) + + action = [ + {"id": "call_1", "function": {"name": "search", "arguments": {"query": "France"}}}, + {"id": "call_2", "function": {"name": "lookup", "arguments": {"topic": "Paris"}}}, + ] + + obs, reward, done, info = env.step(action) + + assert obs == {"tool_outputs": {"call_1": "Search output", "call_2": "Lookup output"}} + assert reward == 0 + assert done is False + assert info["response"] == action + search_manager.execute_tool_calls.assert_called_once_with([action[0]]) + wiki_manager.execute_tool_calls.assert_called_once_with([action[1]]) + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_partial_server_failure_does_not_erase_other_outputs(self, mock_start): + """Test that one server failure does not discard successful tool outputs.""" + mock_start.side_effect = make_start_side_effect( + { + "search-command": {"search": Mock()}, + "wiki-command": {"lookup": Mock()}, + } + ) + env = MCPEnvironment( + mcp_servers={ + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + } + ) + env.reset() + + search_manager = MCPEnvironment._connection_managers["search_server"] + wiki_manager = MCPEnvironment._connection_managers["wiki_server"] + search_manager.execute_tool_calls = Mock(return_value={"call_1": "Search output"}) + wiki_manager.execute_tool_calls = Mock(side_effect=Exception("wiki unavailable")) + + action = [ + {"id": "call_1", "function": {"name": "search", "arguments": {"query": "France"}}}, + {"id": "call_2", "function": {"name": "lookup", "arguments": {"topic": "Paris"}}}, + ] + + obs, reward, done, info = env.step(action) + + assert obs == { + "tool_outputs": { + "call_1": "Search output", + "call_2": "Error: MCP server wiki_server failed: wiki unavailable", + } + } + assert reward == 0 assert done is False + assert info["response"] == action + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_step_assigns_missing_tool_call_ids_across_servers(self, mock_start): + """Test that synthetic tool call ids stay unique across routed server groups.""" + mock_start.side_effect = make_start_side_effect( + { + "search-command": {"search": Mock()}, + "wiki-command": {"lookup": Mock()}, + } + ) + env = MCPEnvironment( + mcp_servers={ + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + } + ) + env.reset() + + search_manager = MCPEnvironment._connection_managers["search_server"] + wiki_manager = MCPEnvironment._connection_managers["wiki_server"] + search_manager.execute_tool_calls = Mock(side_effect=lambda tool_calls: {tool_calls[0]["id"]: "Search output"}) + wiki_manager.execute_tool_calls = Mock(side_effect=lambda tool_calls: {tool_calls[0]["id"]: "Lookup output"}) + + action = [ + {"function": {"name": "search", "arguments": {"query": "France"}}}, + {"function": {"name": "lookup", "arguments": {"topic": "Paris"}}}, + ] + + obs, reward, done, info = env.step(action) + + assert obs == { + "tool_outputs": { + "tool_call_0": "Search output", + "tool_call_1": "Lookup output", + } + } + assert reward == 0 + assert done is False + assert info["response"] == action + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_step_preserves_interleaved_tool_output_order_across_servers(self, mock_start): + """Test that output ordering follows the original tool-call order.""" + mock_start.side_effect = make_start_side_effect( + { + "search-command": {"search": Mock()}, + "wiki-command": {"lookup": Mock()}, + } + ) + env = MCPEnvironment( + mcp_servers={ + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + } + ) + env.reset() + + search_manager = MCPEnvironment._connection_managers["search_server"] + wiki_manager = MCPEnvironment._connection_managers["wiki_server"] + search_manager.execute_tool_calls = Mock( + return_value={ + "call_1": "Search output 1", + "call_3": "Search output 2", + } + ) + wiki_manager.execute_tool_calls = Mock(return_value={"call_2": "Lookup output"}) + + action = [ + {"id": "call_1", "function": {"name": "search", "arguments": {"query": "France"}}}, + {"id": "call_2", "function": {"name": "lookup", "arguments": {"topic": "Paris"}}}, + {"id": "call_3", "function": {"name": "search", "arguments": {"query": "Europe"}}}, + ] + + obs, reward, done, info = env.step(action) + + assert list(obs["tool_outputs"]) == ["call_1", "call_2", "call_3"] + assert reward == 0 + assert done is False + assert info["response"] == action + + @patch.object(MCPConnectionManager, "stop", autospec=True) + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_start_failure_rolls_back_previously_started_managers(self, mock_start, mock_stop): + """Test that manager startup failures do not leave partial global state behind.""" + + def _start(manager): + if manager.mcp_server_command == "search-command": + manager.running = True + manager.tool_map = {"search": Mock()} + return + raise RuntimeError("startup failed") + + mock_start.side_effect = _start + + with pytest.raises(RuntimeError, match="startup failed"): + MCPEnvironment( + mcp_servers={ + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + } + ) + + assert MCPEnvironment._connection_manager is None + assert MCPEnvironment._connection_managers == {} + assert MCPEnvironment._server_specs == {} + assert mock_stop.call_count == 2 + + @patch.object(MCPConnectionManager, "stop", autospec=True) + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_invalid_tool_mapping_rolls_back_new_managers(self, mock_start, mock_stop): + """Test that routing validation failures clean up newly started managers.""" + mock_start.side_effect = make_start_side_effect( + { + "search-command": {"search": Mock()}, + "wiki-command": {"lookup": Mock()}, + } + ) + + with pytest.raises(ValueError, match="does not match any discovered tool"): + MCPEnvironment( + mcp_servers={ + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + }, + tool_name_to_server_name={"missing_tool": "search_server"}, + ) + + assert MCPEnvironment._connection_manager is None + assert MCPEnvironment._connection_managers == {} + assert MCPEnvironment._server_specs == {} + assert mock_stop.call_count == 2 + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_from_dict_with_mcp_servers(self, mock_start): + """Test creating an environment from dictionary with multi-server config.""" + mock_start.side_effect = make_start_side_effect( + { + "search-command": {"search": Mock()}, + "wiki-command": {"lookup": Mock()}, + } + ) + env_args = { + "question": "Test question", + "mcp_servers": { + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + }, + "tool_name_to_server_name": {"lookup": "wiki_server"}, + "max_steps": 15, + "reward_fn": MockRewardFunction(), + } + + env = MCPEnvironment.from_dict(env_args) + + assert isinstance(env, MCPEnvironment) + assert env.task == {"question": "Test question"} + assert env.max_steps == 15 + assert set(env.mcp_servers) == {"search_server", "wiki_server"} + assert env.tool_name_to_server_name == {"lookup": "wiki_server"} + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_from_dict_does_not_mutate_input(self, mock_start): + """Test that from_dict does not mutate the provided env_args dictionary.""" + mock_start.side_effect = make_start_side_effect({"search-command": {"search": Mock()}}) + env_args = { + "question": "Test question", + "mcp_servers": {"search_server": {"command": "search-command"}}, + "tool_name_to_server_name": {"search": "search_server"}, + "max_steps": 7, + } + expected_env_args = { + "question": "Test question", + "mcp_servers": {"search_server": {"command": "search-command"}}, + "tool_name_to_server_name": {"search": "search_server"}, + "max_steps": 7, + } + + MCPEnvironment.from_dict(env_args) + + assert env_args == expected_env_args + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_hyphenated_tool_aliases_are_checked_for_duplicates(self, mock_start): + """Test that hyphenated tool aliases participate in duplicate detection.""" + mock_start.side_effect = make_start_side_effect( + { + "command-a": {"search-tool": Mock(), "search_tool": Mock()}, + "command-b": {"search-tool": Mock(), "search_tool": Mock()}, + } + ) + + with pytest.raises(ValueError, match="Tool 'search-tool' is provided by multiple MCP servers"): + MCPEnvironment( + mcp_servers={ + "server_a": {"command": "command-a"}, + "server_b": {"command": "command-b"}, + } + )