From 0113d80278b952551483035e59a18474de5f1ddc Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Sun, 13 Jul 2025 21:46:47 -0400 Subject: [PATCH 1/3] Expand additional $refs for structured_output Addresses issue#337 Previously lists of items that were optional were not correctly expanding $refs. Derived classes also weren't having their $refs expanded as the subclass already had a "properties" object which bypassed $ref expansion --- pyproject.toml | 1 + src/strands/tools/structured_output.py | 8 +++- tests/strands/tools/test_structured_output.py | 43 +++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 032376be1..53e523c77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "opentelemetry-api>=1.30.0,<2.0.0", "opentelemetry-sdk>=1.30.0,<2.0.0", "opentelemetry-instrumentation-threading>=0.51b0,<1.00b0", + "notebook>=7.4.4", ] [project.urls] diff --git a/src/strands/tools/structured_output.py b/src/strands/tools/structured_output.py index 5421cdc69..6f2739d88 100644 --- a/src/strands/tools/structured_output.py +++ b/src/strands/tools/structured_output.py @@ -54,7 +54,9 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: # Process each nested property for nested_prop_name, nested_prop_value in prop_value["properties"].items(): - processed_prop["properties"][nested_prop_name] = nested_prop_value + is_required = "required" in prop_value and nested_prop_name in prop_value["required"] + sub_property = _process_property(nested_prop_value, schema.get("$defs", {}), is_required) + processed_prop["properties"][nested_prop_name] = sub_property # Copy required fields if present if "required" in prop_value: @@ -137,6 +139,10 @@ def _process_property( if "description" in prop: result["description"] = prop["description"] + # Need to process item refs as well (#337) + if "items" in result: + result["items"] = _process_property(result["items"], defs) + return result # Handle direct references diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index 2e354b831..110d8ed0f 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -1,3 +1,4 @@ +import json from typing import Literal, Optional import pytest @@ -226,3 +227,45 @@ class EmptyDocUser(BaseModel): tool_spec = convert_pydantic_to_tool_spec(EmptyDocUser) assert tool_spec["description"] == "EmptyDocUser structured output tool" + + +def test_convert_pydantic_with_items_refs(): + """Test that no $refs exist after lists of different components.""" + + class Address(BaseModel): + postal_code: Optional[str] = None + + class Person(BaseModel): + """Complete person information.""" + + list_of_items: list[Address] + list_of_items_nullable: Optional[list[Address]] + list_of_item_or_nullable: list[Optional[Address]] + + tool_spec = convert_pydantic_to_tool_spec(Person) + raw_json = json.dumps(tool_spec, indent=2) + + assert "$ref" not in raw_json + + +def test_convert_pydantic_with_refs(): + """Test that no $refs exist after processing complex hierarchies.""" + + class Address(BaseModel): + street: str + city: str + country: str + postal_code: Optional[str] = None + + class Contact(BaseModel): + address: Address + + class Person(BaseModel): + """Complete person information.""" + + contact: Contact = Field(description="Contact methods") + + tool_spec = convert_pydantic_to_tool_spec(Person) + raw_json = json.dumps(tool_spec, indent=2) + + assert "$ref" not in raw_json From d473fdbc3f1f1deb5053e14577ac8286dc2031fc Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Mon, 14 Jul 2025 10:04:14 -0400 Subject: [PATCH 2/3] Remove pyproject change --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 53e523c77..032376be1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,6 @@ dependencies = [ "opentelemetry-api>=1.30.0,<2.0.0", "opentelemetry-sdk>=1.30.0,<2.0.0", "opentelemetry-instrumentation-threading>=0.51b0,<1.00b0", - "notebook>=7.4.4", ] [project.urls] From b586896e694b2384c3eac0673be30112937ff690 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Mon, 14 Jul 2025 13:43:06 -0400 Subject: [PATCH 3/3] Assert on generated input spec instead of json refs --- tests/strands/tools/test_structured_output.py | 84 +++++++++++++++++-- 1 file changed, 79 insertions(+), 5 deletions(-) diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index 110d8ed0f..97b68a34c 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -1,4 +1,3 @@ -import json from typing import Literal, Optional import pytest @@ -243,9 +242,53 @@ class Person(BaseModel): list_of_item_or_nullable: list[Optional[Address]] tool_spec = convert_pydantic_to_tool_spec(Person) - raw_json = json.dumps(tool_spec, indent=2) - assert "$ref" not in raw_json + expected_spec = { + "description": "Complete person information.", + "inputSchema": { + "json": { + "description": "Complete person information.", + "properties": { + "list_of_item_or_nullable": { + "items": { + "anyOf": [ + { + "properties": {"postal_code": {"type": ["string", "null"]}}, + "title": "Address", + "type": "object", + }, + {"type": "null"}, + ] + }, + "title": "List Of Item Or Nullable", + "type": "array", + }, + "list_of_items": { + "items": { + "properties": {"postal_code": {"type": ["string", "null"]}}, + "title": "Address", + "type": "object", + }, + "title": "List Of Items", + "type": "array", + }, + "list_of_items_nullable": { + "items": { + "properties": {"postal_code": {"type": ["string", "null"]}}, + "title": "Address", + "type": "object", + }, + "type": ["array", "null"], + }, + }, + "required": ["list_of_items", "list_of_item_or_nullable"], + "title": "Person", + "type": "object", + } + }, + "name": "Person", + } + assert tool_spec == expected_spec def test_convert_pydantic_with_refs(): @@ -266,6 +309,37 @@ class Person(BaseModel): contact: Contact = Field(description="Contact methods") tool_spec = convert_pydantic_to_tool_spec(Person) - raw_json = json.dumps(tool_spec, indent=2) - assert "$ref" not in raw_json + expected_spec = { + "description": "Complete person information.", + "inputSchema": { + "json": { + "description": "Complete person information.", + "properties": { + "contact": { + "description": "Contact methods", + "properties": { + "address": { + "properties": { + "city": {"title": "City", "type": "string"}, + "country": {"title": "Country", "type": "string"}, + "postal_code": {"type": ["string", "null"]}, + "street": {"title": "Street", "type": "string"}, + }, + "required": ["street", "city", "country"], + "title": "Address", + "type": "object", + } + }, + "required": ["address"], + "type": "object", + } + }, + "required": ["contact"], + "title": "Person", + "type": "object", + } + }, + "name": "Person", + } + assert tool_spec == expected_spec