Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
4. Manage recursive execution cycles
"""

import asyncio
import logging
import time
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator

Expand Down Expand Up @@ -189,7 +189,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
MAX_ATTEMPTS,
attempt + 1,
)
time.sleep(current_delay)
await asyncio.sleep(current_delay)
current_delay = min(current_delay * 2, MAX_DELAY)

yield EventLoopThrottleEvent(delay=current_delay)
Expand Down
10 changes: 6 additions & 4 deletions tests/strands/agent/hooks/test_agent_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ async def streaming_tool():


@pytest.fixture
def mock_time():
with unittest.mock.patch.object(strands.event_loop.event_loop, "time") as mock:
def mock_sleep():
with unittest.mock.patch.object(
strands.event_loop.event_loop.asyncio, "sleep", new_callable=unittest.mock.AsyncMock
) as mock:
yield mock


Expand Down Expand Up @@ -322,7 +324,7 @@ async def test_stream_e2e_success(alist):


@pytest.mark.asyncio
async def test_stream_e2e_throttle_and_redact(alist, mock_time):
async def test_stream_e2e_throttle_and_redact(alist, mock_sleep):
model = MagicMock()
model.stream.side_effect = [
ModelThrottledException("ThrottlingException | ConverseStream"),
Expand Down Expand Up @@ -389,7 +391,7 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_time):
async def test_event_loop_cycle_text_response_throttling_early_end(
agenerator,
alist,
mock_time,
mock_sleep,
):
model = MagicMock()
model.stream.side_effect = [
Expand Down
24 changes: 13 additions & 11 deletions tests/strands/event_loop/test_event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@


@pytest.fixture
def mock_time():
with unittest.mock.patch.object(strands.event_loop.event_loop, "time") as mock:
def mock_sleep():
with unittest.mock.patch.object(
strands.event_loop.event_loop.asyncio, "sleep", new_callable=unittest.mock.AsyncMock
) as mock:
yield mock


Expand Down Expand Up @@ -186,7 +188,7 @@ async def test_event_loop_cycle_text_response(

@pytest.mark.asyncio
async def test_event_loop_cycle_text_response_throttling(
mock_time,
mock_sleep,
agent,
model,
agenerator,
Expand Down Expand Up @@ -215,12 +217,12 @@ async def test_event_loop_cycle_text_response_throttling(

assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state
# Verify that sleep was called once with the initial delay
mock_time.sleep.assert_called_once()
mock_sleep.assert_called_once()


@pytest.mark.asyncio
async def test_event_loop_cycle_exponential_backoff(
mock_time,
mock_sleep,
agent,
model,
agenerator,
Expand Down Expand Up @@ -254,13 +256,13 @@ async def test_event_loop_cycle_exponential_backoff(

# Verify that sleep was called with increasing delays
# Initial delay is 4, then 8, then 16
assert mock_time.sleep.call_count == 3
assert mock_time.sleep.call_args_list == [call(4), call(8), call(16)]
assert mock_sleep.call_count == 3
assert mock_sleep.call_args_list == [call(4), call(8), call(16)]


@pytest.mark.asyncio
async def test_event_loop_cycle_text_response_throttling_exceeded(
mock_time,
mock_sleep,
agent,
model,
alist,
Expand All @@ -281,7 +283,7 @@ async def test_event_loop_cycle_text_response_throttling_exceeded(
)
await alist(stream)

mock_time.sleep.assert_has_calls(
mock_sleep.assert_has_calls(
[
call(4),
call(8),
Expand Down Expand Up @@ -687,7 +689,7 @@ async def test_event_loop_tracing_with_throttling_exception(
]

# Mock the time.sleep function to speed up the test
with patch("strands.event_loop.event_loop.time.sleep"):
with patch("strands.event_loop.event_loop.asyncio.sleep", new_callable=unittest.mock.AsyncMock):
stream = strands.event_loop.event_loop.event_loop_cycle(
agent=agent,
invocation_state={},
Expand Down Expand Up @@ -816,7 +818,7 @@ async def test_prepare_next_cycle_in_tool_execution(agent, model, tool_stream, a


@pytest.mark.asyncio
async def test_event_loop_cycle_exception_model_hooks(mock_time, agent, model, agenerator, alist, hook_provider):
async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, agenerator, alist, hook_provider):
"""Test that model hooks are correctly emitted even when throttled."""
# Set up the model to raise throttling exceptions multiple times before succeeding
exception = ModelThrottledException("ThrottlingException | ConverseStream")
Expand Down
Loading