Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ async def structured_output_async(
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
self._append_message({"role": "user", "content": content})

events = self.model.structured_output(output_model, self.messages)
events = self.model.structured_output(output_model, self.messages, system_prompt=self.system_prompt)
async for event in events:
if "callback" in event:
self.callback_handler(**cast(dict, event["callback"]))
Expand Down
5 changes: 3 additions & 2 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,21 +392,22 @@ async def stream(

@override
async def structured_output(
self, output_model: Type[T], prompt: Messages, **kwargs: Any
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
"""Get structured output from the model.

Args:
output_model: The output model to use for the agent.
prompt: The prompt messages to use for the agent.
system_prompt: System prompt to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Model events with the last being the structured output.
"""
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.stream(messages=prompt, tool_specs=[tool_spec], **kwargs)
response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs)
async for event in process_stream(response):
yield event

Expand Down
5 changes: 3 additions & 2 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,21 +562,22 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool:

@override
async def structured_output(
self, output_model: Type[T], prompt: Messages, **kwargs: Any
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
"""Get structured output from the model.

Args:
output_model: The output model to use for the agent.
prompt: The prompt messages to use for the agent.
system_prompt: System prompt to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Model events with the last being the structured output.
"""
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.stream(messages=prompt, tool_specs=[tool_spec], **kwargs)
response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs)
async for event in streaming.process_stream(response):
yield event

Expand Down
5 changes: 3 additions & 2 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,14 @@ async def stream(

@override
async def structured_output(
self, output_model: Type[T], prompt: Messages, **kwargs: Any
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
"""Get structured output from the model.

Args:
output_model: The output model to use for the agent.
prompt: The prompt messages to use for the agent.
system_prompt: System prompt to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Expand All @@ -199,7 +200,7 @@ async def structured_output(
response = await litellm.acompletion(
**self.client_args,
model=self.get_config()["model_id"],
messages=self.format_request(prompt)["messages"],
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
response_format=output_model,
)

Expand Down
3 changes: 2 additions & 1 deletion src/strands/models/llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,13 +407,14 @@ async def stream(

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, **kwargs: Any
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
"""Get structured output from the model.

Args:
output_model: The output model to use for the agent.
prompt: The prompt messages to use for the agent.
system_prompt: System prompt to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Expand Down
5 changes: 3 additions & 2 deletions src/strands/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,13 +493,14 @@ async def stream(

@override
async def structured_output(
self, output_model: Type[T], prompt: Messages, **kwargs: Any
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
"""Get structured output from the model.

Args:
output_model: The output model to use for the agent.
prompt: The prompt messages to use for the agent.
system_prompt: System prompt to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
Expand All @@ -514,7 +515,7 @@ async def structured_output(
"inputSchema": {"json": output_model.model_json_schema()},
}

formatted_request = self.format_request(messages=prompt, tool_specs=[tool_spec])
formatted_request = self.format_request(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt)

formatted_request["tool_choice"] = "any"
formatted_request["parallel_tool_calls"] = False
Expand Down
3 changes: 2 additions & 1 deletion src/strands/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,14 @@ def get_config(self) -> Any:
@abc.abstractmethod
# pragma: no cover
def structured_output(
self, output_model: Type[T], prompt: Messages, **kwargs: Any
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
"""Get structured output from the model.

Args:
output_model: The output model to use for the agent.
prompt: The prompt messages to use for the agent.
system_prompt: System prompt to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Expand Down
5 changes: 3 additions & 2 deletions src/strands/models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,19 +330,20 @@ async def stream(

@override
async def structured_output(
self, output_model: Type[T], prompt: Messages, **kwargs: Any
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
"""Get structured output from the model.

Args:
output_model: The output model to use for the agent.
prompt: The prompt messages to use for the agent.
system_prompt: System prompt to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Model events with the last being the structured output.
"""
formatted_request = self.format_request(messages=prompt)
formatted_request = self.format_request(messages=prompt, system_prompt=system_prompt)
formatted_request["format"] = output_model.model_json_schema()
formatted_request["stream"] = False

Expand Down
5 changes: 3 additions & 2 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,21 +401,22 @@ async def stream(

@override
async def structured_output(
self, output_model: Type[T], prompt: Messages, **kwargs: Any
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
"""Get structured output from the model.

Args:
output_model: The output model to use for the agent.
prompt: The prompt messages to use for the agent.
system_prompt: System prompt to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Model events with the last being the structured output.
"""
response: ParsedChatCompletion = await self.client.beta.chat.completions.parse( # type: ignore
model=self.get_config()["model_id"],
messages=self.format_request(prompt)["messages"],
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
response_format=output_model,
)

Expand Down
5 changes: 3 additions & 2 deletions src/strands/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,16 +422,17 @@ async def stream(

@override
async def structured_output(
self, output_model: Type[T], prompt: Messages, **kwargs: Any
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt messages to use for the agent.
system_prompt: System prompt to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.
"""
formatted_request = self.format_request(messages=prompt, tool_specs=None, system_prompt=None)
formatted_request = self.format_request(messages=prompt, tool_specs=None, system_prompt=system_prompt)
formatted_request["response_format"] = {
"type": "json_schema",
"json_schema": {"schema": output_model.model_json_schema()},
Expand Down
2 changes: 2 additions & 0 deletions tests/fixtures/mocked_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ async def structured_output(
self,
output_model: Type[T],
prompt: Messages,
system_prompt: Optional[str] = None,
**kwargs: Any,
) -> AsyncGenerator[Any, None]:
pass

Expand Down
18 changes: 12 additions & 6 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ def test_agent_callback_handler_custom_handler_used():
assert agent.callback_handler is custom_handler


def test_agent_structured_output(agent, user, agenerator):
def test_agent_structured_output(agent, system_prompt, user, agenerator):
agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}]))

prompt = "Jane Doe is 30 years old and her email is [email protected]"
Expand All @@ -969,10 +969,12 @@ def test_agent_structured_output(agent, user, agenerator):
exp_result = user
assert tru_result == exp_result

agent.model.structured_output.assert_called_once_with(type(user), [{"role": "user", "content": [{"text": prompt}]}])
agent.model.structured_output.assert_called_once_with(
type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt
)


def test_agent_structured_output_multi_modal_input(agent, user, agenerator):
def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, agenerator):
agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}]))

prompt = [
Expand All @@ -991,7 +993,9 @@ def test_agent_structured_output_multi_modal_input(agent, user, agenerator):
exp_result = user
assert tru_result == exp_result

agent.model.structured_output.assert_called_once_with(type(user), [{"role": "user", "content": prompt}])
agent.model.structured_output.assert_called_once_with(
type(user), [{"role": "user", "content": prompt}], system_prompt=system_prompt
)


@pytest.mark.asyncio
Expand All @@ -1006,7 +1010,7 @@ async def test_agent_structured_output_in_async_context(agent, user, agenerator)


@pytest.mark.asyncio
async def test_agent_structured_output_async(agent, user, agenerator):
async def test_agent_structured_output_async(agent, system_prompt, user, agenerator):
agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}]))

prompt = "Jane Doe is 30 years old and her email is [email protected]"
Expand All @@ -1015,7 +1019,9 @@ async def test_agent_structured_output_async(agent, user, agenerator):
exp_result = user
assert tru_result == exp_result

agent.model.structured_output.assert_called_once_with(type(user), [{"role": "user", "content": [{"text": prompt}]}])
agent.model.structured_output.assert_called_once_with(
type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt
)


@pytest.mark.asyncio
Expand Down
4 changes: 2 additions & 2 deletions tests/strands/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def update_config(self, **model_config):
def get_config(self):
return

async def structured_output(self, output_model):
async def structured_output(self, output_model, prompt=None, system_prompt=None, **kwargs):
yield {"output": output_model(name="test", age=20)}

async def stream(self, messages, tool_specs=None, system_prompt=None):
Expand Down Expand Up @@ -95,7 +95,7 @@ async def test_stream(model, messages, tool_specs, system_prompt, alist):

@pytest.mark.asyncio
async def test_structured_output(model, alist):
response = model.structured_output(Person)
response = model.structured_output(Person, prompt=messages, system_prompt=system_prompt)
events = await alist(response)

tru_output = events[-1]["output"]
Expand Down
Loading