Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/strands/models/_config_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Configuration validation utilities for model providers."""

import warnings
from typing import Any, Mapping, Type

from typing_extensions import get_type_hints


def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> None:
"""Validate that config keys match the TypedDict fields.

Args:
config_dict: Dictionary of configuration parameters
config_class: TypedDict class to validate against
"""
valid_keys = set(get_type_hints(config_class).keys())
provided_keys = set(config_dict.keys())
invalid_keys = provided_keys - valid_keys

if invalid_keys:
warnings.warn(
f"Invalid configuration parameters: {sorted(invalid_keys)}."
f"\nValid parameters are: {sorted(valid_keys)}."
f"\n"
f"\nSee https://github.com/strands-agents/sdk-python/issues/815",
stacklevel=4,
)
3 changes: 3 additions & 0 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolSpec
from ._config_validation import validate_config_keys
from .model import Model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_conf
For a complete list of supported arguments, see https://docs.anthropic.com/en/api/client-sdks.
**model_config: Configuration options for the Anthropic model.
"""
validate_config_keys(model_config, self.AnthropicConfig)
self.config = AnthropicModel.AnthropicConfig(**model_config)

logger.debug("config=<%s> | initializing", self.config)
Expand All @@ -81,6 +83,7 @@ def update_config(self, **model_config: Unpack[AnthropicConfig]) -> None: # typ
Args:
**model_config: Configuration overrides.
"""
validate_config_keys(model_config, self.AnthropicConfig)
self.config.update(model_config)

@override
Expand Down
2 changes: 2 additions & 0 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from ..types.streaming import CitationsDelta, StreamEvent
from ..types.tools import ToolResult, ToolSpec
from ._config_validation import validate_config_keys
from .model import Model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -166,6 +167,7 @@ def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type:
Args:
**model_config: Configuration overrides.
"""
validate_config_keys(model_config, self.BedrockConfig)
self.config.update(model_config)

@override
Expand Down
3 changes: 3 additions & 0 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..types.content import ContentBlock, Messages
from ..types.streaming import StreamEvent
from ..types.tools import ToolSpec
from ._config_validation import validate_config_keys
from .openai import OpenAIModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
**model_config: Configuration options for the LiteLLM model.
"""
self.client_args = client_args or {}
validate_config_keys(model_config, self.LiteLLMConfig)
self.config = dict(model_config)

logger.debug("config=<%s> | initializing", self.config)
Expand All @@ -60,6 +62,7 @@ def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type:
Args:
**model_config: Configuration overrides.
"""
validate_config_keys(model_config, self.LiteLLMConfig)
self.config.update(model_config)

@override
Expand Down
3 changes: 3 additions & 0 deletions src/strands/models/llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..types.exceptions import ModelThrottledException
from ..types.streaming import StreamEvent, Usage
from ..types.tools import ToolResult, ToolSpec, ToolUse
from ._config_validation import validate_config_keys
from .model import Model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(
client_args: Arguments for the Llama API client.
**model_config: Configuration options for the Llama API model.
"""
validate_config_keys(model_config, self.LlamaConfig)
self.config = LlamaAPIModel.LlamaConfig(**model_config)
logger.debug("config=<%s> | initializing", self.config)

Expand All @@ -75,6 +77,7 @@ def update_config(self, **model_config: Unpack[LlamaConfig]) -> None: # type: i
Args:
**model_config: Configuration overrides.
"""
validate_config_keys(model_config, self.LlamaConfig)
self.config.update(model_config)

@override
Expand Down
3 changes: 3 additions & 0 deletions src/strands/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..types.exceptions import ModelThrottledException
from ..types.streaming import StopReason, StreamEvent
from ..types.tools import ToolResult, ToolSpec, ToolUse
from ._config_validation import validate_config_keys
from .model import Model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(
if not 0.0 <= top_p <= 1.0:
raise ValueError(f"top_p must be between 0.0 and 1.0, got {top_p}")

validate_config_keys(model_config, self.MistralConfig)
self.config = MistralModel.MistralConfig(**model_config)

# Set default stream to True if not specified
Expand All @@ -101,6 +103,7 @@ def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type:
Args:
**model_config: Configuration overrides.
"""
validate_config_keys(model_config, self.MistralConfig)
self.config.update(model_config)

@override
Expand Down
3 changes: 3 additions & 0 deletions src/strands/models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..types.content import ContentBlock, Messages
from ..types.streaming import StopReason, StreamEvent
from ..types.tools import ToolSpec
from ._config_validation import validate_config_keys
from .model import Model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
"""
self.host = host
self.client_args = ollama_client_args or {}
validate_config_keys(model_config, self.OllamaConfig)
self.config = OllamaModel.OllamaConfig(**model_config)

logger.debug("config=<%s> | initializing", self.config)
Expand All @@ -81,6 +83,7 @@ def update_config(self, **model_config: Unpack[OllamaConfig]) -> None: # type:
Args:
**model_config: Configuration overrides.
"""
validate_config_keys(model_config, self.OllamaConfig)
self.config.update(model_config)

@override
Expand Down
3 changes: 3 additions & 0 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..types.content import ContentBlock, Messages
from ..types.streaming import StreamEvent
from ..types.tools import ToolResult, ToolSpec, ToolUse
from ._config_validation import validate_config_keys
from .model import Model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
For a complete list of supported arguments, see https://pypi.org/project/openai/.
**model_config: Configuration options for the OpenAI model.
"""
validate_config_keys(model_config, self.OpenAIConfig)
self.config = dict(model_config)

logger.debug("config=<%s> | initializing", self.config)
Expand All @@ -75,6 +77,7 @@ def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type:
Args:
**model_config: Configuration overrides.
"""
validate_config_keys(model_config, self.OpenAIConfig)
self.config.update(model_config)

@override
Expand Down
4 changes: 4 additions & 0 deletions src/strands/models/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..types.content import ContentBlock, Messages
from ..types.streaming import StreamEvent
from ..types.tools import ToolResult, ToolSpec
from ._config_validation import validate_config_keys
from .openai import OpenAIModel

T = TypeVar("T", bound=BaseModel)
Expand Down Expand Up @@ -146,6 +147,8 @@ def __init__(
boto_session: Boto Session to use when calling the SageMaker Runtime.
boto_client_config: Configuration to use when creating the SageMaker-Runtime Boto Client.
"""
validate_config_keys(endpoint_config, self.SageMakerAIEndpointConfig)
validate_config_keys(payload_config, self.SageMakerAIPayloadSchema)
payload_config.setdefault("stream", True)
payload_config.setdefault("tool_results_as_user_messages", False)
self.endpoint_config = dict(endpoint_config)
Expand Down Expand Up @@ -180,6 +183,7 @@ def update_config(self, **endpoint_config: Unpack[SageMakerAIEndpointConfig]) ->
Args:
**endpoint_config: Configuration overrides.
"""
validate_config_keys(endpoint_config, self.SageMakerAIEndpointConfig)
self.endpoint_config.update(endpoint_config)

@override
Expand Down
3 changes: 3 additions & 0 deletions src/strands/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..types.exceptions import ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolResult, ToolSpec, ToolUse
from ._config_validation import validate_config_keys
from .model import Model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
client_args: Arguments for the Writer client (e.g., api_key, base_url, timeout, etc.).
**model_config: Configuration options for the Writer model.
"""
validate_config_keys(model_config, self.WriterConfig)
self.config = WriterModel.WriterConfig(**model_config)

logger.debug("config=<%s> | initializing", self.config)
Expand All @@ -67,6 +69,7 @@ def update_config(self, **model_config: Unpack[WriterConfig]) -> None: # type:
Args:
**model_config: Configuration overrides.
"""
validate_config_keys(model_config, self.WriterConfig)
self.config.update(model_config)

@override
Expand Down
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import sys
import warnings

import boto3
import moto
Expand Down Expand Up @@ -107,3 +108,13 @@ def generate(generator):
return events, stop.value

return generate


## Warnings


@pytest.fixture
def captured_warnings():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
yield w
18 changes: 18 additions & 0 deletions tests/strands/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,3 +767,21 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls,
tru_result = events[-1]
exp_result = {"output": test_output_model_cls(name="John", age=30)}
assert tru_result == exp_result


def test_config_validation_warns_on_unknown_keys(anthropic_client, captured_warnings):
"""Test that unknown config keys emit a warning."""
AnthropicModel(model_id="test-model", max_tokens=100, invalid_param="test")

assert len(captured_warnings) == 1
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
assert "invalid_param" in str(captured_warnings[0].message)


def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings):
"""Test that update_config warns on unknown keys."""
model.update_config(wrong_param="test")

assert len(captured_warnings) == 1
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
assert "wrong_param" in str(captured_warnings[0].message)
18 changes: 18 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,3 +1445,21 @@ async def test_stream_deepseek_skips_empty_messages(bedrock_client, alist):
assert len(sent_messages) == 2
assert sent_messages[0]["content"] == [{"text": "Hello"}]
assert sent_messages[1]["content"] == [{"text": "Follow up"}]


def test_config_validation_warns_on_unknown_keys(bedrock_client, captured_warnings):
"""Test that unknown config keys emit a warning."""
BedrockModel(model_id="test-model", invalid_param="test")

assert len(captured_warnings) == 1
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
assert "invalid_param" in str(captured_warnings[0].message)


def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings):
"""Test that update_config warns on unknown keys."""
model.update_config(wrong_param="test")

assert len(captured_warnings) == 1
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
assert "wrong_param" in str(captured_warnings[0].message)
18 changes: 18 additions & 0 deletions tests/strands/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,21 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c

exp_result = {"output": test_output_model_cls(name="John", age=30)}
assert tru_result == exp_result


def test_config_validation_warns_on_unknown_keys(litellm_acompletion, captured_warnings):
"""Test that unknown config keys emit a warning."""
LiteLLMModel(client_args={"api_key": "test"}, model_id="test-model", invalid_param="test")

assert len(captured_warnings) == 1
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
assert "invalid_param" in str(captured_warnings[0].message)


def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings):
"""Test that update_config warns on unknown keys."""
model.update_config(wrong_param="test")

assert len(captured_warnings) == 1
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
assert "wrong_param" in str(captured_warnings[0].message)
18 changes: 18 additions & 0 deletions tests/strands/models/test_llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,21 @@ def test_format_chunk_other(model):

with pytest.raises(RuntimeError, match="chunk_type=<other> | unknown type"):
model.format_chunk(event)


def test_config_validation_warns_on_unknown_keys(llamaapi_client, captured_warnings):
"""Test that unknown config keys emit a warning."""
LlamaAPIModel(model_id="test-model", invalid_param="test")

assert len(captured_warnings) == 1
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
assert "invalid_param" in str(captured_warnings[0].message)


def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings):
"""Test that update_config warns on unknown keys."""
model.update_config(wrong_param="test")

assert len(captured_warnings) == 1
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
assert "wrong_param" in str(captured_warnings[0].message)
18 changes: 18 additions & 0 deletions tests/strands/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,21 @@ async def test_structured_output_invalid_json(mistral_client, model, test_output
with pytest.raises(ValueError, match="Failed to parse tool call arguments into model"):
stream = model.structured_output(test_output_model_cls, prompt)
await anext(stream)


def test_config_validation_warns_on_unknown_keys(mistral_client, captured_warnings):
"""Test that unknown config keys emit a warning."""
MistralModel(model_id="test-model", max_tokens=100, invalid_param="test")

assert len(captured_warnings) == 1
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
assert "invalid_param" in str(captured_warnings[0].message)


def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings):
"""Test that update_config warns on unknown keys."""
model.update_config(wrong_param="test")

assert len(captured_warnings) == 1
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
assert "wrong_param" in str(captured_warnings[0].message)
18 changes: 18 additions & 0 deletions tests/strands/models/test_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,3 +516,21 @@ async def test_structured_output(ollama_client, model, test_output_model_cls, al
tru_result = events[-1]
exp_result = {"output": test_output_model_cls(name="John", age=30)}
assert tru_result == exp_result


def test_config_validation_warns_on_unknown_keys(ollama_client, captured_warnings):
"""Test that unknown config keys emit a warning."""
OllamaModel("http://localhost:11434", model_id="test-model", invalid_param="test")

assert len(captured_warnings) == 1
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
assert "invalid_param" in str(captured_warnings[0].message)


def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings):
"""Test that update_config warns on unknown keys."""
model.update_config(wrong_param="test")

assert len(captured_warnings) == 1
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
assert "wrong_param" in str(captured_warnings[0].message)
Loading
Loading