diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 677ecb876..956c246ce 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -457,7 +457,7 @@ async def structured_output_async( content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt self._append_message({"role": "user", "content": content}) - events = self.model.structured_output(output_model, self.messages) + events = self.model.structured_output(output_model, self.messages, system_prompt=self.system_prompt) async for event in events: if "callback" in event: self.callback_handler(**cast(dict, event["callback"])) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 936f799d7..eb72becfd 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -392,13 +392,14 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, **kwargs: Any + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: output_model: The output model to use for the agent. prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -406,7 +407,7 @@ async def structured_output( """ tool_spec = convert_pydantic_to_tool_spec(output_model) - response = self.stream(messages=prompt, tool_specs=[tool_spec], **kwargs) + response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs) async for event in process_stream(response): yield event diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index ce76a246a..0dadd9b0e 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -562,13 +562,14 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: @override async def structured_output( - self, output_model: Type[T], prompt: Messages, **kwargs: Any + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: output_model: The output model to use for the agent. prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -576,7 +577,7 @@ async def structured_output( """ tool_spec = convert_pydantic_to_tool_spec(output_model) - response = self.stream(messages=prompt, tool_specs=[tool_spec], **kwargs) + response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs) async for event in streaming.process_stream(response): yield event diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 95eb23078..c1e99f1a2 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -184,13 +184,14 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, **kwargs: Any + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: output_model: The output model to use for the agent. prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -199,7 +200,7 @@ async def structured_output( response = await litellm.acompletion( **self.client_args, model=self.get_config()["model_id"], - messages=self.format_request(prompt)["messages"], + messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], response_format=output_model, ) diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 3bae22335..421b06e52 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -407,13 +407,14 @@ async def stream( @override def structured_output( - self, output_model: Type[T], prompt: Messages, **kwargs: Any + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: output_model: The output model to use for the agent. prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 151b423d1..8855b6d64 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -493,13 +493,14 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, **kwargs: Any + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: output_model: The output model to use for the agent. prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. **kwargs: Additional keyword arguments for future extensibility. Returns: @@ -514,7 +515,7 @@ async def structured_output( "inputSchema": {"json": output_model.model_json_schema()}, } - formatted_request = self.format_request(messages=prompt, tool_specs=[tool_spec]) + formatted_request = self.format_request(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt) formatted_request["tool_choice"] = "any" formatted_request["parallel_tool_calls"] = False diff --git a/src/strands/models/model.py b/src/strands/models/model.py index 6de957633..cb24b704d 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -45,13 +45,14 @@ def get_config(self) -> Any: @abc.abstractmethod # pragma: no cover def structured_output( - self, output_model: Type[T], prompt: Messages, **kwargs: Any + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: output_model: The output model to use for the agent. prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 5fb0c1ffe..76cd87d72 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -330,19 +330,20 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, **kwargs: Any + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: output_model: The output model to use for the agent. prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. **kwargs: Additional keyword arguments for future extensibility. Yields: Model events with the last being the structured output. """ - formatted_request = self.format_request(messages=prompt) + formatted_request = self.format_request(messages=prompt, system_prompt=system_prompt) formatted_request["format"] = output_model.model_json_schema() formatted_request["stream"] = False diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 9a2a87f6a..1076fbae4 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -401,13 +401,14 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, **kwargs: Any + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: output_model: The output model to use for the agent. prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -415,7 +416,7 @@ async def structured_output( """ response: ParsedChatCompletion = await self.client.beta.chat.completions.parse( # type: ignore model=self.get_config()["model_id"], - messages=self.format_request(prompt)["messages"], + messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], response_format=output_model, ) diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index 5ce248a8c..1a87ee8f0 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -422,16 +422,17 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, **kwargs: Any + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: output_model(Type[BaseModel]): The output model to use for the agent. prompt(Messages): The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. **kwargs: Additional keyword arguments for future extensibility. """ - formatted_request = self.format_request(messages=prompt, tool_specs=None, system_prompt=None) + formatted_request = self.format_request(messages=prompt, tool_specs=None, system_prompt=system_prompt) formatted_request["response_format"] = { "type": "json_schema", "json_schema": {"schema": output_model.model_json_schema()}, diff --git a/tests/fixtures/mocked_model_provider.py b/tests/fixtures/mocked_model_provider.py index e4cb5fe93..2a397bb18 100644 --- a/tests/fixtures/mocked_model_provider.py +++ b/tests/fixtures/mocked_model_provider.py @@ -47,6 +47,8 @@ async def structured_output( self, output_model: Type[T], prompt: Messages, + system_prompt: Optional[str] = None, + **kwargs: Any, ) -> AsyncGenerator[Any, None]: pass diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 6de05113b..fd443c833 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -960,7 +960,7 @@ def test_agent_callback_handler_custom_handler_used(): assert agent.callback_handler is custom_handler -def test_agent_structured_output(agent, user, agenerator): +def test_agent_structured_output(agent, system_prompt, user, agenerator): agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) prompt = "Jane Doe is 30 years old and her email is jane@doe.com" @@ -969,10 +969,12 @@ def test_agent_structured_output(agent, user, agenerator): exp_result = user assert tru_result == exp_result - agent.model.structured_output.assert_called_once_with(type(user), [{"role": "user", "content": [{"text": prompt}]}]) + agent.model.structured_output.assert_called_once_with( + type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt + ) -def test_agent_structured_output_multi_modal_input(agent, user, agenerator): +def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, agenerator): agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) prompt = [ @@ -991,7 +993,9 @@ def test_agent_structured_output_multi_modal_input(agent, user, agenerator): exp_result = user assert tru_result == exp_result - agent.model.structured_output.assert_called_once_with(type(user), [{"role": "user", "content": prompt}]) + agent.model.structured_output.assert_called_once_with( + type(user), [{"role": "user", "content": prompt}], system_prompt=system_prompt + ) @pytest.mark.asyncio @@ -1006,7 +1010,7 @@ async def test_agent_structured_output_in_async_context(agent, user, agenerator) @pytest.mark.asyncio -async def test_agent_structured_output_async(agent, user, agenerator): +async def test_agent_structured_output_async(agent, system_prompt, user, agenerator): agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) prompt = "Jane Doe is 30 years old and her email is jane@doe.com" @@ -1015,7 +1019,9 @@ async def test_agent_structured_output_async(agent, user, agenerator): exp_result = user assert tru_result == exp_result - agent.model.structured_output.assert_called_once_with(type(user), [{"role": "user", "content": [{"text": prompt}]}]) + agent.model.structured_output.assert_called_once_with( + type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt + ) @pytest.mark.asyncio diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index 064d97a2d..175358578 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -16,7 +16,7 @@ def update_config(self, **model_config): def get_config(self): return - async def structured_output(self, output_model): + async def structured_output(self, output_model, prompt=None, system_prompt=None, **kwargs): yield {"output": output_model(name="test", age=20)} async def stream(self, messages, tool_specs=None, system_prompt=None): @@ -95,7 +95,7 @@ async def test_stream(model, messages, tool_specs, system_prompt, alist): @pytest.mark.asyncio async def test_structured_output(model, alist): - response = model.structured_output(Person) + response = model.structured_output(Person, prompt=messages, system_prompt=system_prompt) events = await alist(response) tru_output = events[-1]["output"]