Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ ignore_missing_imports = true

[tool.ruff]
line-length = 120
include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests-integ/**/*.py"]
include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests_integ/**/*.py"]

[tool.ruff.lint]
select = [
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file added tests_integ/models/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions tests_integ/models/providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
from dataclasses import dataclass

import requests
from pytest import mark


@dataclass
class ApiKeyProviderInfo:
"""Provider-based info for providers that require an APIKey via environment variables."""

def __init__(self, id: str, environment_variable: str) -> None:
self.id = id
self.environment_variable = environment_variable
self.mark = mark.skipif(
self.environment_variable not in os.environ,
reason=f"{self.environment_variable} environment variable missing",
)


class OllamaProviderInfo:
"""Special case ollama as it's dependent on the server being available."""

def __init__(self):
self.id = "ollama"

is_server_available = False
try:
is_server_available = requests.get("http://localhost:11434").ok
except requests.exceptions.ConnectionError:
pass

self.mark = mark.skipif(
not is_server_available,
reason="Local Ollama endpoint not available at localhost:11434",
)


anthropic = ApiKeyProviderInfo(id="anthropic", environment_variable="ANTHROPIC_API_KEY")
cohere = ApiKeyProviderInfo(id="cohere", environment_variable="CO_API_KEY")
llama = ApiKeyProviderInfo(id="cohere", environment_variable="LLAMA_API_KEY")
mistral = ApiKeyProviderInfo(id="mistral", environment_variable="MISTRAL_API_KEY")
openai = ApiKeyProviderInfo(id="openai", environment_variable="OPENAI_API_KEY")
writer = ApiKeyProviderInfo(id="writer", environment_variable="WRITER_API_KEY")

ollama = OllamaProviderInfo()
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import strands
from strands import Agent
from strands.models.anthropic import AnthropicModel
from tests_integ.models import providers

# these tests only run if we have the anthropic api key
pytestmark = providers.anthropic.mark


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -53,15 +57,13 @@ class Weather(BaseModel):
return Weather(time="12:00", weather="sunny")


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
def test_agent_invoke(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_invoke_async(agent):
result = await agent.invoke_async("What is the time and weather in New York?")
Expand All @@ -70,7 +72,6 @@ async def test_agent_invoke_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_stream_async(agent):
stream = agent.stream_async("What is the time and weather in New York?")
Expand All @@ -83,14 +84,12 @@ async def test_agent_stream_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
def test_structured_output(agent, weather):
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_structured_output_async(agent, weather):
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import strands
from strands import Agent
from strands.models.openai import OpenAIModel
from tests_integ.models import providers

# these tests only run if we have the cohere api key
pytestmark = providers.cohere.mark


@pytest.fixture
Expand Down Expand Up @@ -37,10 +41,6 @@ def agent(model, tools):
return Agent(model=model, tools=tools)


@pytest.mark.skipif(
"CO_API_KEY" not in os.environ,
reason="CO_API_KEY environment variable missing",
)
def test_agent(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import strands
from strands import Agent
from strands.models.llamaapi import LlamaAPIModel
from tests_integ.models import providers

# these tests only run if we have the llama api key
pytestmark = providers.llama.mark


@pytest.fixture
Expand Down Expand Up @@ -36,10 +40,6 @@ def agent(model, tools):
return Agent(model=model, tools=tools)


@pytest.mark.skipif(
"LLAMA_API_KEY" not in os.environ,
reason="LLAMA_API_KEY environment variable missing",
)
def test_agent(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import strands
from strands import Agent
from strands.models.mistral import MistralModel
from tests_integ.models import providers

# these tests only run if we have the mistral api key
pytestmark = providers.mistral.mark


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -76,15 +80,13 @@ class Weather(BaseModel):
return Weather(time="12:00", weather="sunny")


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
def test_agent_invoke(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_invoke_async(agent):
result = await agent.invoke_async("What is the time and weather in New York?")
Expand All @@ -93,7 +95,6 @@ async def test_agent_invoke_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_stream_async(agent):
stream = agent.stream_async("What is the time and weather in New York?")
Expand All @@ -106,14 +107,12 @@ async def test_agent_stream_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
def test_agent_structured_output(non_streaming_agent, weather):
tru_weather = non_streaming_agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_structured_output_async(non_streaming_agent, weather):
tru_weather = await non_streaming_agent.structured_output_async(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import pytest
import requests
from pydantic import BaseModel

import strands
from strands import Agent
from strands.models.ollama import OllamaModel
from tests_integ.models import providers


def is_server_available() -> bool:
try:
return requests.get("http://localhost:11434").ok
except requests.exceptions.ConnectionError:
return False
# these tests only run if we have the ollama is running
pytestmark = providers.ollama.mark


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -48,15 +44,13 @@ class Weather(BaseModel):
return Weather(time="12:00", weather="sunny")


@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
def test_agent_invoke(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
@pytest.mark.asyncio
async def test_agent_invoke_async(agent):
result = await agent.invoke_async("What is the time and weather in New York?")
Expand All @@ -65,7 +59,6 @@ async def test_agent_invoke_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
@pytest.mark.asyncio
async def test_agent_stream_async(agent):
stream = agent.stream_async("What is the time and weather in New York?")
Expand All @@ -78,14 +71,12 @@ async def test_agent_stream_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
def test_agent_structured_output(agent, weather):
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather


@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
@pytest.mark.asyncio
async def test_agent_structured_output_async(agent, weather):
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import strands
from strands import Agent, tool

if "OPENAI_API_KEY" not in os.environ:
pytest.skip(allow_module_level=True, reason="OPENAI_API_KEY environment variable missing")

from strands.models.openai import OpenAIModel
from tests_integ.models import providers

# these tests only run if we have the openai api key
pytestmark = providers.openai.mark


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -53,7 +53,7 @@ class Weather(BaseModel):

@pytest.fixture(scope="module")
def test_image_path(request):
return request.config.rootpath / "tests-integ" / "test_image.png"
return request.config.rootpath / "tests_integ" / "test_image.png"


def test_agent_invoke(agent):
Expand Down Expand Up @@ -96,6 +96,7 @@ async def test_agent_structured_output_async(agent, weather):
assert tru_weather == exp_weather


@pytest.skip("https://github.com/strands-agents/sdk-python/issues/320")
def test_tool_returning_images(model, test_image_path):
@tool
def tool_with_image_return():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import strands
from strands import Agent
from strands.models.writer import WriterModel
from tests_integ.models import providers

# these tests only run if we have the writer api key
pytestmark = providers.writer.mark


@pytest.fixture
Expand Down Expand Up @@ -40,7 +44,6 @@ def agent(model, tools, system_prompt):
return Agent(model=model, tools=tools, system_prompt=system_prompt, load_tools_from_directory=False)


@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing")
def test_agent(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()
Expand All @@ -49,7 +52,6 @@ def test_agent(agent):


@pytest.mark.asyncio
@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing")
async def test_agent_async(agent):
result = await agent.invoke_async("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()
Expand All @@ -58,7 +60,6 @@ async def test_agent_async(agent):


@pytest.mark.asyncio
@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing")
async def test_agent_stream_async(agent):
stream = agent.stream_async("What is the time and weather in New York?")
async for event in stream:
Expand All @@ -70,7 +71,6 @@ async def test_agent_stream_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing")
def test_structured_output(agent):
class Weather(BaseModel):
time: str
Expand All @@ -84,7 +84,6 @@ class Weather(BaseModel):


@pytest.mark.asyncio
@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing")
async def test_structured_output_async(agent):
class Weather(BaseModel):
time: str
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def calculator(x: int, y: int) -> int:
@mcp.tool(description="Generates a custom image")
def generate_custom_image() -> MCPImageContent:
try:
with open("tests-integ/test_image.png", "rb") as image_file:
with open("tests_integ/test_image.png", "rb") as image_file:
encoded_image = base64.b64encode(image_file.read())
return MCPImageContent(type="image", data=encoded_image, mimeType="image/png")
except Exception as e:
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_mcp_client():

sse_mcp_client = MCPClient(lambda: sse_client("http://127.0.0.1:8000/sse"))
stdio_mcp_client = MCPClient(
lambda: stdio_client(StdioServerParameters(command="python", args=["tests-integ/echo_server.py"]))
lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
)
with sse_mcp_client, stdio_mcp_client:
agent = Agent(tools=sse_mcp_client.list_tools_sync() + stdio_mcp_client.list_tools_sync())
Expand All @@ -90,7 +90,7 @@ def test_mcp_client():

def test_can_reuse_mcp_client():
stdio_mcp_client = MCPClient(
lambda: stdio_client(StdioServerParameters(command="python", args=["tests-integ/echo_server.py"]))
lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
)
with stdio_mcp_client:
stdio_mcp_client.list_tools_sync()
Expand Down
File renamed without changes.
Loading
Loading