@@ -21,36 +21,53 @@ class Prompt:
21
21
22
22
23
23
class MCPAgent (ChatAgent ):
24
- """An agent that can use tools provided by MCP (Model Context Protocol) servers."""
24
+ """An agent that can use prompts and tools provided by MCP (Model Context Protocol) servers."""
25
25
26
26
def __init__ (
27
27
self ,
28
- system : Prompt | str = "" ,
29
28
mcp_server_base_url : str = "" ,
30
- selected_tools : list [str ] | None = None ,
29
+ mcp_server_headers : dict [str , Any ] | None = None ,
30
+ system : Prompt | str = "" ,
31
+ tools : list [str ] | None = None ,
31
32
client : ModelClient = default_model_client ,
32
33
) -> None :
33
34
super ().__init__ (system = "" , client = client )
34
35
35
36
self ._mcp_server_base_url : str = mcp_server_base_url
37
+ self ._mcp_server_headers : dict [str , Any ] | None = mcp_server_headers
38
+
36
39
self ._mcp_client_transport : AsyncContextManager [tuple ] | None = None
37
40
self ._mcp_client_session : ClientSession | None = None
38
41
39
42
self ._mcp_swarm_agent : SwarmAgent | None = None
40
43
self ._mcp_system_prompt_config : Prompt | str = system
41
44
# The selected tools to use. If None, all available tools will be used.
42
- self ._mcp_selected_tools : list [str ] | None = selected_tools
45
+ self ._mcp_selected_tools : list [str ] | None = tools
43
46
44
47
@property
45
48
def mcp_server_base_url (self ) -> str :
46
49
if not self ._mcp_server_base_url :
47
50
raise ValueError ("MCP server base URL is empty" )
48
51
return self ._mcp_server_base_url
49
52
53
+ @property
54
+ def mcp_server_headers (self ) -> dict [str , Any ] | None :
55
+ return self ._mcp_server_headers
56
+
57
+ @property
58
+ def system (self ) -> Prompt | str :
59
+ """Note that this property is different from the `system` property in ChatAgent."""
60
+ return self ._mcp_system_prompt_config
61
+
62
+ @property
63
+ def tools (self ) -> list [str ] | None :
64
+ """Note that this property is different from the `tools` property in ChatAgent."""
65
+ return self ._mcp_selected_tools
66
+
50
67
def _make_mcp_client_transport (self ) -> AsyncContextManager [tuple ]:
51
68
if self .mcp_server_base_url .startswith (("http://" , "https://" )):
52
69
url = urljoin (self .mcp_server_base_url , "sse" )
53
- return sse_client (url = url )
70
+ return sse_client (url = url , headers = self . mcp_server_headers )
54
71
else :
55
72
# Mainly for testing purposes.
56
73
command , arg = self .mcp_server_base_url .split (" " , 1 )
@@ -88,8 +105,8 @@ async def _handle_data(self) -> None:
88
105
89
106
async def get_swarm_agent (self ) -> SwarmAgent :
90
107
if not self ._mcp_swarm_agent :
91
- system = await self ._get_prompt (self ._mcp_system_prompt_config )
92
- tools = await self ._get_tools ()
108
+ system = await self ._get_prompt (self .system )
109
+ tools = await self ._get_tools (self . tools )
93
110
self ._mcp_swarm_agent = SwarmAgent (
94
111
name = self .name ,
95
112
model = self .client .model ,
@@ -117,13 +134,13 @@ async def _get_prompt(self, prompt_config: Prompt | str) -> str:
117
134
case _: # ImageContent() or EmbeddedResource() or other types
118
135
return ""
119
136
120
- async def _get_tools (self ) -> list [Callable ]:
137
+ async def _get_tools (self , selected_tools : list [ str ] | None ) -> list [Callable ]:
121
138
result = await self ._mcp_client_session .list_tools ()
122
139
123
140
def filter_tool (t : Tool ) -> bool :
124
- if self . _mcp_selected_tools is None :
141
+ if selected_tools is None :
125
142
return True
126
- return t .name in self . _mcp_selected_tools
143
+ return t .name in selected_tools
127
144
128
145
tools = [self ._make_tool (t ) for t in result .tools if filter_tool (t )]
129
146
0 commit comments