diff --git a/tests_integ/models/conformance.py b/tests_integ/models/conformance.py new file mode 100644 index 000000000..262e41e42 --- /dev/null +++ b/tests_integ/models/conformance.py @@ -0,0 +1,30 @@ +import pytest + +from strands.types.models import Model +from tests_integ.models.providers import ProviderInfo, all_providers + + +def get_models(): + return [ + pytest.param( + provider_info, + id=provider_info.id, # Adds the provider name to the test name + marks=[provider_info.mark], # ignores tests that don't have the requirements + ) + for provider_info in all_providers + ] + + +@pytest.fixture(params=get_models()) +def provider_info(request) -> ProviderInfo: + return request.param + + +@pytest.fixture() +def model(provider_info): + return provider_info.create_model() + + +def test_model_can_be_constructed(model: Model): + assert model is not None + pass diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index a789f7b41..f15628eaf 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -1,28 +1,51 @@ +""" +Aggregates all providers for testing all providers in one go. +""" + import os -from dataclasses import dataclass +from typing import Callable, Optional import requests from pytest import mark +from strands.models import BedrockModel +from strands.models.anthropic import AnthropicModel +from strands.models.litellm import LiteLLMModel +from strands.models.llamaapi import LlamaAPIModel +from strands.models.mistral import MistralModel +from strands.models.ollama import OllamaModel +from strands.models.openai import OpenAIModel +from strands.models.writer import WriterModel +from strands.types.models import Model + -@dataclass -class ApiKeyProviderInfo: +class ProviderInfo: """Provider-based info for providers that require an APIKey via environment variables.""" - def __init__(self, id: str, environment_variable: str) -> None: + def __init__( + self, + id: str, + factory: Callable[[], Model], + environment_variable: Optional[str] = None, + ) -> None: self.id = id - self.environment_variable = environment_variable + self.model_factory = factory self.mark = mark.skipif( - self.environment_variable not in os.environ, - reason=f"{self.environment_variable} environment variable missing", + environment_variable is not None and environment_variable not in os.environ, + reason=f"{environment_variable} environment variable missing", ) + def create_model(self) -> Model: + return self.model_factory() + -class OllamaProviderInfo: +class OllamaProviderInfo(ProviderInfo): """Special case ollama as it's dependent on the server being available.""" def __init__(self): - self.id = "ollama" + super().__init__( + id="ollama", factory=lambda: OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b") + ) is_server_available = False try: @@ -36,11 +59,85 @@ def __init__(self): ) -anthropic = ApiKeyProviderInfo(id="anthropic", environment_variable="ANTHROPIC_API_KEY") -cohere = ApiKeyProviderInfo(id="cohere", environment_variable="CO_API_KEY") -llama = ApiKeyProviderInfo(id="cohere", environment_variable="LLAMA_API_KEY") -mistral = ApiKeyProviderInfo(id="mistral", environment_variable="MISTRAL_API_KEY") -openai = ApiKeyProviderInfo(id="openai", environment_variable="OPENAI_API_KEY") -writer = ApiKeyProviderInfo(id="writer", environment_variable="WRITER_API_KEY") +anthropic = ProviderInfo( + id="anthropic", + environment_variable="ANTHROPIC_API_KEY", + factory=lambda: AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-7-sonnet-20250219", + max_tokens=512, + ), +) +bedrock = ProviderInfo(id="bedrock", factory=lambda: BedrockModel()) +cohere = ProviderInfo( + id="cohere", + environment_variable="CO_API_KEY", + factory=lambda: OpenAIModel( + client_args={ + "base_url": "https://api.cohere.com/compatibility/v1", + "api_key": os.getenv("CO_API_KEY"), + }, + model_id="command-a-03-2025", + params={"stream_options": None}, + ), +) +litellm = ProviderInfo( + id="litellm", factory=lambda: LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") +) +llama = ProviderInfo( + id="llama", + environment_variable="LLAMA_API_KEY", + factory=lambda: LlamaAPIModel( + model_id="Llama-4-Maverick-17B-128E-Instruct-FP8", + client_args={ + "api_key": os.getenv("LLAMA_API_KEY"), + }, + ), +) +mistral = ProviderInfo( + id="mistral", + environment_variable="MISTRAL_API_KEY", + factory=lambda: MistralModel( + model_id="mistral-medium-latest", + api_key=os.getenv("MISTRAL_API_KEY"), + stream=True, + temperature=0.7, + max_tokens=1000, + top_p=0.9, + ), +) +openai = ProviderInfo( + id="openai", + environment_variable="OPENAI_API_KEY", + factory=lambda: OpenAIModel( + model_id="gpt-4o", + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ), +) +writer = ProviderInfo( + id="writer", + environment_variable="WRITER_API_KEY", + factory=lambda: WriterModel( + model_id="palmyra-x4", + client_args={"api_key": os.getenv("WRITER_API_KEY", "")}, + stream_options={"include_usage": True}, + ), +) ollama = OllamaProviderInfo() + + +all_providers = [ + bedrock, + anthropic, + cohere, + llama, + litellm, + mistral, + openai, + writer, +]