Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix registration of async functions #1201

Merged
merged 7 commits into from
Jan 11, 2024
Merged
2 changes: 1 addition & 1 deletion autogen/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic.version import VERSION as PYDANTIC_VERSION
from typing_extensions import get_origin

__all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema")
__all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema", "evaluate_forwardref")

PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")

Expand Down
4 changes: 1 addition & 3 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,9 +912,7 @@ async def a_generate_tool_calls_reply(
return True, {
"role": "tool",
"tool_responses": tool_returns,
"content": "\n\n".join(
[self._str_for_tool_response(tool_return["content"]) for tool_return in tool_returns]
),
"content": "\n\n".join([self._str_for_tool_response(tool_return) for tool_return in tool_returns]),
davorrunje marked this conversation as resolved.
Show resolved Hide resolved
}

return False, None
Expand Down
34 changes: 26 additions & 8 deletions autogen/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
return get_typed_annotation(annotation, globalns)


def get_param_annotations(typed_signature: inspect.Signature) -> Dict[int, Union[Annotated[Type, str], Type]]:
def get_param_annotations(typed_signature: inspect.Signature) -> Dict[int, Union[Annotated[Type[Any], str], Type[Any]]]:
"""Get the type annotations of the parameters of a function

Args:
Expand Down Expand Up @@ -111,7 +111,7 @@ class ToolFunction(BaseModel):


def get_parameter_json_schema(
k: str, v: Union[Annotated[Type, str], Type], default_values: Dict[str, Any]
k: str, v: Union[Annotated[Type[Any], str], Type[Any]], default_values: Dict[str, Any]
) -> JsonSchemaValue:
"""Get a JSON schema for a parameter as defined by the OpenAI API

Expand All @@ -124,10 +124,14 @@ def get_parameter_json_schema(
A Pydanitc model for the parameter
"""

def type2description(k: str, v: Union[Annotated[Type, str], Type]) -> str:
def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
# handles Annotated
if hasattr(v, "__metadata__"):
return v.__metadata__[0]
retval = v.__metadata__[0]
if isinstance(retval, str):
return retval
else:
raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
davorrunje marked this conversation as resolved.
Show resolved Hide resolved
else:
return k

Expand Down Expand Up @@ -166,7 +170,9 @@ def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]:


def get_parameters(
required: List[str], param_annotations: Dict[str, Union[Annotated[Type, str], Type]], default_values: Dict[str, Any]
required: List[str],
param_annotations: Dict[str, Union[Annotated[Type[Any], str], Type[Any]]],
default_values: Dict[str, Any],
) -> Parameters:
"""Get the parameters of a function as defined by the OpenAI API

Expand Down Expand Up @@ -278,7 +284,7 @@ def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Paramet
return model_dump(function)


def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[T, Type], BaseModel]]:
def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[T, Type[Any]], BaseModel]]:
"""Get a function to load a parameter if it is a Pydantic model

Args:
Expand Down Expand Up @@ -319,15 +325,27 @@ def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]:

# a function that loads the parameters before calling the original function
@functools.wraps(func)
def load_parameters_if_needed(*args, **kwargs):
def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
# load the BaseModels if needed
for k, f in kwargs_mapping.items():
kwargs[k] = f(kwargs[k], param_annotations[k])

# call the original function
return func(*args, **kwargs)

return load_parameters_if_needed
@functools.wraps(func)
davorrunje marked this conversation as resolved.
Show resolved Hide resolved
async def _a_load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
# load the BaseModels if needed
for k, f in kwargs_mapping.items():
kwargs[k] = f(kwargs[k], param_annotations[k])

# call the original function
return await func(*args, **kwargs)

if inspect.iscoroutinefunction(func):
return _a_load_parameters_if_needed
else:
return _load_parameters_if_needed


def serialize_to_str(x: Any) -> str:
Expand Down
165 changes: 165 additions & 0 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import asyncio
import copy
import sys
import time
from typing import Any, Callable, Dict, Literal
import unittest

import pytest
from unittest.mock import patch
from pydantic import BaseModel, Field
from typing_extensions import Annotated
import autogen

from autogen.agentchat import ConversableAgent, UserProxyAgent
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
from conftest import skip_openai

try:
Expand Down Expand Up @@ -445,6 +451,8 @@ def currency_calculator(
== '{"currency":"EUR","amount":100.1}'
)

assert not asyncio.coroutines.iscoroutinefunction(currency_calculator)
davorrunje marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.asyncio
async def test__wrap_function_async():
Expand Down Expand Up @@ -481,6 +489,8 @@ async def currency_calculator(
== '{"currency":"EUR","amount":100.1}'
)

assert asyncio.coroutines.iscoroutinefunction(currency_calculator)
davorrunje marked this conversation as resolved.
Show resolved Hide resolved


def get_origin(d: Dict[str, Callable[..., Any]]) -> Dict[str, Callable[..., Any]]:
return {k: v._origin for k, v in d.items()}
Expand Down Expand Up @@ -624,6 +634,161 @@ async def exec_sh(script: Annotated[str, "Valid shell script to execute."]):
assert get_origin(user_proxy_1.function_map) == expected_function_map


@pytest.mark.skipif(
skip or not sys.version.startswith("3.10"),
reason="do not run if openai is not installed or py!=3.10",
)
def test_function_restration_e2e_sync() -> None:
davorrunje marked this conversation as resolved.
Show resolved Hide resolved
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
filter_dict={
"model": ["gpt-4", "gpt-4-0314", "gpt4", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-v0314"],
},
file_location=KEY_LOC,
)

llm_config = {
"config_list": config_list,
}

coder = autogen.AssistantAgent(
name="chatbot",
system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.",
llm_config=llm_config,
)

# create a UserProxyAgent instance named "user_proxy"
user_proxy = autogen.UserProxyAgent(
name="user_proxy",
system_message="A proxy for the user for executing code.",
is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
human_input_mode="NEVER",
max_consecutive_auto_reply=10,
code_execution_config={"work_dir": "coding"},
)

# define functions according to the function description
timer_mock = unittest.mock.MagicMock()
stopwatch_mock = unittest.mock.MagicMock()

# An example async function
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a timer for N seconds")
def timer(num_seconds: Annotated[str, "Number of seconds in the timer."]) -> str:
print("timer is running")
for i in range(int(num_seconds)):
print(".", end="")
time.sleep(0.01)
print()

timer_mock(num_seconds=num_seconds)
return "Timer is done!"

# An example sync function
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a stopwatch for N seconds")
def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]) -> str:
print("stopwatch is running")
# assert False, "stopwatch's alive!"
for i in range(int(num_seconds)):
print(".", end="")
time.sleep(0.01)
print()

stopwatch_mock(num_seconds=num_seconds)
return "Stopwatch is done!"

# start the conversation
# 'await' is used to pause and resume code execution for async IO operations.
# Without 'await', an async function returns a coroutine object but doesn't execute the function.
# With 'await', the async function is executed and the current function is paused until the awaited function returns a result.
user_proxy.initiate_chat( # noqa: F704
coder,
message="Create a timer for 2 seconds and then a stopwatch for 3 seconds.",
)

timer_mock.assert_called_once_with(num_seconds="2")
stopwatch_mock.assert_called_once_with(num_seconds="3")


@pytest.mark.skipif(
skip or not sys.version.startswith("3.10"),
reason="do not run if openai is not installed or py!=3.10",
)
@pytest.mark.asyncio()
async def test_function_restration_e2e_async() -> None:
davorrunje marked this conversation as resolved.
Show resolved Hide resolved
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
filter_dict={
"model": ["gpt-4", "gpt-4-0314", "gpt4", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-v0314"],
},
file_location=KEY_LOC,
)

llm_config = {
"config_list": config_list,
}

coder = autogen.AssistantAgent(
name="chatbot",
system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.",
llm_config=llm_config,
)

# create a UserProxyAgent instance named "user_proxy"
user_proxy = autogen.UserProxyAgent(
name="user_proxy",
system_message="A proxy for the user for executing code.",
is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
human_input_mode="NEVER",
max_consecutive_auto_reply=10,
code_execution_config={"work_dir": "coding"},
)

# define functions according to the function description
timer_mock = unittest.mock.MagicMock()
stopwatch_mock = unittest.mock.MagicMock()

# An example async function
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a timer for N seconds")
async def timer(num_seconds: Annotated[str, "Number of seconds in the timer."]) -> str:
print("timer is running")
for i in range(int(num_seconds)):
print(".", end="")
await asyncio.sleep(0.01)
print()

timer_mock(num_seconds=num_seconds)
return "Timer is done!"

# An example sync function
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a stopwatch for N seconds")
def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]) -> str:
print("stopwatch is running")
# assert False, "stopwatch's alive!"
for i in range(int(num_seconds)):
print(".", end="")
time.sleep(0.01)
print()

stopwatch_mock(num_seconds=num_seconds)
return "Stopwatch is done!"

# start the conversation
# 'await' is used to pause and resume code execution for async IO operations.
# Without 'await', an async function returns a coroutine object but doesn't execute the function.
# With 'await', the async function is executed and the current function is paused until the awaited function returns a result.
await user_proxy.a_initiate_chat( # noqa: F704
coder,
message="Create a timer for 4 seconds and then a stopwatch for 5 seconds.",
)

timer_mock.assert_called_once_with(num_seconds="4")
stopwatch_mock.assert_called_once_with(num_seconds="5")


@pytest.mark.skipif(
skip,
reason="do not run if skipping openai",
Expand Down
23 changes: 22 additions & 1 deletion test/test_function_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import inspect
import unittest.mock
from typing import Dict, List, Literal, Optional, Tuple
Expand Down Expand Up @@ -355,21 +356,41 @@ def test_get_load_param_if_needed_function() -> None:
assert actual == expected, actual


def test_load_basemodels_if_needed() -> None:
def test_load_basemodels_if_needed_sync() -> None:
@load_basemodels_if_needed
def f(
base: Annotated[Currency, "Base currency"],
quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
) -> Tuple[Currency, CurrencySymbol]:
return base, quote_currency

assert not asyncio.coroutines.iscoroutinefunction(f)

actual = f(base={"currency": "USD", "amount": 123.45}, quote_currency="EUR")
assert isinstance(actual[0], Currency)
assert actual[0].amount == 123.45
assert actual[0].currency == "USD"
assert actual[1] == "EUR"


@pytest.mark.asyncio
davorrunje marked this conversation as resolved.
Show resolved Hide resolved
async def test_load_basemodels_if_needed_async() -> None:
@load_basemodels_if_needed
async def f(
base: Annotated[Currency, "Base currency"],
quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
) -> Tuple[Currency, CurrencySymbol]:
return base, quote_currency

assert asyncio.coroutines.iscoroutinefunction(f)

actual = await f(base={"currency": "USD", "amount": 123.45}, quote_currency="EUR")
assert isinstance(actual[0], Currency)
assert actual[0].amount == 123.45
assert actual[0].currency == "USD"
assert actual[1] == "EUR"


def test_serialize_to_json():
assert serialize_to_str("abc") == "abc"
assert serialize_to_str(123) == "123"
Expand Down
Loading