Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu committed Jan 13, 2024
1 parent 2e519b0 commit e3c7fe6
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 1 deletion.
17 changes: 16 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,6 +1655,7 @@ def register_for_llm(
*,
name: Optional[str] = None,
description: Optional[str] = None,
api_style: Optional[str] = None,
) -> Callable[[F], F]:
"""Decorator factory for registering a function to be used by an agent.
Expand All @@ -1667,6 +1668,9 @@ def register_for_llm(
name (optional(str)): name of the function. If None, the function name will be used (default: None).
description (optional(str)): description of the function (default: None). It is mandatory
for the initial decorator, but the following ones can omit it.
api_style: (optional(str)): the API style for function call.
For Azure OpenAI API as of 2023-09-01-preview, you should set this to
`"function"`. By default, it uses the OpenAI API's tool calling style.
Returns:
The decorator for registering a function to be used by an agent.
Expand All @@ -1680,6 +1684,13 @@ def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c
return a + str(b * c)
```
For Azure OpenAI API as of 2023-09-01-preview, you should set `api_style` to `"function"`:
```
@agent2.register_for_llm(api_style="function")
def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14) -> str:
return a + str(b * c)
```
"""

def _decorator(func: F) -> F:
Expand Down Expand Up @@ -1716,7 +1727,11 @@ def _decorator(func: F) -> F:
if self.llm_config is None:
raise RuntimeError("LLM config must be setup before registering a function for LLM.")

self.update_tool_signature(f, is_remove=False)
if api_style == "function":
f = f["function"]
self.update_function_signature(f, is_remove=False)
else:
self.update_tool_signature(f, is_remove=False)

return func

Expand Down
71 changes: 71 additions & 0 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,77 @@ async def exec_sh(script: Annotated[str, "Valid shell script to execute."]) -> s
assert agent3.llm_config["tools"] == expected3


def test_register_for_llm_api_style_function():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
agent3 = ConversableAgent(name="agent3", llm_config={"config_list": []})
agent2 = ConversableAgent(name="agent2", llm_config={"config_list": []})
agent1 = ConversableAgent(name="agent1", llm_config={"config_list": []})

@agent3.register_for_llm(api_style="function")
@agent2.register_for_llm(name="python", api_style="function")
@agent1.register_for_llm(
description="run cell in ipython and return the execution result.", api_style="function"
)
def exec_python(cell: Annotated[str, "Valid Python cell to execute."]) -> str:
pass

expected1 = [
{
"description": "run cell in ipython and return the execution result.",
"name": "exec_python",
"parameters": {
"type": "object",
"properties": {
"cell": {
"type": "string",
"description": "Valid Python cell to execute.",
}
},
"required": ["cell"],
},
}
]
expected2 = copy.deepcopy(expected1)
expected2[0]["name"] = "python"
expected3 = expected2

assert agent1.llm_config["functions"] == expected1
assert agent2.llm_config["functions"] == expected2
assert agent3.llm_config["functions"] == expected3

@agent3.register_for_llm(api_style="function")
@agent2.register_for_llm(api_style="function")
@agent1.register_for_llm(
name="sh", description="run a shell script and return the execution result.", api_style="function"
)
async def exec_sh(script: Annotated[str, "Valid shell script to execute."]) -> str:
pass

expected1 = expected1 + [
{
"name": "sh",
"description": "run a shell script and return the execution result.",
"parameters": {
"type": "object",
"properties": {
"script": {
"type": "string",
"description": "Valid shell script to execute.",
}
},
"required": ["script"],
},
}
]
expected2 = expected2 + [expected1[1]]
expected3 = expected3 + [expected1[1]]

assert agent1.llm_config["functions"] == expected1
assert agent2.llm_config["functions"] == expected2
assert agent3.llm_config["functions"] == expected3


def test_register_for_llm_without_description():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
Expand Down

0 comments on commit e3c7fe6

Please sign in to comment.