Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix types in tool tests #2285

Merged
merged 6 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions autogen/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pydantic._internal._typing_extra import eval_type_lenient as evaluate_forwardref
from pydantic.json_schema import JsonSchemaValue

def type2schema(t: Optional[Type[Any]]) -> JsonSchemaValue:
def type2schema(t: Any) -> JsonSchemaValue:
"""Convert a type to a JSON schema

Args:
Expand Down Expand Up @@ -55,7 +55,7 @@ def model_dump_json(model: BaseModel) -> str:

JsonSchemaValue = Dict[str, Any] # type: ignore[misc]

def type2schema(t: Optional[Type[Any]]) -> JsonSchemaValue:
def type2schema(t: Any) -> JsonSchemaValue:
"""Convert a type to a JSON schema

Args:
Expand Down
4 changes: 1 addition & 3 deletions autogen/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ class ToolFunction(BaseModel):
function: Annotated[Function, Field(description="Function under tool")]


def get_parameter_json_schema(
k: str, v: Union[Annotated[Type[Any], str], Type[Any]], default_values: Dict[str, Any]
) -> JsonSchemaValue:
def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> JsonSchemaValue:
"""Get a JSON schema for a parameter as defined by the OpenAI API

Args:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ files = [
"autogen/_pydantic.py",
"autogen/function_utils.py",
"autogen/io",
"test/test_pydantic.py",
"test/test_function_utils.py",
"test/io",
]

Expand Down
32 changes: 16 additions & 16 deletions test/test_function_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import inspect
import unittest.mock
from typing import Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple

import pytest
from pydantic import BaseModel, Field
Expand All @@ -25,11 +25,11 @@
)


def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1, *, d):
def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1, *, d): # type: ignore[no-untyped-def]
pass


def g(
def g( # type: ignore[empty-body]
a: Annotated[str, "Parameter a"],
b: int = 2,
c: Annotated[float, "Parameter c"] = 0.1,
Expand All @@ -39,7 +39,7 @@ def g(
pass


async def a_g(
async def a_g( # type: ignore[empty-body]
a: Annotated[str, "Parameter a"],
b: int = 2,
c: Annotated[float, "Parameter c"] = 0.1,
Expand Down Expand Up @@ -83,7 +83,7 @@ class B(BaseModel):
b: float
c: str

expected = {
expected: Dict[str, Any] = {
"description": "b",
"properties": {"b": {"title": "B", "type": "number"}, "c": {"title": "C", "type": "string"}},
"required": ["b", "c"],
Expand All @@ -107,7 +107,7 @@ def test_get_default_values() -> None:


def test_get_param_annotations() -> None:
def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"] = 1.0):
def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"] = 1.0): # type: ignore[no-untyped-def]
pass

expected = {"a": Annotated[str, "Parameter a"], "c": Annotated[float, "Parameter c"]}
Expand All @@ -119,14 +119,14 @@ def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"]


def test_get_missing_annotations() -> None:
def _f1(a: str, b=2):
def _f1(a: str, b=2): # type: ignore[no-untyped-def]
pass

missing, unannotated_with_default = get_missing_annotations(get_typed_signature(_f1), ["a"])
assert missing == set()
assert unannotated_with_default == {"b"}

def _f2(a: str, b) -> str:
def _f2(a: str, b) -> str: # type: ignore[empty-body,no-untyped-def]
"ok"

missing, unannotated_with_default = get_missing_annotations(get_typed_signature(_f2), ["a", "b"])
Expand All @@ -142,7 +142,7 @@ def _f3() -> None:


def test_get_parameters() -> None:
def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"] = 1.0):
def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"] = 1.0): # type: ignore[no-untyped-def]
pass

typed_signature = get_typed_signature(f)
Expand All @@ -165,7 +165,7 @@ def f(a: Annotated[str, "Parameter a"], b=1, c: Annotated[float, "Parameter c"]


def test_get_function_schema_no_return_type() -> None:
def f(a: Annotated[str, "Parameter a"], b: int, c: float = 0.1):
def f(a: Annotated[str, "Parameter a"], b: int, c: float = 0.1): # type: ignore[no-untyped-def]
pass

expected = (
Expand All @@ -182,7 +182,7 @@ def f(a: Annotated[str, "Parameter a"], b: int, c: float = 0.1):
def test_get_function_schema_unannotated_with_default() -> None:
with unittest.mock.patch("autogen.function_utils.logger.warning") as mock_logger_warning:

def f(
def f( # type: ignore[no-untyped-def]
a: Annotated[str, "Parameter a"], b=2, c: Annotated[float, "Parameter c"] = 0.1, d="whatever", e=None
) -> str:
return "ok"
Expand All @@ -195,7 +195,7 @@ def f(


def test_get_function_schema_missing() -> None:
def f(a: Annotated[str, "Parameter a"], b, c: Annotated[float, "Parameter c"] = 0.1) -> float:
def f(a: Annotated[str, "Parameter a"], b, c: Annotated[float, "Parameter c"] = 0.1) -> float: # type: ignore[no-untyped-def, empty-body]
pass

expected = (
Expand Down Expand Up @@ -291,7 +291,7 @@ class Currency(BaseModel):


def test_get_function_schema_pydantic() -> None:
def currency_calculator(
def currency_calculator( # type: ignore[empty-body]
base: Annotated[Currency, "Base currency: amount and currency symbol"],
quote_currency: Annotated[CurrencySymbol, "Quote currency symbol (default: 'EUR')"] = "EUR",
) -> Currency:
Expand Down Expand Up @@ -346,12 +346,12 @@ def currency_calculator(

def test_get_load_param_if_needed_function() -> None:
assert get_load_param_if_needed_function(CurrencySymbol) is None
assert get_load_param_if_needed_function(Currency)({"currency": "USD", "amount": 123.45}, Currency) == Currency(
assert get_load_param_if_needed_function(Currency)({"currency": "USD", "amount": 123.45}, Currency) == Currency( # type: ignore[misc]
currency="USD", amount=123.45
)

f = get_load_param_if_needed_function(Annotated[Currency, "amount and a symbol of a currency"])
actual = f({"currency": "USD", "amount": 123.45}, Currency)
actual = f({"currency": "USD", "amount": 123.45}, Currency) # type: ignore[misc]
expected = Currency(currency="USD", amount=123.45)
assert actual == expected, actual

Expand Down Expand Up @@ -391,7 +391,7 @@ async def f(
assert actual[1] == "EUR"


def test_serialize_to_json():
def test_serialize_to_json() -> None:
assert serialize_to_str("abc") == "abc"
assert serialize_to_str(123) == "123"
assert serialize_to_str([123, 456]) == "[123, 456]"
Expand Down
Loading