Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from typing_extensions import get_type_hints

from ..types.tools import ToolChoice


def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> None:
"""Validate that config keys match the TypedDict fields.
Expand All @@ -25,3 +27,16 @@ def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) ->
f"\nSee https://github.com/strands-agents/sdk-python/issues/815",
stacklevel=4,
)


def warn_on_tool_choice_not_supported(tool_choice: ToolChoice | None) -> None:
"""Emits a warning if a tool choice is provided but not supported by the provider.

Args:
tool_choice: the tool_choice provided to the provider
"""
if tool_choice:
warnings.warn(
"A ToolChoice was provided to this provider but is not supported and will be ignored",
stacklevel=4,
)
38 changes: 33 additions & 5 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from ..types.content import ContentBlock, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolSpec
from ._config_validation import validate_config_keys
from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec
from ._validation import validate_config_keys
from .model import Model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -195,14 +195,19 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:
return formatted_messages

def format_request(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
self,
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
tool_choice: ToolChoice | None = None,
) -> dict[str, Any]:
"""Format an Anthropic streaming request.

Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.

Returns:
An Anthropic streaming request.
Expand All @@ -223,10 +228,25 @@ def format_request(
}
for tool_spec in tool_specs or []
],
**(self._format_tool_choice(tool_choice)),
**({"system": system_prompt} if system_prompt else {}),
**(self.config.get("params") or {}),
}

@staticmethod
def _format_tool_choice(tool_choice: ToolChoice | None) -> dict:
if tool_choice is None:
return {}

if "any" in tool_choice:
return {"tool_choice": {"type": "any"}}
elif "auto" in tool_choice:
return {"tool_choice": {"type": "auto"}}
elif "tool" in tool_choice:
return {"tool_choice": {"type": "tool", "name": cast(ToolChoiceToolDict, tool_choice)["tool"]["name"]}}
else:
return {}

def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""Format the Anthropic response events into standardized message chunks.

Expand Down Expand Up @@ -350,6 +370,7 @@ async def stream(
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
tool_choice: ToolChoice | None = None,
**kwargs: Any,
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the Anthropic model.
Expand All @@ -358,6 +379,7 @@ async def stream(
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Expand All @@ -368,7 +390,7 @@ async def stream(
ModelThrottledException: If the request is throttled by Anthropic.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
logger.debug("request=<%s>", request)

logger.debug("invoking model")
Expand Down Expand Up @@ -410,7 +432,13 @@ async def structured_output(
"""
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs)
response = self.stream(
messages=prompt,
tool_specs=[tool_spec],
system_prompt=system_prompt,
tool_choice=cast(ToolChoice, {"any": {}}),
**kwargs,
)
async for event in process_stream(response):
yield event

Expand Down
17 changes: 12 additions & 5 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
ModelThrottledException,
)
from ..types.streaming import CitationsDelta, StreamEvent
from ..types.tools import ToolResult, ToolSpec
from ._config_validation import validate_config_keys
from ..types.tools import ToolChoice, ToolResult, ToolSpec
from ._validation import validate_config_keys
from .model import Model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -195,13 +195,15 @@ def format_request(
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
tool_choice: ToolChoice | None = None,
) -> dict[str, Any]:
"""Format a Bedrock converse stream request.

Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.

Returns:
A Bedrock converse stream request.
Expand All @@ -224,7 +226,7 @@ def format_request(
else []
),
],
"toolChoice": {"auto": {}},
**({"toolChoice": tool_choice if tool_choice else {"auto": {}}}),
}
}
if tool_specs
Expand Down Expand Up @@ -416,6 +418,7 @@ async def stream(
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
tool_choice: ToolChoice | None = None,
**kwargs: Any,
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the Bedrock model.
Expand All @@ -427,6 +430,7 @@ async def stream(
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Expand All @@ -445,7 +449,7 @@ def callback(event: Optional[StreamEvent] = None) -> None:
loop = asyncio.get_event_loop()
queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue()

thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt)
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt, tool_choice)
task = asyncio.create_task(thread)

while True:
Expand All @@ -463,6 +467,7 @@ def _stream(
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
tool_choice: ToolChoice | None = None,
) -> None:
"""Stream conversation with the Bedrock model.

Expand All @@ -474,14 +479,15 @@ def _stream(
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.

Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the model service is throttling requests.
"""
try:
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
logger.debug("request=<%s>", request)

logger.debug("invoking model")
Expand Down Expand Up @@ -738,6 +744,7 @@ async def structured_output(
messages=prompt,
tool_specs=[tool_spec],
system_prompt=system_prompt,
tool_choice=cast(ToolChoice, {"any": {}}),
**kwargs,
)
async for event in streaming.process_stream(response):
Expand Down
8 changes: 5 additions & 3 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from ..types.content import ContentBlock, Messages
from ..types.streaming import StreamEvent
from ..types.tools import ToolSpec
from ._config_validation import validate_config_keys
from ..types.tools import ToolChoice, ToolSpec
from ._validation import validate_config_keys
from .openai import OpenAIModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -114,6 +114,7 @@ async def stream(
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
tool_choice: ToolChoice | None = None,
**kwargs: Any,
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the LiteLLM model.
Expand All @@ -122,13 +123,14 @@ async def stream(
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Formatted message chunks from the model.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
logger.debug("request=<%s>", request)

logger.debug("invoking model")
Expand Down
9 changes: 7 additions & 2 deletions src/strands/models/llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from ..types.content import ContentBlock, Messages
from ..types.exceptions import ModelThrottledException
from ..types.streaming import StreamEvent, Usage
from ..types.tools import ToolResult, ToolSpec, ToolUse
from ._config_validation import validate_config_keys
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
from ._validation import validate_config_keys, warn_on_tool_choice_not_supported
from .model import Model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -330,6 +330,7 @@ async def stream(
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
tool_choice: ToolChoice | None = None,
**kwargs: Any,
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the LlamaAPI model.
Expand All @@ -338,6 +339,8 @@ async def stream(
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
interface consistency but is currently ignored for this model provider.**
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Expand All @@ -346,6 +349,8 @@ async def stream(
Raises:
ModelThrottledException: When the model service is throttling requests from the client.
"""
warn_on_tool_choice_not_supported(tool_choice)

logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
logger.debug("request=<%s>", request)
Expand Down
9 changes: 7 additions & 2 deletions src/strands/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from ..types.content import ContentBlock, Messages
from ..types.exceptions import ModelThrottledException
from ..types.streaming import StopReason, StreamEvent
from ..types.tools import ToolResult, ToolSpec, ToolUse
from ._config_validation import validate_config_keys
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
from ._validation import validate_config_keys, warn_on_tool_choice_not_supported
from .model import Model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -397,6 +397,7 @@ async def stream(
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
tool_choice: ToolChoice | None = None,
**kwargs: Any,
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the Mistral model.
Expand All @@ -405,6 +406,8 @@ async def stream(
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
interface consistency but is currently ignored for this model provider.**
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Expand All @@ -413,6 +416,8 @@ async def stream(
Raises:
ModelThrottledException: When the model service is throttling requests.
"""
warn_on_tool_choice_not_supported(tool_choice)

logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
logger.debug("request=<%s>", request)
Expand Down
4 changes: 3 additions & 1 deletion src/strands/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ..types.content import Messages
from ..types.streaming import StreamEvent
from ..types.tools import ToolSpec
from ..types.tools import ToolChoice, ToolSpec

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,6 +70,7 @@ def stream(
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
tool_choice: ToolChoice | None = None,
**kwargs: Any,
) -> AsyncIterable[StreamEvent]:
"""Stream conversation with the model.
Expand All @@ -84,6 +85,7 @@ def stream(
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Expand Down
9 changes: 7 additions & 2 deletions src/strands/models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

from ..types.content import ContentBlock, Messages
from ..types.streaming import StopReason, StreamEvent
from ..types.tools import ToolSpec
from ._config_validation import validate_config_keys
from ..types.tools import ToolChoice, ToolSpec
from ._validation import validate_config_keys, warn_on_tool_choice_not_supported
from .model import Model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -287,6 +287,7 @@ async def stream(
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
tool_choice: ToolChoice | None = None,
**kwargs: Any,
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the Ollama model.
Expand All @@ -295,11 +296,15 @@ async def stream(
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
interface consistency but is currently ignored for this model provider.**
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Formatted message chunks from the model.
"""
warn_on_tool_choice_not_supported(tool_choice)

logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
logger.debug("request=<%s>", request)
Expand Down
Loading
Loading