diff --git a/src/strands/models/_config_validation.py b/src/strands/models/_config_validation.py new file mode 100644 index 000000000..085449bb8 --- /dev/null +++ b/src/strands/models/_config_validation.py @@ -0,0 +1,27 @@ +"""Configuration validation utilities for model providers.""" + +import warnings +from typing import Any, Mapping, Type + +from typing_extensions import get_type_hints + + +def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> None: + """Validate that config keys match the TypedDict fields. + + Args: + config_dict: Dictionary of configuration parameters + config_class: TypedDict class to validate against + """ + valid_keys = set(get_type_hints(config_class).keys()) + provided_keys = set(config_dict.keys()) + invalid_keys = provided_keys - valid_keys + + if invalid_keys: + warnings.warn( + f"Invalid configuration parameters: {sorted(invalid_keys)}." + f"\nValid parameters are: {sorted(valid_keys)}." + f"\n" + f"\nSee https://github.com/strands-agents/sdk-python/issues/815", + stacklevel=4, + ) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 29cb40d40..06dc816f2 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -19,6 +19,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolSpec +from ._config_validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -67,6 +68,7 @@ def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_conf For a complete list of supported arguments, see https://docs.anthropic.com/en/api/client-sdks. **model_config: Configuration options for the Anthropic model. """ + validate_config_keys(model_config, self.AnthropicConfig) self.config = AnthropicModel.AnthropicConfig(**model_config) logger.debug("config=<%s> | initializing", self.config) @@ -81,6 +83,7 @@ def update_config(self, **model_config: Unpack[AnthropicConfig]) -> None: # typ Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.AnthropicConfig) self.config.update(model_config) @override diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index aa19b114d..f18422191 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -24,6 +24,7 @@ ) from ..types.streaming import CitationsDelta, StreamEvent from ..types.tools import ToolResult, ToolSpec +from ._config_validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -166,6 +167,7 @@ def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.BedrockConfig) self.config.update(model_config) @override diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index c1e99f1a2..9a31e82df 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -15,6 +15,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent from ..types.tools import ToolSpec +from ._config_validation import validate_config_keys from .openai import OpenAIModel logger = logging.getLogger(__name__) @@ -49,6 +50,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: **model_config: Configuration options for the LiteLLM model. """ self.client_args = client_args or {} + validate_config_keys(model_config, self.LiteLLMConfig) self.config = dict(model_config) logger.debug("config=<%s> | initializing", self.config) @@ -60,6 +62,7 @@ def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.LiteLLMConfig) self.config.update(model_config) @override diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 421b06e52..57ff85c66 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -19,6 +19,7 @@ from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent, Usage from ..types.tools import ToolResult, ToolSpec, ToolUse +from ._config_validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -60,6 +61,7 @@ def __init__( client_args: Arguments for the Llama API client. **model_config: Configuration options for the Llama API model. """ + validate_config_keys(model_config, self.LlamaConfig) self.config = LlamaAPIModel.LlamaConfig(**model_config) logger.debug("config=<%s> | initializing", self.config) @@ -75,6 +77,7 @@ def update_config(self, **model_config: Unpack[LlamaConfig]) -> None: # type: i Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.LlamaConfig) self.config.update(model_config) @override diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 8855b6d64..401dde98e 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -16,6 +16,7 @@ from ..types.exceptions import ModelThrottledException from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolResult, ToolSpec, ToolUse +from ._config_validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -82,6 +83,7 @@ def __init__( if not 0.0 <= top_p <= 1.0: raise ValueError(f"top_p must be between 0.0 and 1.0, got {top_p}") + validate_config_keys(model_config, self.MistralConfig) self.config = MistralModel.MistralConfig(**model_config) # Set default stream to True if not specified @@ -101,6 +103,7 @@ def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.MistralConfig) self.config.update(model_config) @override diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 76cd87d72..4025dc062 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -14,6 +14,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolSpec +from ._config_validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -70,6 +71,7 @@ def __init__( """ self.host = host self.client_args = ollama_client_args or {} + validate_config_keys(model_config, self.OllamaConfig) self.config = OllamaModel.OllamaConfig(**model_config) logger.debug("config=<%s> | initializing", self.config) @@ -81,6 +83,7 @@ def update_config(self, **model_config: Unpack[OllamaConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.OllamaConfig) self.config.update(model_config) @override diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 1076fbae4..16eb4defe 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -17,6 +17,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent from ..types.tools import ToolResult, ToolSpec, ToolUse +from ._config_validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -61,6 +62,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: For a complete list of supported arguments, see https://pypi.org/project/openai/. **model_config: Configuration options for the OpenAI model. """ + validate_config_keys(model_config, self.OpenAIConfig) self.config = dict(model_config) logger.debug("config=<%s> | initializing", self.config) @@ -75,6 +77,7 @@ def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.OpenAIConfig) self.config.update(model_config) @override diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 9cfe27d9e..74069b895 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -15,6 +15,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent from ..types.tools import ToolResult, ToolSpec +from ._config_validation import validate_config_keys from .openai import OpenAIModel T = TypeVar("T", bound=BaseModel) @@ -146,6 +147,8 @@ def __init__( boto_session: Boto Session to use when calling the SageMaker Runtime. boto_client_config: Configuration to use when creating the SageMaker-Runtime Boto Client. """ + validate_config_keys(endpoint_config, self.SageMakerAIEndpointConfig) + validate_config_keys(payload_config, self.SageMakerAIPayloadSchema) payload_config.setdefault("stream", True) payload_config.setdefault("tool_results_as_user_messages", False) self.endpoint_config = dict(endpoint_config) @@ -180,6 +183,7 @@ def update_config(self, **endpoint_config: Unpack[SageMakerAIEndpointConfig]) -> Args: **endpoint_config: Configuration overrides. """ + validate_config_keys(endpoint_config, self.SageMakerAIEndpointConfig) self.endpoint_config.update(endpoint_config) @override diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index f6a3da3d8..9bcdaad42 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -17,6 +17,7 @@ from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolResult, ToolSpec, ToolUse +from ._config_validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -53,6 +54,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: client_args: Arguments for the Writer client (e.g., api_key, base_url, timeout, etc.). **model_config: Configuration options for the Writer model. """ + validate_config_keys(model_config, self.WriterConfig) self.config = WriterModel.WriterConfig(**model_config) logger.debug("config=<%s> | initializing", self.config) @@ -67,6 +69,7 @@ def update_config(self, **model_config: Unpack[WriterConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.WriterConfig) self.config.update(model_config) @override diff --git a/tests/conftest.py b/tests/conftest.py index 3b82e362c..f2a8909cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import logging import os import sys +import warnings import boto3 import moto @@ -107,3 +108,13 @@ def generate(generator): return events, stop.value return generate + + +## Warnings + + +@pytest.fixture +def captured_warnings(): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + yield w diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 5e8d69ea7..9a7a4be11 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -767,3 +767,21 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls, tru_result = events[-1] exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result + + +def test_config_validation_warns_on_unknown_keys(anthropic_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + AnthropicModel(model_id="test-model", max_tokens=100, invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index f2e459bde..624eec6e9 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1445,3 +1445,21 @@ async def test_stream_deepseek_skips_empty_messages(bedrock_client, alist): assert len(sent_messages) == 2 assert sent_messages[0]["content"] == [{"text": "Hello"}] assert sent_messages[1]["content"] == [{"text": "Follow up"}] + + +def test_config_validation_warns_on_unknown_keys(bedrock_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + BedrockModel(model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 44b6df63b..9140cadcc 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -252,3 +252,21 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result + + +def test_config_validation_warns_on_unknown_keys(litellm_acompletion, captured_warnings): + """Test that unknown config keys emit a warning.""" + LiteLLMModel(client_args={"api_key": "test"}, model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index 309dac2e9..712ef8b7a 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -361,3 +361,21 @@ def test_format_chunk_other(model): with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): model.format_chunk(event) + + +def test_config_validation_warns_on_unknown_keys(llamaapi_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + LlamaAPIModel(model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index 2a78024f2..9b3f62a31 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -539,3 +539,21 @@ async def test_structured_output_invalid_json(mistral_client, model, test_output with pytest.raises(ValueError, match="Failed to parse tool call arguments into model"): stream = model.structured_output(test_output_model_cls, prompt) await anext(stream) + + +def test_config_validation_warns_on_unknown_keys(mistral_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + MistralModel(model_id="test-model", max_tokens=100, invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index c3fb7736e..9a63a3214 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -516,3 +516,21 @@ async def test_structured_output(ollama_client, model, test_output_model_cls, al tru_result = events[-1] exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result + + +def test_config_validation_warns_on_unknown_keys(ollama_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + OllamaModel("http://localhost:11434", model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index a7c97701c..00cae7447 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -583,3 +583,21 @@ async def test_structured_output(openai_client, model, test_output_model_cls, al tru_result = events[-1] exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result + + +def test_config_validation_warns_on_unknown_keys(openai_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + OpenAIModel({"api_key": "test"}, model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index ba395b2d6..a9071c7e2 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -572,3 +572,44 @@ def test_tool_call(self): assert tool2.type == "function" assert tool2.function.name == "get_time" assert tool2.function.arguments == '{"timezone": "UTC"}' + + +def test_config_validation_warns_on_unknown_keys_in_endpoint(boto_session, captured_warnings): + """Test that unknown config keys emit a warning.""" + endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1", "invalid_param": "test"} + payload_config = {"max_tokens": 1024} + + SageMakerAIModel( + endpoint_config=endpoint_config, + payload_config=payload_config, + boto_session=boto_session, + ) + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_config_validation_warns_on_unknown_keys_in_payload(boto_session, captured_warnings): + """Test that unknown config keys emit a warning.""" + endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"} + payload_config = {"max_tokens": 1024, "invalid_param": "test"} + + SageMakerAIModel( + endpoint_config=endpoint_config, + payload_config=payload_config, + boto_session=boto_session, + ) + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_writer.py b/tests/strands/models/test_writer.py index f7748cfdb..75896ca68 100644 --- a/tests/strands/models/test_writer.py +++ b/tests/strands/models/test_writer.py @@ -380,3 +380,21 @@ async def test_stream_with_empty_choices(writer_client, model, model_id): "stream_options": {"include_usage": True}, } writer_client.chat.chat.assert_called_once_with(**expected_request) + + +def test_config_validation_warns_on_unknown_keys(writer_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + WriterModel({"api_key": "test"}, model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message)