Skip to content

Commit b872789

Browse files
skzhang1“skzhang1”
and
“skzhang1”
authored
Support functions removing in ConversableAgent (#1786)
* fix * update * reformat --------- Co-authored-by: “skzhang1” <“[email protected]”>
1 parent c6f6707 commit b872789

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

autogen/agentchat/conversable_agent.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -2129,15 +2129,18 @@ async def a_generate_init_message(self, **context) -> Union[str, Dict]:
21292129
self._process_carryover(context)
21302130
return context["message"]
21312131

2132-
def register_function(self, function_map: Dict[str, Callable]):
2132+
def register_function(self, function_map: Dict[str, Union[Callable, None]]):
21332133
"""Register functions to the agent.
21342134
21352135
Args:
2136-
function_map: a dictionary mapping function names to functions.
2136+
function_map: a dictionary mapping function names to functions. if function_map[name] is None, the function will be removed from the function_map.
21372137
"""
2138-
for name in function_map.keys():
2138+
for name, func in function_map.items():
21392139
self._assert_valid_name(name)
2140+
if func is None and name not in self._function_map.keys():
2141+
warnings.warn(f"The function {name} to remove doesn't exist", name)
21402142
self._function_map.update(function_map)
2143+
self._function_map = {k: v for k, v in self._function_map.items() if v is not None}
21412144

21422145
def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None):
21432146
"""update a function_signature in the LLM configuration for function_call.

test/agentchat/test_conversable_agent.py

+10
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,16 @@ def exec_sh(script: str) -> None:
559559
assert agent.function_map["python"] == exec_python
560560
assert agent.function_map["sh"] == exec_sh
561561

562+
# remove the functions
563+
agent.register_function(
564+
function_map={
565+
"python": None,
566+
}
567+
)
568+
569+
assert set(agent.function_map.keys()) == {"sh"}
570+
assert agent.function_map["sh"] == exec_sh
571+
562572

563573
def test__wrap_function_sync():
564574
CurrencySymbol = Literal["USD", "EUR"]

0 commit comments

Comments
 (0)