Skip to content

Commit 4b5f599

Browse files
jtoyJasonekzhu
authored
add warning if duplicate function is registered (#2159)
* add warning if duplicate function is registereed * check _function_map and llm_config * check function_map and llm_config * use register_function and llm_config * cleanups * cleanups * warning test * warning test * more test coverage * use a fake config * formatting * formatting --------- Co-authored-by: Jason <[email protected]> Co-authored-by: Eric Zhu <[email protected]>
1 parent b698a98 commit 4b5f599

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

autogen/agentchat/conversable_agent.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -2406,6 +2406,8 @@ def register_function(self, function_map: Dict[str, Union[Callable, None]]):
24062406
self._assert_valid_name(name)
24072407
if func is None and name not in self._function_map.keys():
24082408
warnings.warn(f"The function {name} to remove doesn't exist", name)
2409+
if name in self._function_map:
2410+
warnings.warn(f"Function '{name}' is being overridden.", UserWarning)
24092411
self._function_map.update(function_map)
24102412
self._function_map = {k: v for k, v in self._function_map.items() if v is not None}
24112413

@@ -2442,6 +2444,9 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None)
24422444

24432445
self._assert_valid_name(func_sig["name"])
24442446
if "functions" in self.llm_config.keys():
2447+
if any(func["name"] == func_sig["name"] for func in self.llm_config["functions"]):
2448+
warnings.warn(f"Function '{func_sig['name']}' is being overridden.", UserWarning)
2449+
24452450
self.llm_config["functions"] = [
24462451
func for func in self.llm_config["functions"] if func.get("name") != func_sig["name"]
24472452
] + [func_sig]
@@ -2481,7 +2486,9 @@ def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None):
24812486
f"The tool signature must be of the type dict. Received tool signature type {type(tool_sig)}"
24822487
)
24832488
self._assert_valid_name(tool_sig["function"]["name"])
2484-
if "tools" in self.llm_config.keys():
2489+
if "tools" in self.llm_config:
2490+
if any(tool["function"]["name"] == tool_sig["function"]["name"] for tool in self.llm_config["tools"]):
2491+
warnings.warn(f"Function '{tool_sig['function']['name']}' is being overridden.", UserWarning)
24852492
self.llm_config["tools"] = [
24862493
tool
24872494
for tool in self.llm_config["tools"]

test/agentchat/test_conversable_agent.py

+59
Original file line numberDiff line numberDiff line change
@@ -1403,6 +1403,64 @@ def test_http_client():
14031403
)
14041404

14051405

1406+
def test_adding_duplicate_function_warning():
1407+
1408+
config_base = [{"base_url": "http://0.0.0.0:8000", "api_key": "NULL"}]
1409+
1410+
agent = autogen.ConversableAgent(
1411+
"jtoy",
1412+
llm_config={"config_list": config_base},
1413+
)
1414+
1415+
def sample_function():
1416+
pass
1417+
1418+
agent.register_function(
1419+
function_map={
1420+
"sample_function": sample_function,
1421+
}
1422+
)
1423+
agent.update_function_signature(
1424+
{
1425+
"name": "foo",
1426+
},
1427+
is_remove=False,
1428+
)
1429+
agent.update_tool_signature(
1430+
{
1431+
"type": "function",
1432+
"function": {
1433+
"name": "yo",
1434+
},
1435+
},
1436+
is_remove=False,
1437+
)
1438+
1439+
with pytest.warns(UserWarning, match="Function 'sample_function' is being overridden."):
1440+
agent.register_function(
1441+
function_map={
1442+
"sample_function": sample_function,
1443+
}
1444+
)
1445+
with pytest.warns(UserWarning, match="Function 'foo' is being overridden."):
1446+
agent.update_function_signature(
1447+
{
1448+
"name": "foo",
1449+
},
1450+
is_remove=False,
1451+
)
1452+
with pytest.warns(UserWarning, match="Function 'yo' is being overridden."):
1453+
agent.update_tool_signature(
1454+
{
1455+
"type": "function",
1456+
"function": {
1457+
"name": "yo",
1458+
},
1459+
},
1460+
is_remove=False,
1461+
)
1462+
1463+
14061464
if __name__ == "__main__":
14071465
# test_trigger()
14081466
# test_context()
@@ -1414,4 +1472,5 @@ def test_http_client():
14141472
# test_process_before_send()
14151473
# test_message_func()
14161474
test_summary()
1475+
test_adding_duplicate_function_warning()
14171476
# test_function_registration_e2e_sync()

0 commit comments

Comments
 (0)