Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ async def stream_messages(

messages = remove_blank_messages_content_text(messages)

chunks = model.converse(messages, tool_specs if tool_specs else None, system_prompt)
chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt)

async for event in process_stream(chunks, messages):
yield event
30 changes: 20 additions & 10 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:

return formatted_messages

@override
def format_request(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> dict[str, Any]:
Expand Down Expand Up @@ -225,7 +224,6 @@ def format_request(
**(self.config.get("params") or {}),
}

@override
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""Format the Anthropic response events into standardized message chunks.

Expand Down Expand Up @@ -344,27 +342,37 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
raise RuntimeError(f"event_type=<{event['type']} | unknown type")

@override
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Send the request to the Anthropic model and get the streaming response.
async def stream(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the Anthropic model.

Args:
request: The formatted request to send to the Anthropic model.
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.

Returns:
An iterable of response events from the Anthropic model.
Yields:
Formatted message chunks from the model.

Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the request is throttled by Anthropic.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
logger.debug("formatted request=<%s>", request)

logger.debug("invoking model")
try:
async with self.client.messages.stream(**request) as stream:
logger.debug("got response from model")
async for event in stream:
if event.type in AnthropicModel.EVENT_TYPES:
yield event.model_dump()
yield self.format_chunk(event.model_dump())

usage = event.message.usage # type: ignore
yield {"type": "metadata", "usage": usage.model_dump()}
yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()})

except anthropic.RateLimitError as error:
raise ModelThrottledException(str(error)) from error
Expand All @@ -375,6 +383,8 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]

raise error

logger.debug("finished streaming response from model")

@override
async def structured_output(
self, output_model: Type[T], prompt: Messages
Expand All @@ -390,7 +400,7 @@ async def structured_output(
"""
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.converse(messages=prompt, tool_specs=[tool_spec])
response = self.stream(messages=prompt, tool_specs=[tool_spec])
async for event in process_stream(response, prompt):
yield event

Expand Down
28 changes: 19 additions & 9 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def get_config(self) -> BedrockConfig:
"""
return self.config

@override
def format_request(
self,
messages: Messages,
Expand Down Expand Up @@ -246,7 +245,6 @@ def format_request(
),
}

@override
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""Format the Bedrock response events into standardized message chunks.

Expand Down Expand Up @@ -315,25 +313,35 @@ def _generate_redaction_events(self) -> list[StreamEvent]:
return events

@override
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[StreamEvent, None]:
"""Send the request to the Bedrock model and get the response.
async def stream(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the Bedrock model.

This method calls either the Bedrock converse_stream API or the converse API
based on the streaming parameter in the configuration.

Args:
request: The formatted request to send to the Bedrock model
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.

Returns:
An iterable of response events from the Bedrock model
Yields:
Formatted message chunks from the model.

Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the model service is throttling requests.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
logger.debug("formatted request=<%s>", request)

logger.debug("invoking model")
streaming = self.config.get("streaming", True)

try:
logger.debug("got response from model")
if streaming:
# Streaming implementation
response = self.client.converse_stream(**request)
Expand All @@ -347,7 +355,7 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[StreamEvent, N
if self._has_blocked_guardrail(guardrail_data):
for event in self._generate_redaction_events():
yield event
yield chunk
yield self.format_chunk(chunk)
else:
# Non-streaming implementation
response = self.client.converse(**request)
Expand Down Expand Up @@ -406,6 +414,8 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[StreamEvent, N
# Otherwise raise the error
raise e

logger.debug("finished streaming response from model")

def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
"""Convert a non-streaming response to the streaming format.

Expand Down Expand Up @@ -531,7 +541,7 @@ async def structured_output(
"""
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.converse(messages=prompt, tool_specs=[tool_spec])
response = self.stream(messages=prompt, tool_specs=[tool_spec])
async for event in process_stream(response, prompt):
yield event

Expand Down
58 changes: 38 additions & 20 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from ..types.content import ContentBlock, Messages
from ..types.models.openai import OpenAIModel
from ..types.streaming import StreamEvent
from ..types.tools import ToolSpec

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,19 +106,29 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]
return super().format_request_message_content(content)

@override
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Send the request to the LiteLLM model and get the streaming response.
async def stream(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the LiteLLM model.

Args:
request: The formatted request to send to the LiteLLM model.
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.

Returns:
An iterable of response events from the LiteLLM model.
Yields:
Formatted message chunks from the model.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
logger.debug("formatted request=<%s>", request)

logger.debug("invoking model")
response = self.client.chat.completions.create(**request)

yield {"chunk_type": "message_start"}
yield {"chunk_type": "content_start", "data_type": "text"}
logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})

tool_calls: dict[int, list[Any]] = {}

Expand All @@ -127,38 +139,44 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
choice = event.choices[0]

if choice.delta.content:
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
)

if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
yield {
"chunk_type": "content_delta",
"data_type": "reasoning_content",
"data": choice.delta.reasoning_content,
}
yield self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": "reasoning_content",
"data": choice.delta.reasoning_content,
}
)

for tool_call in choice.delta.tool_calls or []:
tool_calls.setdefault(tool_call.index, []).append(tool_call)

if choice.finish_reason:
break

yield {"chunk_type": "content_stop", "data_type": "text"}
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})

for tool_deltas in tool_calls.values():
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})

for tool_delta in tool_deltas:
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})

yield {"chunk_type": "content_stop", "data_type": "tool"}
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})

yield {"chunk_type": "message_stop", "data": choice.finish_reason}
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})

# Skip remaining events as we don't have use for anything except the final usage payload
for event in response:
_ = event

yield {"chunk_type": "metadata", "data": event.usage}
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})

logger.debug("finished streaming response from model")

@override
async def structured_output(
Expand All @@ -178,7 +196,7 @@ async def structured_output(
# completions() has a method `create()` which wraps the real completion API of Litellm
response = self.client.chat.completions.create(
model=self.get_config()["model_id"],
messages=super().format_request(prompt)["messages"],
messages=self.format_request(prompt)["messages"],
response_format=output_model,
)

Expand Down
46 changes: 29 additions & 17 deletions src/strands/models/llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s

return [message for message in formatted_messages if message["content"] or "tool_calls" in message]

@override
def format_request(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> dict[str, Any]:
Expand Down Expand Up @@ -249,7 +248,6 @@ def format_request(

return request

@override
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""Format the Llama API model response events into standardized message chunks.

Expand Down Expand Up @@ -324,24 +322,34 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")

@override
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Send the request to the model and get a streaming response.
async def stream(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the LlamaAPI model.

Args:
request: The formatted request to send to the model.
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.

Returns:
The model's response.
Yields:
Formatted message chunks from the model.

Raises:
ModelThrottledException: When the model service is throttling requests from the client.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
logger.debug("formatted request=<%s>", request)

logger.debug("invoking model")
try:
response = self.client.chat.completions.create(**request)
except llama_api_client.RateLimitError as e:
raise ModelThrottledException(str(e)) from e

yield {"chunk_type": "message_start"}
logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})

stop_reason = None
tool_calls: dict[Any, list[Any]] = {}
Expand All @@ -350,9 +358,11 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
metrics_event = None
for chunk in response:
if chunk.event.event_type == "start":
yield {"chunk_type": "content_start", "data_type": "text"}
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
elif chunk.event.event_type in ["progress", "complete"] and chunk.event.delta.type == "text":
yield {"chunk_type": "content_delta", "data_type": "text", "data": chunk.event.delta.text}
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": chunk.event.delta.text}
)
else:
if chunk.event.delta.type == "tool_call":
if chunk.event.delta.id:
Expand All @@ -364,29 +374,31 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
elif chunk.event.event_type == "metrics":
metrics_event = chunk.event.metrics
else:
yield chunk
yield self.format_chunk(chunk)

if stop_reason is None:
stop_reason = chunk.event.stop_reason

# stopped generation
if stop_reason:
yield {"chunk_type": "content_stop", "data_type": "text"}
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})

for tool_deltas in tool_calls.values():
tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:]
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_start}
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_start})

for tool_delta in tool_deltas:
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})

yield {"chunk_type": "content_stop", "data_type": "tool"}
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})

yield {"chunk_type": "message_stop", "data": stop_reason}
yield self.format_chunk({"chunk_type": "message_stop", "data": stop_reason})

# we may have a metrics event here
if metrics_event:
yield {"chunk_type": "metadata", "data": metrics_event}
yield self.format_chunk({"chunk_type": "metadata", "data": metrics_event})

logger.debug("finished streaming response from model")

@override
def structured_output(
Expand Down
Loading
Loading