Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_conf
logger.debug("config=<%s> | initializing", self.config)

client_args = client_args or {}
self.client = anthropic.Anthropic(**client_args)
self.client = anthropic.AsyncAnthropic(**client_args)

@override
def update_config(self, **model_config: Unpack[AnthropicConfig]) -> None: # type: ignore[override]
Expand Down Expand Up @@ -358,8 +358,8 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
ModelThrottledException: If the request is throttled by Anthropic.
"""
try:
with self.client.messages.stream(**request) as stream:
for event in stream:
async with self.client.messages.stream(**request) as stream:
async for event in stream:
if event.type in AnthropicModel.EVENT_TYPES:
yield event.model_dump()

Expand Down
63 changes: 49 additions & 14 deletions tests-integ/test_model_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from strands.models.anthropic import AnthropicModel


@pytest.fixture
@pytest.fixture(scope="module")
def model():
return AnthropicModel(
client_args={
Expand All @@ -19,7 +19,7 @@ def model():
)


@pytest.fixture
@pytest.fixture(scope="module")
def tools():
@strands.tool
def tool_time() -> str:
Expand All @@ -32,32 +32,67 @@ def tool_weather() -> str:
return [tool_time, tool_weather]


@pytest.fixture
@pytest.fixture(scope="module")
def system_prompt():
return "You are an AI assistant."


@pytest.fixture
@pytest.fixture(scope="module")
def agent(model, tools, system_prompt):
return Agent(model=model, tools=tools, system_prompt=system_prompt)


@pytest.fixture(scope="module")
def weather():
class Weather(BaseModel):
"""Extracts the time and weather from the user's message with the exact strings."""

time: str
weather: str

return Weather(time="12:00", weather="sunny")


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
def test_agent(agent):
def test_agent_invoke(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
def test_structured_output(model):
class Weather(BaseModel):
time: str
weather: str
@pytest.mark.asyncio
async def test_agent_invoke_async(agent):
result = await agent.invoke_async("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])

agent = Agent(model=model)
result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")
assert isinstance(result, Weather)
assert result.time == "12:00"
assert result.weather == "sunny"

@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_stream_async(agent):
stream = agent.stream_async("What is the time and weather in New York?")
async for event in stream:
_ = event

result = event["result"]
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
def test_structured_output(agent, weather):
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather


@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_structured_output_async(agent, weather):
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather
18 changes: 9 additions & 9 deletions tests/strands/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.fixture
def anthropic_client():
with unittest.mock.patch.object(strands.models.anthropic.anthropic, "Anthropic") as mock_client_cls:
with unittest.mock.patch.object(strands.models.anthropic.anthropic, "AsyncAnthropic") as mock_client_cls:
yield mock_client_cls.return_value


Expand Down Expand Up @@ -625,7 +625,7 @@ def test_format_chunk_unknown(model):


@pytest.mark.asyncio
async def test_stream(anthropic_client, model, alist):
async def test_stream(anthropic_client, model, agenerator, alist):
mock_event_1 = unittest.mock.Mock(
type="message_start",
dict=lambda: {"type": "message_start"},
Expand All @@ -646,9 +646,9 @@ async def test_stream(anthropic_client, model, alist):
),
)

mock_stream = unittest.mock.MagicMock()
mock_stream.__iter__.return_value = iter([mock_event_1, mock_event_2, mock_event_3])
anthropic_client.messages.stream.return_value.__enter__.return_value = mock_stream
mock_context = unittest.mock.AsyncMock()
mock_context.__aenter__.return_value = agenerator([mock_event_1, mock_event_2, mock_event_3])
anthropic_client.messages.stream.return_value = mock_context

request = {"model": "m1"}
response = model.stream(request)
Expand Down Expand Up @@ -705,7 +705,7 @@ async def test_stream_bad_request_error(anthropic_client, model):


@pytest.mark.asyncio
async def test_structured_output(anthropic_client, model, test_output_model_cls, alist):
async def test_structured_output(anthropic_client, model, test_output_model_cls, agenerator, alist):
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]

events = [
Expand Down Expand Up @@ -749,9 +749,9 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls,
),
]

mock_stream = unittest.mock.MagicMock()
mock_stream.__iter__.return_value = iter(events)
anthropic_client.messages.stream.return_value.__enter__.return_value = mock_stream
mock_context = unittest.mock.AsyncMock()
mock_context.__aenter__.return_value = agenerator(events)
anthropic_client.messages.stream.return_value = mock_context

stream = model.structured_output(test_output_model_cls, messages)
events = await alist(stream)
Expand Down
Loading