Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decorator for function calling #1018

Merged
merged 44 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
2e3a204
add function decorator to converasble agent
davorrunje Dec 15, 2023
a79356e
polishing
davorrunje Dec 18, 2023
b1995a5
Merge remote-tracking branch 'origin/main' into add-function-decorator
davorrunje Dec 18, 2023
38f7abe
polishing
davorrunje Dec 18, 2023
cceef1b
Merge remote-tracking branch 'origin/main' into add-function-decorator
davorrunje Dec 19, 2023
2d80140
Merge remote-tracking branch 'origin/main' into add-function-decorator
davorrunje Dec 19, 2023
721a9fa
added function decorator to the notebook with async function calls
davorrunje Dec 19, 2023
7486d06
added support for return type hint and JSON encoding of returned valu…
davorrunje Dec 20, 2023
79dc2e5
polishing
davorrunje Dec 20, 2023
6e05fba
polishing
davorrunje Dec 20, 2023
d9d624f
refactored async case
davorrunje Dec 20, 2023
b2882da
Python 3.8 support added
davorrunje Dec 21, 2023
513d3c5
Merge branch 'microsoft:main' into add-function-decorator
davorrunje Dec 21, 2023
b92a1fc
Merge remote-tracking branch 'origin/main' into add-function-decorator
davorrunje Dec 21, 2023
a0176c8
Merge branch 'add-function-decorator' of github.com:davorrunje/autoge…
davorrunje Dec 21, 2023
06fd4fb
polishing
davorrunje Dec 21, 2023
e4b131e
polishing
davorrunje Dec 21, 2023
4cd0b84
missing docs added
davorrunje Dec 21, 2023
522a247
refacotring and changes as requested
davorrunje Dec 21, 2023
2fcc353
getLogger
davorrunje Dec 21, 2023
3aa6686
documentation added
davorrunje Dec 21, 2023
e77f837
Merge remote-tracking branch 'origin/main' into add-function-decorator
davorrunje Dec 21, 2023
8f339fc
test fix
davorrunje Dec 22, 2023
b9214bb
test fix
davorrunje Dec 22, 2023
d6f4a21
Merge remote-tracking branch 'origin/main' into add-function-decorator
davorrunje Dec 22, 2023
4d6b342
added testing of agentchat_function_call_currency_calculator.ipynb to…
davorrunje Dec 22, 2023
8fdfedf
Merge remote-tracking branch 'origin/main' into add-function-decorator
davorrunje Dec 22, 2023
89df135
added support for Pydantic parameters in function decorator
davorrunje Dec 22, 2023
bbe1f4f
polishing
davorrunje Dec 23, 2023
46fee6f
Update website/docs/Use-Cases/agent_chat.md
ekzhu Dec 23, 2023
3ca57b1
Update website/docs/Use-Cases/agent_chat.md
ekzhu Dec 23, 2023
95772ca
Merge remote-tracking branch 'origin/main' into add-function-decorator
davorrunje Dec 24, 2023
53f2e9a
Merge branch 'add-function-decorator' of github.com:davorrunje/autoge…
davorrunje Dec 24, 2023
22bcb17
Merge branch 'main' into add-function-decorator
davorrunje Dec 24, 2023
b8b3a62
fixes problem with logprob parameter in openai.types.chat.chat_comple…
davorrunje Dec 24, 2023
4ac0bab
Merge remote-tracking branch 'origin/main' into add-function-decorator
davorrunje Dec 24, 2023
5db274d
get 100% code coverage on code added
davorrunje Dec 24, 2023
2ffb0bd
updated docs
davorrunje Dec 24, 2023
e226f31
default values added to JSON schema
davorrunje Dec 24, 2023
b0352b2
serialization using json.dump() add for values not string or BaseModel
davorrunje Dec 24, 2023
144f40d
added limit to openai version because of breaking changes in 1.5.0
davorrunje Dec 24, 2023
c354d83
Merge branch 'add-function-decorator' of github.com:davorrunje/autoge…
davorrunje Dec 24, 2023
e11bbf3
added line-by-line comments in docs to explain the process
davorrunje Dec 24, 2023
158698b
polishing
davorrunje Dec 25, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ node_modules/
*.log

# Python virtualenv
.venv
.venv*

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
110 changes: 110 additions & 0 deletions autogen/_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import Any, Dict, Optional, Tuple, Type, Union, get_args

from pydantic import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from typing_extensions import get_origin

__all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema")

PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")

if not PYDANTIC_V1:
from pydantic import TypeAdapter
from pydantic._internal._typing_extra import eval_type_lenient as evaluate_forwardref
from pydantic.json_schema import JsonSchemaValue

def type2schema(t: Optional[Type]) -> JsonSchemaValue:
"""Convert a type to a JSON schema

Args:
t (Type): The type to convert

Returns:
JsonSchemaValue: The JSON schema
"""
return TypeAdapter(t).json_schema()

def model_dump(model: BaseModel) -> Dict[str, Any]:
"""Convert a pydantic model to a dict

Args:
model (BaseModel): The model to convert

Returns:
Dict[str, Any]: The dict representation of the model

"""
return model.model_dump()

def model_dump_json(model: BaseModel) -> str:
"""Convert a pydantic model to a JSON string

Args:
model (BaseModel): The model to convert

Returns:
str: The JSON string representation of the model
"""
return model.model_dump_json()


# Remove this once we drop support for pydantic 1.x
else:
from pydantic import schema_of
from pydantic.typing import evaluate_forwardref as evaluate_forwardref

JsonSchemaValue = Dict[str, Any]

def type2schema(t: Optional[Type]) -> JsonSchemaValue:
"""Convert a type to a JSON schema

Args:
t (Type): The type to convert

Returns:
JsonSchemaValue: The JSON schema
"""
if PYDANTIC_V1:
if t is None:
return {"type": "null"}
elif get_origin(t) is Union:
return {"anyOf": [type2schema(tt) for tt in get_args(t)]}
elif get_origin(t) in [Tuple, tuple]:
prefixItems = [type2schema(tt) for tt in get_args(t)]
return {
"maxItems": len(prefixItems),
"minItems": len(prefixItems),
"prefixItems": prefixItems,
"type": "array",
}

d = schema_of(t)
if "title" in d:
d.pop("title")
if "description" in d:
d.pop("description")

return d

def model_dump(model: BaseModel) -> Dict[str, Any]:
"""Convert a pydantic model to a dict

Args:
model (BaseModel): The model to convert

Returns:
Dict[str, Any]: The dict representation of the model

"""
return model.dict()

def model_dump_json(model: BaseModel) -> str:
"""Convert a pydantic model to a JSON string

Args:
model (BaseModel): The model to convert

Returns:
str: The JSON string representation of the model
"""
return model.json()
4 changes: 3 additions & 1 deletion autogen/agentchat/contrib/math_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from time import sleep

from autogen._pydantic import PYDANTIC_V1
from autogen.agentchat import Agent, UserProxyAgent
from autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang
from autogen.math_utils import get_answer
Expand Down Expand Up @@ -384,7 +385,8 @@ class WolframAlphaAPIWrapper(BaseModel):
class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
if PYDANTIC_V1:
extra = Extra.forbid

@root_validator(skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
Expand Down
167 changes: 163 additions & 4 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import asyncio
import copy
import functools
import inspect
import json
import logging
from collections import defaultdict
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union

from autogen import OpenAIWrapper
from autogen.code_utils import DEFAULT_MODEL, UNKNOWN, content_str, execute_code, extract_code, infer_lang
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union

from .. import OpenAIWrapper
from .._pydantic import model_dump_json
from ..code_utils import DEFAULT_MODEL, UNKNOWN, content_str, execute_code, extract_code, infer_lang
from ..function_utils import get_function_schema, load_basemodels_if_needed
from .agent import Agent

try:
Expand All @@ -19,8 +21,12 @@ def colored(x, *args, **kwargs):
return x


__all__ = ("ConversableAgent",)

logger = logging.getLogger(__name__)

F = TypeVar("F", bound=Callable[..., Any])


class ConversableAgent(Agent):
"""(In preview) A class for generic conversable agents which can be configured as assistant or user proxy.
Expand Down Expand Up @@ -1330,3 +1336,156 @@ def can_execute_function(self, name: str) -> bool:
def function_map(self) -> Dict[str, Callable]:
"""Return the function map."""
return self._function_map

def _wrap_function(self, func: F) -> F:
"""Wrap the function to dump the return value to json.

Handles both sync and async functions.

Args:
func: the function to be wrapped.

Returns:
The wrapped function.
"""

@load_basemodels_if_needed
@functools.wraps(func)
def _wrapped_func(*args, **kwargs):
retval = func(*args, **kwargs)
return retval if isinstance(retval, str) else model_dump_json(retval)

@load_basemodels_if_needed
@functools.wraps(func)
async def _a_wrapped_func(*args, **kwargs):
retval = await func(*args, **kwargs)
return retval if isinstance(retval, str) else model_dump_json(retval)

wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func

# needed for testing
wrapped_func._origin = func

return wrapped_func

def register_for_llm(
self,
*,
name: Optional[str] = None,
description: Optional[str] = None,
) -> Callable[[F], F]:
"""Decorator factory for registering a function to be used by an agent.

It's return value is used to decorate a function to be registered to the agent. The function uses type hints to
specify the arguments and return type. The function name is used as the default name for the function,
but a custom name can be provided. The function description is used to describe the function in the
agent's configuration.

Args:
name (optional(str)): name of the function. If None, the function name will be used (default: None).
description (optional(str)): description of the function (default: None). It is mandatory
for the initial decorator, but the following ones can omit it.

Returns:
The decorator for registering a function to be used by an agent.

Examples:
```
@user_proxy.register_for_execution()
@agent2.register_for_llm()
@agent1.register_for_llm(description="This is a very useful function")
def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14) -> str:
return a + str(b * c)
```

"""

def _decorator(func: F) -> F:
"""Decorator for registering a function to be used by an agent.

Args:
func: the function to be registered.

Returns:
The function to be registered, with the _description attribute set to the function description.

Raises:
ValueError: if the function description is not provided and not propagated by a previous decorator.
RuntimeError: if the LLM config is not set up before registering a function.

"""
# name can be overwriten by the parameter, by default it is the same as function name
if name:
func._name = name
elif not hasattr(func, "_name"):
func._name = func.__name__

# description is propagated from the previous decorator, but it is mandatory for the first one
if description:
func._description = description
else:
if not hasattr(func, "_description"):
raise ValueError("Function description is required, none found.")

# get JSON schema for the function
f = get_function_schema(func, name=func._name, description=func._description)

# register the function to the agent if there is LLM config, raise an exception otherwise
if self.llm_config is None:
raise RuntimeError("LLM config must be setup before registering a function for LLM.")

self.update_function_signature(f, is_remove=False)

return func

return _decorator

def register_for_execution(
self,
name: Optional[str] = None,
) -> Callable[[F], F]:
"""Decorator factory for registering a function to be executed by an agent.

It's return value is used to decorate a function to be registered to the agent.

Args:
name (optional(str)): name of the function. If None, the function name will be used (default: None).

Returns:
The decorator for registering a function to be used by an agent.

Examples:
```
@user_proxy.register_for_execution()
@agent2.register_for_llm()
@agent1.register_for_llm(description="This is a very useful function")
def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14):
return a + str(b * c)
```

"""

def _decorator(func: F) -> F:
"""Decorator for registering a function to be used by an agent.

Args:
func: the function to be registered.

Returns:
The function to be registered, with the _description attribute set to the function description.

Raises:
ValueError: if the function description is not provided and not propagated by a previous decorator.

"""
# name can be overwriten by the parameter, by default it is the same as function name
if name:
func._name = name
elif not hasattr(func, "_name"):
func._name = func.__name__

self.register_function({func._name: self._wrap_function(func)})

return func

return _decorator
Loading
Loading