Skip to content

Commit

Permalink
Fix registration of async functions (#1201)
Browse files Browse the repository at this point in the history
* bug fix for async functions

* Update test_conversable_agent.py

Co-authored-by: Chi Wang <[email protected]>

* Update test/agentchat/test_conversable_agent.py

Co-authored-by: Chi Wang <[email protected]>

* commented out cell in a notebook until issue #1205 is not fixed

---------

Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
davorrunje and sonichi authored Jan 11, 2024
1 parent fba7cae commit 2e519b0
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 116 deletions.
2 changes: 1 addition & 1 deletion autogen/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic.version import VERSION as PYDANTIC_VERSION
from typing_extensions import get_origin

__all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema")
__all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema", "evaluate_forwardref")

PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")

Expand Down
34 changes: 26 additions & 8 deletions autogen/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
return get_typed_annotation(annotation, globalns)


def get_param_annotations(typed_signature: inspect.Signature) -> Dict[int, Union[Annotated[Type, str], Type]]:
def get_param_annotations(typed_signature: inspect.Signature) -> Dict[int, Union[Annotated[Type[Any], str], Type[Any]]]:
"""Get the type annotations of the parameters of a function
Args:
Expand Down Expand Up @@ -111,7 +111,7 @@ class ToolFunction(BaseModel):


def get_parameter_json_schema(
k: str, v: Union[Annotated[Type, str], Type], default_values: Dict[str, Any]
k: str, v: Union[Annotated[Type[Any], str], Type[Any]], default_values: Dict[str, Any]
) -> JsonSchemaValue:
"""Get a JSON schema for a parameter as defined by the OpenAI API
Expand All @@ -124,10 +124,14 @@ def get_parameter_json_schema(
A Pydanitc model for the parameter
"""

def type2description(k: str, v: Union[Annotated[Type, str], Type]) -> str:
def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
# handles Annotated
if hasattr(v, "__metadata__"):
return v.__metadata__[0]
retval = v.__metadata__[0]
if isinstance(retval, str):
return retval
else:
raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
else:
return k

Expand Down Expand Up @@ -166,7 +170,9 @@ def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]:


def get_parameters(
required: List[str], param_annotations: Dict[str, Union[Annotated[Type, str], Type]], default_values: Dict[str, Any]
required: List[str],
param_annotations: Dict[str, Union[Annotated[Type[Any], str], Type[Any]]],
default_values: Dict[str, Any],
) -> Parameters:
"""Get the parameters of a function as defined by the OpenAI API
Expand Down Expand Up @@ -278,7 +284,7 @@ def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Paramet
return model_dump(function)


def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[T, Type], BaseModel]]:
def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[T, Type[Any]], BaseModel]]:
"""Get a function to load a parameter if it is a Pydantic model
Args:
Expand Down Expand Up @@ -319,15 +325,27 @@ def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]:

# a function that loads the parameters before calling the original function
@functools.wraps(func)
def load_parameters_if_needed(*args, **kwargs):
def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
# load the BaseModels if needed
for k, f in kwargs_mapping.items():
kwargs[k] = f(kwargs[k], param_annotations[k])

# call the original function
return func(*args, **kwargs)

return load_parameters_if_needed
@functools.wraps(func)
async def _a_load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
# load the BaseModels if needed
for k, f in kwargs_mapping.items():
kwargs[k] = f(kwargs[k], param_annotations[k])

# call the original function
return await func(*args, **kwargs)

if inspect.iscoroutinefunction(func):
return _a_load_parameters_if_needed
else:
return _load_parameters_if_needed


def serialize_to_str(x: Any) -> str:
Expand Down
133 changes: 27 additions & 106 deletions notebook/agentchat_function_call_async.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "9fb85afb",
"metadata": {},
"outputs": [
Expand All @@ -134,40 +134,46 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested function Call: timer *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_thUjscBN349eGd6xh3XrVT18): timer *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
"\u001b[32m******************************************\u001b[0m\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling function \"timer\" *****\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"timer\" *****\u001b[0m\n",
"Timer is done!\n",
"\u001b[32m**************************************************\u001b[0m\n",
"\u001b[32m**********************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested function Call: stopwatch *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_ubo7cKE3TKumGHkqGjQtZisy): stopwatch *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
"\u001b[32m**********************************************\u001b[0m\n",
"\u001b[32m**************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION stopwatch...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling function \"stopwatch\" *****\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"stopwatch\" *****\u001b[0m\n",
"Stopwatch is done!\n",
"\u001b[32m******************************************************\u001b[0m\n",
"\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"Both the timer and the stopwatch for 5 seconds have been completed. \n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n"
Expand Down Expand Up @@ -239,7 +245,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"id": "2472f95c",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -274,105 +280,20 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"id": "e2c9267a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\n",
"1) Create a timer for 5 seconds.\n",
"2) a stopwatch for 5 seconds.\n",
"3) Pretty print the result as md.\n",
"4) when 1-3 are done, terminate the group chat\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Suggested function Call: timer *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
"\u001b[32m******************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Response from calling function \"timer\" *****\u001b[0m\n",
"Timer is done!\n",
"\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Suggested function Call: stopwatch *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
"\u001b[32m**********************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION stopwatch...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Response from calling function \"stopwatch\" *****\u001b[0m\n",
"Stopwatch is done!\n",
"\u001b[32m******************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mMarkdown_agent\u001b[0m (to chat_manager):\n",
"\n",
"The results are as follows:\n",
"\n",
"- Timer: Completed after `5 seconds`.\n",
"- Stopwatch: Recorded time of `5 seconds`.\n",
"\n",
"**Timer and Stopwatch Summary:**\n",
"Both the timer and stopwatch were set for `5 seconds` and have now concluded successfully. \n",
"\n",
"Now, let's proceed to terminate the group chat as requested.\n",
"\u001b[32m***** Suggested function Call: terminate_group_chat *****\u001b[0m\n",
"Arguments: \n",
"{\"message\":\"All tasks have been completed. The group chat will now be terminated. Goodbye!\"}\n",
"\u001b[32m*********************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION terminate_group_chat...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Response from calling function \"terminate_group_chat\" *****\u001b[0m\n",
"[GROUPCHAT_TERMINATE] All tasks have been completed. The group chat will now be terminated. Goodbye!\n",
"\u001b[32m*****************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"outputs": [],
"source": [
"await user_proxy.a_initiate_chat( # noqa: F704\n",
" manager,\n",
" message=\"\"\"\n",
"1) Create a timer for 5 seconds.\n",
"2) a stopwatch for 5 seconds.\n",
"3) Pretty print the result as md.\n",
"4) when 1-3 are done, terminate the group chat\"\"\",\n",
")"
"# todo: remove comment after fixing https://github.com/microsoft/autogen/issues/1205\n",
"# await user_proxy.a_initiate_chat( # noqa: F704\n",
"# manager,\n",
"# message=\"\"\"\n",
"# 1) Create a timer for 5 seconds.\n",
"# 2) a stopwatch for 5 seconds.\n",
"# 3) Pretty print the result as md.\n",
"# 4) when 1-3 are done, terminate the group chat\"\"\",\n",
"# )"
]
},
{
Expand Down
Loading

0 comments on commit 2e519b0

Please sign in to comment.