diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index fd857707c..878cbefcd 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -16,7 +16,7 @@ import random from concurrent.futures import ThreadPoolExecutor from threading import Thread -from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union, cast from uuid import uuid4 from opentelemetry import trace @@ -423,7 +423,12 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) messages.append({"role": "user", "content": [{"text": prompt}]}) # get the structured output from the model - return self.model.structured_output(output_model, messages, self.callback_handler) + events = self.model.structured_output(output_model, messages) + for event in events: + if "callback" in event: + self.callback_handler(**cast(dict, event["callback"])) + + return event["output"] async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 51089d47e..d70ed0c8c 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -7,14 +7,13 @@ import json import logging import mimetypes -from typing import Any, Callable, Iterable, Optional, Type, TypedDict, TypeVar, cast +from typing import Any, Generator, Iterable, Optional, Type, TypedDict, TypeVar, Union, cast import anthropic from pydantic import BaseModel from typing_extensions import Required, Unpack, override from ..event_loop.streaming import process_stream -from ..handlers.callback_handler import PrintingCallbackHandler from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException @@ -378,24 +377,24 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: @override def structured_output( - self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None - ) -> T: + self, output_model: Type[T], prompt: Messages + ) -> Generator[dict[str, Union[T, Any]], None, None]: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. prompt(Messages): The prompt messages to use for the agent. - callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Yields: + Model events with the last being the structured output. """ - callback_handler = callback_handler or PrintingCallbackHandler() tool_spec = convert_pydantic_to_tool_spec(output_model) response = self.converse(messages=prompt, tool_specs=[tool_spec]) for event in process_stream(response, prompt): - if "callback" in event: - callback_handler(**event["callback"]) - else: - stop_reason, messages, _, _ = event["stop"] + yield event + + stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") @@ -413,4 +412,4 @@ def structured_output( if output_response is None: raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") - return output_model(**output_response) + yield {"output": output_model(**output_response)} diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 5448f96e0..31d57fd61 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -6,7 +6,7 @@ import json import logging import os -from typing import Any, Callable, Iterable, List, Literal, Optional, Type, TypeVar, cast +from typing import Any, Generator, Iterable, List, Literal, Optional, Type, TypeVar, Union, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -15,7 +15,6 @@ from typing_extensions import TypedDict, Unpack, override from ..event_loop.streaming import process_stream -from ..handlers.callback_handler import PrintingCallbackHandler from ..tools import convert_pydantic_to_tool_spec from ..types.content import Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException @@ -521,24 +520,24 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: @override def structured_output( - self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None - ) -> T: + self, output_model: Type[T], prompt: Messages + ) -> Generator[dict[str, Union[T, Any]], None, None]: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. prompt(Messages): The prompt messages to use for the agent. - callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Yields: + Model events with the last being the structured output. """ - callback_handler = callback_handler or PrintingCallbackHandler() tool_spec = convert_pydantic_to_tool_spec(output_model) response = self.converse(messages=prompt, tool_specs=[tool_spec]) for event in process_stream(response, prompt): - if "callback" in event: - callback_handler(**event["callback"]) - else: - stop_reason, messages, _, _ = event["stop"] + yield event + + stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") @@ -556,4 +555,4 @@ def structured_output( if output_response is None: raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") - return output_model(**output_response) + yield {"output": output_model(**output_response)} diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 661381863..8c5020637 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -5,7 +5,7 @@ import json import logging -from typing import Any, Callable, Optional, Type, TypedDict, TypeVar, cast +from typing import Any, Generator, Optional, Type, TypedDict, TypeVar, Union, cast import litellm from litellm.utils import supports_response_schema @@ -105,15 +105,16 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] @override def structured_output( - self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None - ) -> T: + self, output_model: Type[T], prompt: Messages + ) -> Generator[dict[str, Union[T, Any]], None, None]: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. prompt(Messages): The prompt messages to use for the agent. - callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + Yields: + Model events with the last being the structured output. """ # The LiteLLM `Client` inits with Chat(). # Chat() inits with self.completions @@ -136,7 +137,8 @@ def structured_output( # Parse the tool call content as JSON tool_call_data = json.loads(choice.message.content) # Instantiate the output model with the parsed data - return output_model(**tool_call_data) + yield {"output": output_model(**tool_call_data)} + return except (json.JSONDecodeError, TypeError, ValueError) as e: raise ValueError(f"Failed to parse or load content into model: {e}") from e diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 755e07ad9..751b1b1b7 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -8,7 +8,7 @@ import json import logging import mimetypes -from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast +from typing import Any, Generator, Iterable, Optional, Type, TypeVar, Union, cast import llama_api_client from llama_api_client import LlamaAPIClient @@ -390,14 +390,16 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: @override def structured_output( - self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None - ) -> T: + self, output_model: Type[T], prompt: Messages + ) -> Generator[dict[str, Union[T, Any]], None, None]: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. prompt(Messages): The prompt messages to use for the agent. - callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Yields: + Model events with the last being the structured output. Raises: NotImplementedError: Structured output is not currently supported for LlamaAPI models. diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index b062fe14d..431b1f45b 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,7 +5,7 @@ import json import logging -from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast +from typing import Any, Generator, Iterable, Optional, Type, TypeVar, Union, cast from ollama import Client as OllamaClient from pydantic import BaseModel @@ -316,14 +316,16 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: @override def structured_output( - self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None - ) -> T: + self, output_model: Type[T], prompt: Messages + ) -> Generator[dict[str, Union[T, Any]], None, None]: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. prompt(Messages): The prompt messages to use for the agent. - callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Yields: + Model events with the last being the structured output. """ formatted_request = self.format_request(messages=prompt) formatted_request["format"] = output_model.model_json_schema() @@ -332,6 +334,6 @@ def structured_output( try: content = response.message.content.strip() - return output_model.model_validate_json(content) + yield {"output": output_model.model_validate_json(content)} except Exception as e: raise ValueError(f"Failed to parse or load content into model: {e}") from e diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 783ce3794..7ec16efed 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -4,7 +4,7 @@ """ import logging -from typing import Any, Callable, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, cast +from typing import Any, Generator, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast import openai from openai.types.chat.parsed_chat_completion import ParsedChatCompletion @@ -133,14 +133,16 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: @override def structured_output( - self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None - ) -> T: + self, output_model: Type[T], prompt: Messages + ) -> Generator[dict[str, Union[T, Any]], None, None]: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. prompt(Messages): The prompt messages to use for the agent. - callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Yields: + Model events with the last being the structured output. """ response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore model=self.get_config()["model_id"], @@ -159,6 +161,6 @@ def structured_output( break if parsed: - return parsed + yield {"output": parsed} else: raise ValueError("No valid tool use or tool use input was found in the OpenAI response.") diff --git a/src/strands/types/models/model.py b/src/strands/types/models/model.py index 071c8a511..0a289cf53 100644 --- a/src/strands/types/models/model.py +++ b/src/strands/types/models/model.py @@ -2,7 +2,7 @@ import abc import logging -from typing import Any, Callable, Iterable, Optional, Type, TypeVar +from typing import Any, Generator, Iterable, Optional, Type, TypeVar, Union from pydantic import BaseModel @@ -45,17 +45,16 @@ def get_config(self) -> Any: @abc.abstractmethod # pragma: no cover def structured_output( - self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None - ) -> T: + self, output_model: Type[T], prompt: Messages + ) -> Generator[dict[str, Union[T, Any]], None, None]: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. prompt(Messages): The prompt messages to use for the agent. - callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. - Returns: - The structured output as a serialized instance of the output model. + Yields: + Model events with the last being the structured output. Raises: ValidationException: The response format from the model does not match the output_model diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 8ff37d359..e5a8ce6b1 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -11,7 +11,7 @@ import json import logging import mimetypes -from typing import Any, Callable, Optional, Type, TypeVar, cast +from typing import Any, Generator, Optional, Type, TypeVar, Union, cast from pydantic import BaseModel from typing_extensions import override @@ -295,13 +295,15 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: @override def structured_output( - self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None - ) -> T: + self, output_model: Type[T], prompt: Messages + ) -> Generator[dict[str, Union[T, Any]], None, None]: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. prompt(Messages): The prompt to use for the agent. - callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Yields: + Model events with the last being the structured output. """ - return output_model() + yield {"output": output_model()} diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index c813a1a91..7fd1bea6f 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -898,7 +898,7 @@ class User(BaseModel): def test_agent_method_structured_output(agent): # Mock the structured_output method on the model expected_user = User(name="Jane Doe", age=30, email="jane@doe.com") - agent.model.structured_output = unittest.mock.Mock(return_value=expected_user) + agent.model.structured_output = unittest.mock.Mock(return_value=[{"output": expected_user}]) prompt = "Jane Doe is 30 years old and her email is jane@doe.com" @@ -906,9 +906,7 @@ def test_agent_method_structured_output(agent): assert result == expected_user # Verify the model's structured_output was called with correct arguments - agent.model.structured_output.assert_called_once_with( - User, [{"role": "user", "content": [{"text": prompt}]}], agent.callback_handler - ) + agent.model.structured_output.assert_called_once_with(User, [{"role": "user", "content": [{"text": prompt}]}]) @pytest.mark.asyncio diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index a0cfc4d4a..203352151 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1,6 +1,7 @@ import unittest.mock import anthropic +import pydantic import pytest import strands @@ -41,6 +42,15 @@ def system_prompt(): return "s1" +@pytest.fixture +def test_output_model_cls(): + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + def test__init__model_configs(anthropic_client, model_id, max_tokens): _ = anthropic_client @@ -688,3 +698,58 @@ def test_stream_bad_request_error(anthropic_client, model): with pytest.raises(anthropic.BadRequestError, match="bad"): next(model.stream({})) + + +def test_structured_output(anthropic_client, model, test_output_model_cls): + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + events = [ + unittest.mock.Mock(type="message_start", model_dump=unittest.mock.Mock(return_value={"type": "message_start"})), + unittest.mock.Mock( + type="content_block_start", + model_dump=unittest.mock.Mock( + return_value={ + "type": "content_block_start", + "index": 0, + "content_block": {"type": "tool_use", "id": "123", "name": "TestOutputModel"}, + } + ), + ), + unittest.mock.Mock( + type="content_block_delta", + model_dump=unittest.mock.Mock( + return_value={ + "type": "content_block_delta", + "index": 0, + "delta": {"type": "input_json_delta", "partial_json": '{"name": "John", "age": 30}'}, + }, + ), + ), + unittest.mock.Mock( + type="content_block_stop", + model_dump=unittest.mock.Mock(return_value={"type": "content_block_stop", "index": 0}), + ), + unittest.mock.Mock( + type="message_stop", + model_dump=unittest.mock.Mock( + return_value={"type": "message_stop", "message": {"stop_reason": "tool_use"}} + ), + ), + unittest.mock.Mock( + message=unittest.mock.Mock( + usage=unittest.mock.Mock( + model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0}) + ), + ), + ), + ] + + mock_stream = unittest.mock.MagicMock() + mock_stream.__iter__.return_value = iter(events) + anthropic_client.messages.stream.return_value.__enter__.return_value = mock_stream + + stream = model.structured_output(test_output_model_cls, messages) + + tru_result = list(stream)[-1] + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 3ed72973b..1d045f3b1 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -3,6 +3,7 @@ import unittest.mock import boto3 +import pydantic import pytest from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError, EventStreamError @@ -84,6 +85,15 @@ def cache_type(): return "default" +@pytest.fixture +def test_output_model_cls(): + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + def test__init__default_model_id(bedrock_client): """Test that BedrockModel uses DEFAULT_MODEL_ID when no model_id is provided.""" _ = bedrock_client @@ -1035,6 +1045,26 @@ def test_converse_output_guardrails_redacts_output(bedrock_client): bedrock_client.converse_stream.assert_not_called() +def test_structured_output(bedrock_client, model, test_output_model_cls): + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + bedrock_client.converse_stream.return_value = { + "stream": [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "TestOutputModel"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"name": "John", "age": 30}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + } + + stream = model.structured_output(test_output_model_cls, messages) + + tru_output = list(stream)[-1] + exp_output = {"output": test_output_model_cls(name="John", age=30)} + assert tru_output == exp_output + + @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") def test_add_note_on_client_error(bedrock_client, model): """Test that add_note is called on ClientError with region and model ID information.""" diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 528d14982..50a073ad3 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -1,5 +1,6 @@ import unittest.mock +import pydantic import pytest import strands @@ -39,6 +40,15 @@ def system_prompt(): return "s1" +@pytest.fixture +def test_output_model_cls(): + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + def test__init__(litellm_client_cls, model_id): model = LiteLLMModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) @@ -103,3 +113,22 @@ def test_update_config(model, model_id): def test_format_request_message_content(content, exp_result): tru_result = LiteLLMModel.format_request_message_content(content) assert tru_result == exp_result + + +def test_structured_output(litellm_client, model, test_output_model_cls): + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + mock_choice = unittest.mock.Mock() + mock_choice.finish_reason = "tool_calls" + mock_choice.message.content = '{"name": "John", "age": 30}' + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + litellm_client.chat.completions.create.return_value = mock_response + + with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=True): + stream = model.structured_output(test_output_model_cls, messages) + tru_result = list(stream)[-1] + + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index fe590dffc..ead4caba0 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -1,6 +1,7 @@ import json import unittest.mock +import pydantic import pytest import strands @@ -41,6 +42,15 @@ def system_prompt(): return "s1" +@pytest.fixture +def test_output_model_cls(): + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + def test__init__model_configs(ollama_client, model_id, host): _ = ollama_client @@ -457,3 +467,18 @@ def test_stream_with_tool_calls(ollama_client, model): assert tru_events == exp_events ollama_client.chat.assert_called_once_with(**request) + + +def test_structured_output(ollama_client, model, test_output_model_cls): + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + mock_response = unittest.mock.Mock() + mock_response.message.content = '{"name": "John", "age": 30}' + + ollama_client.chat.return_value = mock_response + + stream = model.structured_output(test_output_model_cls, messages) + + tru_result = list(stream)[-1] + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 4c1f85287..ae0332864 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -1,5 +1,6 @@ import unittest.mock +import pydantic import pytest import strands @@ -39,6 +40,15 @@ def system_prompt(): return "s1" +@pytest.fixture +def test_output_model_cls(): + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + def test__init__(openai_client_cls, model_id): model = OpenAIModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) @@ -173,3 +183,21 @@ def test_stream_with_empty_choices(openai_client, model): assert tru_events == exp_events openai_client.chat.completions.create.assert_called_once_with(**request) + + +def test_structured_output(openai_client, model, test_output_model_cls): + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + mock_parsed_instance = test_output_model_cls(name="John", age=30) + mock_choice = unittest.mock.Mock() + mock_choice.message.parsed = mock_parsed_instance + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + openai_client.beta.chat.completions.parse.return_value = mock_response + + stream = model.structured_output(test_output_model_cls, messages) + + tru_result = list(stream)[-1] + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result