Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ jobs:
AWS_REGION_NAME: us-east-1 # Needed for LiteLLM
id: tests
run: |
hatch test tests-integ
hatch test tests_integ
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ test = [
"hatch test --cover --cov-report html --cov-report xml {args}"
]
test-integ = [
"hatch test tests-integ {args}"
"hatch test tests_integ {args}"
]
prepare = [
"hatch fmt --linter",
Expand Down Expand Up @@ -230,7 +230,7 @@ ignore_missing_imports = true

[tool.ruff]
line-length = 120
include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests-integ/**/*.py"]
include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests_integ/**/*.py"]

[tool.ruff.lint]
select = [
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file added tests_integ/models/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions tests_integ/models/providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
from dataclasses import dataclass

import requests
from pytest import mark


@dataclass
class ApiKeyProviderInfo:
"""Provider-based info for providers that require an APIKey via environment variables."""

def __init__(self, id: str, environment_variable: str) -> None:
self.id = id
self.environment_variable = environment_variable
self.mark = mark.skipif(
self.environment_variable not in os.environ,
reason=f"{self.environment_variable} environment variable missing",
)


class OllamaProviderInfo:
"""Special case ollama as it's dependent on the server being available."""

def __init__(self):
self.id = "ollama"

is_server_available = False
try:
is_server_available = requests.get("http://localhost:11434").ok
except requests.exceptions.ConnectionError:
pass

self.mark = mark.skipif(
not is_server_available,
reason="Local Ollama endpoint not available at localhost:11434",
)


anthropic = ApiKeyProviderInfo(id="anthropic", environment_variable="ANTHROPIC_API_KEY")
cohere = ApiKeyProviderInfo(id="cohere", environment_variable="CO_API_KEY")
llama = ApiKeyProviderInfo(id="cohere", environment_variable="LLAMA_API_KEY")
mistral = ApiKeyProviderInfo(id="mistral", environment_variable="MISTRAL_API_KEY")
openai = ApiKeyProviderInfo(id="openai", environment_variable="OPENAI_API_KEY")
writer = ApiKeyProviderInfo(id="writer", environment_variable="WRITER_API_KEY")

ollama = OllamaProviderInfo()
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import strands
from strands import Agent
from strands.models.anthropic import AnthropicModel
from tests_integ.models import providers

# these tests only run if we have the anthropic api key
pytestmark = providers.anthropic.mark


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -53,15 +57,13 @@ class Weather(BaseModel):
return Weather(time="12:00", weather="sunny")


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
def test_agent_invoke(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_invoke_async(agent):
result = await agent.invoke_async("What is the time and weather in New York?")
Expand All @@ -70,7 +72,6 @@ async def test_agent_invoke_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_stream_async(agent):
stream = agent.stream_async("What is the time and weather in New York?")
Expand All @@ -83,14 +84,12 @@ async def test_agent_stream_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
def test_structured_output(agent, weather):
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_structured_output_async(agent, weather):
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import strands
from strands import Agent
from strands.models.openai import OpenAIModel
from tests_integ.models import providers

# these tests only run if we have the cohere api key
pytestmark = providers.cohere.mark


@pytest.fixture
Expand Down Expand Up @@ -37,10 +41,6 @@ def agent(model, tools):
return Agent(model=model, tools=tools)


@pytest.mark.skipif(
"CO_API_KEY" not in os.environ,
reason="CO_API_KEY environment variable missing",
)
def test_agent(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import strands
from strands import Agent
from strands.models.llamaapi import LlamaAPIModel
from tests_integ.models import providers

# these tests only run if we have the llama api key
pytestmark = providers.llama.mark


@pytest.fixture
Expand Down Expand Up @@ -36,10 +40,6 @@ def agent(model, tools):
return Agent(model=model, tools=tools)


@pytest.mark.skipif(
"LLAMA_API_KEY" not in os.environ,
reason="LLAMA_API_KEY environment variable missing",
)
def test_agent(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import strands
from strands import Agent
from strands.models.mistral import MistralModel
from tests_integ.models import providers

# these tests only run if we have the mistral api key
pytestmark = providers.mistral.mark


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -76,15 +80,13 @@ class Weather(BaseModel):
return Weather(time="12:00", weather="sunny")


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
def test_agent_invoke(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_invoke_async(agent):
result = await agent.invoke_async("What is the time and weather in New York?")
Expand All @@ -93,7 +95,6 @@ async def test_agent_invoke_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_stream_async(agent):
stream = agent.stream_async("What is the time and weather in New York?")
Expand All @@ -106,14 +107,12 @@ async def test_agent_stream_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
def test_agent_structured_output(non_streaming_agent, weather):
tru_weather = non_streaming_agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_structured_output_async(non_streaming_agent, weather):
tru_weather = await non_streaming_agent.structured_output_async(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import pytest
import requests
from pydantic import BaseModel

import strands
from strands import Agent
from strands.models.ollama import OllamaModel
from tests_integ.models import providers


def is_server_available() -> bool:
try:
return requests.get("http://localhost:11434").ok
except requests.exceptions.ConnectionError:
return False
# these tests only run if we have the ollama is running
pytestmark = providers.ollama.mark


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -48,15 +44,13 @@ class Weather(BaseModel):
return Weather(time="12:00", weather="sunny")


@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
def test_agent_invoke(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
@pytest.mark.asyncio
async def test_agent_invoke_async(agent):
result = await agent.invoke_async("What is the time and weather in New York?")
Expand All @@ -65,7 +59,6 @@ async def test_agent_invoke_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
@pytest.mark.asyncio
async def test_agent_stream_async(agent):
stream = agent.stream_async("What is the time and weather in New York?")
Expand All @@ -78,14 +71,12 @@ async def test_agent_stream_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
def test_agent_structured_output(agent, weather):
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather


@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
@pytest.mark.asyncio
async def test_agent_structured_output_async(agent, weather):
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import strands
from strands import Agent, tool

if "OPENAI_API_KEY" not in os.environ:
pytest.skip(allow_module_level=True, reason="OPENAI_API_KEY environment variable missing")

from strands.models.openai import OpenAIModel
from tests_integ.models import providers

# these tests only run if we have the openai api key
pytestmark = providers.openai.mark


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -53,7 +53,7 @@ class Weather(BaseModel):

@pytest.fixture(scope="module")
def test_image_path(request):
return request.config.rootpath / "tests-integ" / "test_image.png"
return request.config.rootpath / "tests_integ" / "test_image.png"


def test_agent_invoke(agent):
Expand Down Expand Up @@ -96,6 +96,7 @@ async def test_agent_structured_output_async(agent, weather):
assert tru_weather == exp_weather


@pytest.mark.skip("https://github.com/strands-agents/sdk-python/issues/320")
def test_tool_returning_images(model, test_image_path):
@tool
def tool_with_image_return():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import strands
from strands import Agent
from strands.models.writer import WriterModel
from tests_integ.models import providers

# these tests only run if we have the writer api key
pytestmark = providers.writer.mark


@pytest.fixture
Expand Down Expand Up @@ -40,7 +44,6 @@ def agent(model, tools, system_prompt):
return Agent(model=model, tools=tools, system_prompt=system_prompt, load_tools_from_directory=False)


@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing")
def test_agent(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()
Expand All @@ -49,7 +52,6 @@ def test_agent(agent):


@pytest.mark.asyncio
@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing")
async def test_agent_async(agent):
result = await agent.invoke_async("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()
Expand All @@ -58,7 +60,6 @@ async def test_agent_async(agent):


@pytest.mark.asyncio
@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing")
async def test_agent_stream_async(agent):
stream = agent.stream_async("What is the time and weather in New York?")
async for event in stream:
Expand All @@ -70,7 +71,6 @@ async def test_agent_stream_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing")
def test_structured_output(agent):
class Weather(BaseModel):
time: str
Expand All @@ -84,7 +84,6 @@ class Weather(BaseModel):


@pytest.mark.asyncio
@pytest.mark.skipif("WRITER_API_KEY" not in os.environ, reason="WRITER_API_KEY environment variable missing")
async def test_structured_output_async(agent):
class Weather(BaseModel):
time: str
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def calculator(x: int, y: int) -> int:
@mcp.tool(description="Generates a custom image")
def generate_custom_image() -> MCPImageContent:
try:
with open("tests-integ/test_image.png", "rb") as image_file:
with open("tests_integ/test_image.png", "rb") as image_file:
encoded_image = base64.b64encode(image_file.read())
return MCPImageContent(type="image", data=encoded_image, mimeType="image/png")
except Exception as e:
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_mcp_client():

sse_mcp_client = MCPClient(lambda: sse_client("http://127.0.0.1:8000/sse"))
stdio_mcp_client = MCPClient(
lambda: stdio_client(StdioServerParameters(command="python", args=["tests-integ/echo_server.py"]))
lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
)
with sse_mcp_client, stdio_mcp_client:
agent = Agent(tools=sse_mcp_client.list_tools_sync() + stdio_mcp_client.list_tools_sync())
Expand All @@ -90,7 +90,7 @@ def test_mcp_client():

def test_can_reuse_mcp_client():
stdio_mcp_client = MCPClient(
lambda: stdio_client(StdioServerParameters(command="python", args=["tests-integ/echo_server.py"]))
lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
)
with stdio_mcp_client:
stdio_mcp_client.list_tools_sync()
Expand Down
File renamed without changes.
Loading
Loading