Skip to content

Commit f3e9cab

Browse files
committed
Make MCPAgent more configurable
1 parent 59201da commit f3e9cab

File tree

6 files changed

+44
-26
lines changed

6 files changed

+44
-26
lines changed

coagent/agents/mcp_agent.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -21,36 +21,53 @@ class Prompt:
2121

2222

2323
class MCPAgent(ChatAgent):
24-
"""An agent that can use tools provided by MCP (Model Context Protocol) servers."""
24+
"""An agent that can use prompts and tools provided by MCP (Model Context Protocol) servers."""
2525

2626
def __init__(
2727
self,
28-
system: Prompt | str = "",
2928
mcp_server_base_url: str = "",
30-
selected_tools: list[str] | None = None,
29+
mcp_server_headers: dict[str, Any] | None = None,
30+
system: Prompt | str = "",
31+
tools: list[str] | None = None,
3132
client: ModelClient = default_model_client,
3233
) -> None:
3334
super().__init__(system="", client=client)
3435

3536
self._mcp_server_base_url: str = mcp_server_base_url
37+
self._mcp_server_headers: dict[str, Any] | None = mcp_server_headers
38+
3639
self._mcp_client_transport: AsyncContextManager[tuple] | None = None
3740
self._mcp_client_session: ClientSession | None = None
3841

3942
self._mcp_swarm_agent: SwarmAgent | None = None
4043
self._mcp_system_prompt_config: Prompt | str = system
4144
# The selected tools to use. If None, all available tools will be used.
42-
self._mcp_selected_tools: list[str] | None = selected_tools
45+
self._mcp_selected_tools: list[str] | None = tools
4346

4447
@property
4548
def mcp_server_base_url(self) -> str:
4649
if not self._mcp_server_base_url:
4750
raise ValueError("MCP server base URL is empty")
4851
return self._mcp_server_base_url
4952

53+
@property
54+
def mcp_server_headers(self) -> dict[str, Any] | None:
55+
return self._mcp_server_headers
56+
57+
@property
58+
def system(self) -> Prompt | str:
59+
"""Note that this property is different from the `system` property in ChatAgent."""
60+
return self._mcp_system_prompt_config
61+
62+
@property
63+
def tools(self) -> list[str] | None:
64+
"""Note that this property is different from the `tools` property in ChatAgent."""
65+
return self._mcp_selected_tools
66+
5067
def _make_mcp_client_transport(self) -> AsyncContextManager[tuple]:
5168
if self.mcp_server_base_url.startswith(("http://", "https://")):
5269
url = urljoin(self.mcp_server_base_url, "sse")
53-
return sse_client(url=url)
70+
return sse_client(url=url, headers=self.mcp_server_headers)
5471
else:
5572
# Mainly for testing purposes.
5673
command, arg = self.mcp_server_base_url.split(" ", 1)
@@ -88,8 +105,8 @@ async def _handle_data(self) -> None:
88105

89106
async def get_swarm_agent(self) -> SwarmAgent:
90107
if not self._mcp_swarm_agent:
91-
system = await self._get_prompt(self._mcp_system_prompt_config)
92-
tools = await self._get_tools()
108+
system = await self._get_prompt(self.system)
109+
tools = await self._get_tools(self.tools)
93110
self._mcp_swarm_agent = SwarmAgent(
94111
name=self.name,
95112
model=self.client.model,
@@ -117,13 +134,13 @@ async def _get_prompt(self, prompt_config: Prompt | str) -> str:
117134
case _: # ImageContent() or EmbeddedResource() or other types
118135
return ""
119136

120-
async def _get_tools(self) -> list[Callable]:
137+
async def _get_tools(self, selected_tools: list[str] | None) -> list[Callable]:
121138
result = await self._mcp_client_session.list_tools()
122139

123140
def filter_tool(t: Tool) -> bool:
124-
if self._mcp_selected_tools is None:
141+
if selected_tools is None:
125142
return True
126-
return t.name in self._mcp_selected_tools
143+
return t.name in selected_tools
127144

128145
tools = [self._make_tool(t) for t in result.tools if filter_tool(t)]
129146

examples/mcp/agent.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
"mcp",
1111
new(
1212
MCPAgent,
13-
system=Prompt(name="system_prompt", arguments={"role": "Weather Reporter"}),
1413
mcp_server_base_url="http://localhost:8080",
14+
system=Prompt(name="system_prompt", arguments={"role": "Weather Reporter"}),
1515
),
1616
)
1717

poetry.lock

+7-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ blinker = "1.9.0"
3535
loguru = "0.7.3"
3636
jq = "1.8.0"
3737
litellm = "1.60.4"
38-
mcp = "1.2.0"
38+
mcp = "1.3.0"
3939
jinja2 = "3.1.5"
4040

4141
[tool.pyright]

tests/agents/test_mcp_agent.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ async def test_get_tools(self):
3737
agent = MCPAgent(mcp_server_base_url="python tests/agents/mcp_server.py")
3838
await agent.started()
3939

40-
tools = await agent._get_tools()
40+
tools = await agent._get_tools(None)
4141
assert len(tools) == 2
4242

4343
# Get query_weather
@@ -99,13 +99,14 @@ async def test_get_tools(self):
9999
@pytest.mark.skipif(sys.platform == "win32", reason="Does not run on Windows.")
100100
@pytest.mark.asyncio
101101
async def test_get_tools_with_selection(self):
102+
selected_tools = ["query_weather"]
102103
agent = MCPAgent(
103104
mcp_server_base_url="python tests/agents/mcp_server.py",
104-
selected_tools=["query_weather"],
105+
tools=selected_tools,
105106
)
106107
await agent.started()
107108

108-
tools = await agent._get_tools()
109+
tools = await agent._get_tools(selected_tools)
109110
assert len(tools) == 1
110111

111112
tool = tools[0]

uv.lock

+4-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)