Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Fix schema building for complex types #6394

Merged
merged 12 commits into from
May 28, 2024
45 changes: 41 additions & 4 deletions python/semantic_kernel/schema/kernel_json_schema_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# Copyright (c) Microsoft. All rights reserved.

from typing import Any, get_type_hints
from typing import Any, Union, get_args, get_origin, get_type_hints

from semantic_kernel.kernel_pydantic import KernelBaseModel

Expand All @@ -11,12 +9,16 @@
float: "number",
list: "array",
dict: "object",
set: "array",
tuple: "array",
"int": "integer",
"str": "string",
"bool": "boolean",
"float": "number",
"list": "array",
"dict": "object",
"set": "array",
"tuple": "array",
"object": "object",
"array": "array",
}
Expand All @@ -30,10 +32,12 @@ def build(cls, parameter_type: type | str, description: str | None = None) -> di

if isinstance(parameter_type, str):
return cls.build_from_type_name(parameter_type, description)
if issubclass(parameter_type, KernelBaseModel):
if isinstance(parameter_type, KernelBaseModel):
return cls.build_model_schema(parameter_type, description)
if hasattr(parameter_type, "__annotations__"):
return cls.build_model_schema(parameter_type, description)
if hasattr(parameter_type, "__args__"):
return cls.handle_complex_type(parameter_type, description)
else:
schema = cls.get_json_schema(parameter_type)
if description:
Expand Down Expand Up @@ -74,3 +78,36 @@ def get_json_schema(cls, parameter_type: type) -> dict[str, Any]:
type_name = TYPE_MAPPING.get(parameter_type, "object")
schema = {"type": type_name}
return schema

@classmethod
def handle_complex_type(cls, parameter_type: type, description: str | None = None) -> dict[str, Any]:
"""Handles complex types like list[str], dict[str, int],
set[int], tuple[int, str], Union[int, str], and Optional[int]."""
moonbox3 marked this conversation as resolved.
Show resolved Hide resolved
origin = get_origin(parameter_type)
args = get_args(parameter_type)

if origin is list or origin is set:
item_type = args[0]
return {"type": "array", "items": cls.build(item_type), "description": description}
elif origin is dict:
_, value_type = args
moonbox3 marked this conversation as resolved.
Show resolved Hide resolved
additional_properties = cls.build(value_type)
if additional_properties == {"type": "object"}:
additional_properties["properties"] = {} # Account for differences in Python 3.10 dict
return {"type": "object", "additionalProperties": additional_properties, "description": description}
elif origin is tuple:
items = [cls.build(arg) for arg in args]
return {"type": "array", "items": items, "description": description}
elif origin is Union:
if len(args) == 2 and type(None) in args:
non_none_type = args[0] if args[1] is type(None) else args[1]
moonbox3 marked this conversation as resolved.
Show resolved Hide resolved
schema = cls.build(non_none_type)
schema["nullable"] = True
if description:
schema["description"] = description
return schema
else:
schemas = [cls.build(arg) for arg in args]
return {"anyOf": schemas, "description": description}
else:
return cls.get_json_schema(parameter_type)
165 changes: 165 additions & 0 deletions python/tests/unit/schema/test_schema_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) Microsoft. All rights reserved.

import json
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from unittest.mock import Mock

from semantic_kernel.kernel_pydantic import KernelBaseModel
from semantic_kernel.schema.kernel_json_schema_builder import KernelJsonSchemaBuilder
Expand All @@ -15,6 +18,35 @@ class AnotherModel:
score: float


class MockClass:
name: str = None
age: int = None


class MockModel:
__annotations__ = {
moonbox3 marked this conversation as resolved.
Show resolved Hide resolved
"id": int,
"name": str,
"is_active": bool,
"scores": List[int],
"metadata": Dict[str, Any],
"tags": Set[str],
"coordinates": Tuple[int, int],
"status": Union[int, str],
"optional_field": Optional[str],
}
__fields__ = {
"id": Mock(description="The ID of the model"),
"name": Mock(description="The name of the model"),
"is_active": Mock(description="Whether the model is active"),
"tags": Mock(description="Tags associated with the model"),
"status": Mock(description="The status of the model, either as an integer or a string"),
"scores": Mock(description="The scores associated with the model"),
"optional_field": Mock(description="An optional field that can be null"),
"metadata": Mock(description="The optional metadata description"),
}


def test_build_with_kernel_base_model():
expected_schema = {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}}
result = KernelJsonSchemaBuilder.build(ExampleModel)
Expand Down Expand Up @@ -71,3 +103,136 @@ def test_get_json_schema():
expected_schema = {"type": "integer"}
result = KernelJsonSchemaBuilder.get_json_schema(int)
assert result == expected_schema


def test_build_primitive_types():
assert KernelJsonSchemaBuilder.build(int) == {"type": "integer"}
assert KernelJsonSchemaBuilder.build(str) == {"type": "string"}
assert KernelJsonSchemaBuilder.build(bool) == {"type": "boolean"}
assert KernelJsonSchemaBuilder.build(float) == {"type": "number"}


def test_build_list():
schema = KernelJsonSchemaBuilder.build(list[str])
assert schema == {"type": "array", "items": {"type": "string"}, "description": None}


def test_build_list_complex_type():
schema = KernelJsonSchemaBuilder.build(list[MockClass])
assert schema == {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
},
},
"description": None,
}


def test_build_dict():
schema = KernelJsonSchemaBuilder.build(dict[str, int])
assert schema == {"type": "object", "additionalProperties": {"type": "integer"}, "description": None}


def test_build_set():
schema = KernelJsonSchemaBuilder.build(set[int])
assert schema == {"type": "array", "items": {"type": "integer"}, "description": None}


def test_build_tuple():
schema = KernelJsonSchemaBuilder.build(Tuple[int, str])
assert schema == {"type": "array", "items": [{"type": "integer"}, {"type": "string"}], "description": None}


def test_build_union():
schema = KernelJsonSchemaBuilder.build(Union[int, str])
assert schema == {"anyOf": [{"type": "integer"}, {"type": "string"}], "description": None}


def test_build_optional():
schema = KernelJsonSchemaBuilder.build(Optional[int])
assert schema == {"type": "integer", "nullable": True}


def test_build_model_schema_for_many_types():
schema = KernelJsonSchemaBuilder.build(MockModel)
expected = """
{
"type": "object",
"properties": {
"id": {
"type": "integer",
"description": "The ID of the model"
},
"name": {
"type": "string",
"description": "The name of the model"
},
"is_active": {
"type": "boolean",
"description": "Whether the model is active"
},
"scores": {
"type": "array",
"items": {"type": "integer"},
"description": "The scores associated with the model"
},
"metadata": {
"type": "object",
"additionalProperties": {
"type": "object",
"properties": {}
},
"description": "The optional metadata description"
},
"tags": {
"type": "array",
"items": {"type": "string"},
"description": "Tags associated with the model"
},
"coordinates": {
"type": "array",
"items": [
{"type": "integer"},
{"type": "integer"}
],
"description": null
},
"status": {
"anyOf": [
{"type": "integer"},
{"type": "string"}
],
"description": "The status of the model, either as an integer or a string"
},
"optional_field": {
"type": "string",
"nullable": true,
"description": "An optional field that can be null"
}
}
}
"""
expected_schema = json.loads(expected)
assert schema == expected_schema


def test_build_from_many_type_names():
moonbox3 marked this conversation as resolved.
Show resolved Hide resolved
assert KernelJsonSchemaBuilder.build_from_type_name("int") == {"type": "integer"}
assert KernelJsonSchemaBuilder.build_from_type_name("str") == {"type": "string"}
assert KernelJsonSchemaBuilder.build_from_type_name("bool") == {"type": "boolean"}
assert KernelJsonSchemaBuilder.build_from_type_name("float") == {"type": "number"}
assert KernelJsonSchemaBuilder.build_from_type_name("list") == {"type": "array"}
assert KernelJsonSchemaBuilder.build_from_type_name("dict") == {"type": "object"}
assert KernelJsonSchemaBuilder.build_from_type_name("object") == {"type": "object"}
assert KernelJsonSchemaBuilder.build_from_type_name("array") == {"type": "array"}


def test_get_json_schema_multiple():
assert KernelJsonSchemaBuilder.get_json_schema(int) == {"type": "integer"}
assert KernelJsonSchemaBuilder.get_json_schema(str) == {"type": "string"}
assert KernelJsonSchemaBuilder.get_json_schema(bool) == {"type": "boolean"}
assert KernelJsonSchemaBuilder.get_json_schema(float) == {"type": "number"}
Loading