Skip to content

Commit 9210c0f

Browse files
pgrayyjsamuel1
authored andcommitted
iterative agent (strands-agents#295)
1 parent 78692d4 commit 9210c0f

File tree

3 files changed

+82
-200
lines changed

3 files changed

+82
-200
lines changed

src/strands/agent/agent.py

Lines changed: 45 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,18 @@
99
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
1010
"""
1111

12-
import asyncio
1312
import json
1413
import logging
1514
import os
1615
import random
1716
from concurrent.futures import ThreadPoolExecutor
18-
from threading import Thread
19-
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union, cast
20-
from uuid import uuid4
17+
from typing import Any, AsyncIterator, Callable, Generator, Mapping, Optional, Type, TypeVar, Union, cast
2118

2219
from opentelemetry import trace
2320
from pydantic import BaseModel
2421

2522
from ..event_loop.event_loop import event_loop_cycle
26-
from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler
23+
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
2724
from ..handlers.tool_handler import AgentToolHandler
2825
from ..models.bedrock import BedrockModel
2926
from ..telemetry.metrics import EventLoopMetrics
@@ -211,7 +208,7 @@ def __init__(
211208
self,
212209
model: Union[Model, str, None] = None,
213210
messages: Optional[Messages] = None,
214-
tools: Optional[List[Union[str, Dict[str, str], Any]]] = None,
211+
tools: Optional[list[Union[str, dict[str, str], Any]]] = None,
215212
system_prompt: Optional[str] = None,
216213
callback_handler: Optional[
217214
Union[Callable[..., Any], _DefaultCallbackHandlerSentinel]
@@ -291,7 +288,7 @@ def __init__(
291288
self.tool_manager.set_agent(self)
292289

293290
# Process trace attributes to ensure they're of compatible types
294-
self.trace_attributes: Dict[str, AttributeValue] = {}
291+
self.trace_attributes: dict[str, AttributeValue] = {}
295292
if trace_attributes:
296293
for k, v in trace_attributes.items():
297294
if isinstance(v, (str, int, float, bool)) or (
@@ -348,7 +345,7 @@ def tool(self) -> ToolCaller:
348345
return self.tool_caller
349346

350347
@property
351-
def tool_names(self) -> List[str]:
348+
def tool_names(self) -> list[str]:
352349
"""Get a list of all registered tool names.
353350
354351
Returns:
@@ -424,19 +421,25 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
424421
- metrics: Performance metrics from the event loop
425422
- state: The final state of the event loop
426423
"""
424+
callback_handler = kwargs.get("callback_handler", self.callback_handler)
425+
427426
self._start_agent_trace_span(prompt)
428427

429428
try:
430-
# Run the event loop and get the result
431-
result = self._run_loop(prompt, kwargs)
429+
events = self._run_loop(callback_handler, prompt, kwargs)
430+
for event in events:
431+
if "callback" in event:
432+
callback_handler(**event["callback"])
433+
434+
stop_reason, message, metrics, state = event["stop"]
435+
result = AgentResult(stop_reason, message, metrics, state)
432436

433437
self._end_agent_trace_span(response=result)
434438

435439
return result
440+
436441
except Exception as e:
437442
self._end_agent_trace_span(error=e)
438-
439-
# Re-raise the exception to preserve original behavior
440443
raise
441444

442445
def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T:
@@ -500,83 +503,56 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
500503
yield event["data"]
501504
```
502505
"""
503-
self._start_agent_trace_span(prompt)
506+
callback_handler = kwargs.get("callback_handler", self.callback_handler)
504507

505-
_stop_event = uuid4()
506-
507-
queue = asyncio.Queue[Any]()
508-
loop = asyncio.get_event_loop()
509-
510-
def enqueue(an_item: Any) -> None:
511-
nonlocal queue
512-
nonlocal loop
513-
loop.call_soon_threadsafe(queue.put_nowait, an_item)
514-
515-
def queuing_callback_handler(**handler_kwargs: Any) -> None:
516-
enqueue(handler_kwargs.copy())
508+
self._start_agent_trace_span(prompt)
517509

518-
def target_callback() -> None:
519-
nonlocal kwargs
510+
try:
511+
events = self._run_loop(callback_handler, prompt, kwargs)
512+
for event in events:
513+
if "callback" in event:
514+
callback_handler(**event["callback"])
515+
yield event["callback"]
520516

521-
try:
522-
result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler)
523-
self._end_agent_trace_span(response=result)
524-
except Exception as e:
525-
self._end_agent_trace_span(error=e)
526-
enqueue(e)
527-
finally:
528-
enqueue(_stop_event)
517+
stop_reason, message, metrics, state = event["stop"]
518+
result = AgentResult(stop_reason, message, metrics, state)
529519

530-
thread = Thread(target=target_callback, daemon=True)
531-
thread.start()
520+
self._end_agent_trace_span(response=result)
532521

533-
try:
534-
while True:
535-
item = await queue.get()
536-
if item == _stop_event:
537-
break
538-
if isinstance(item, Exception):
539-
raise item
540-
yield item
541-
finally:
542-
thread.join()
522+
except Exception as e:
523+
self._end_agent_trace_span(error=e)
524+
raise
543525

544526
def _run_loop(
545-
self, prompt: str, kwargs: Dict[str, Any], supplementary_callback_handler: Optional[Callable[..., Any]] = None
546-
) -> AgentResult:
527+
self, callback_handler: Callable[..., Any], prompt: str, kwargs: dict[str, Any]
528+
) -> Generator[dict[str, Any], None, None]:
547529
"""Execute the agent's event loop with the given prompt and parameters."""
548530
try:
549-
# If the call had a callback_handler passed in, then for this event_loop
550-
# cycle we call both handlers as the callback_handler
551-
invocation_callback_handler = (
552-
CompositeCallbackHandler(self.callback_handler, supplementary_callback_handler)
553-
if supplementary_callback_handler is not None
554-
else self.callback_handler
555-
)
556-
557531
# Extract key parameters
558-
invocation_callback_handler(init_event_loop=True, **kwargs)
532+
yield {"callback": {"init_event_loop": True, **kwargs}}
559533

560534
# Set up the user message with optional knowledge base retrieval
561-
message_content: List[ContentBlock] = [{"text": prompt}]
535+
message_content: list[ContentBlock] = [{"text": prompt}]
562536
new_message: Message = {"role": "user", "content": message_content}
563537
self.messages.append(new_message)
564538

565539
# Execute the event loop cycle with retry logic for context limits
566-
return self._execute_event_loop_cycle(invocation_callback_handler, kwargs)
540+
yield from self._execute_event_loop_cycle(callback_handler, kwargs)
567541

568542
finally:
569543
self.conversation_manager.apply_management(self)
570544

571-
def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs: Dict[str, Any]) -> AgentResult:
545+
def _execute_event_loop_cycle(
546+
self, callback_handler: Callable[..., Any], kwargs: dict[str, Any]
547+
) -> Generator[dict[str, Any], None, None]:
572548
"""Execute the event loop cycle with retry logic for context window limits.
573549
574550
This internal method handles the execution of the event loop cycle and implements
575551
retry logic for handling context window overflow exceptions by reducing the
576552
conversation context and retrying.
577553
578-
Returns:
579-
The result of the event loop cycle.
554+
Yields:
555+
Events of the loop cycle.
580556
"""
581557
# Extract parameters with fallbacks to instance values
582558
system_prompt = kwargs.pop("system_prompt", self.system_prompt)
@@ -605,7 +581,7 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs
605581

606582
try:
607583
# Execute the main event loop cycle
608-
events = event_loop_cycle(
584+
yield from event_loop_cycle(
609585
model=model,
610586
system_prompt=system_prompt,
611587
messages=messages, # will be modified by event_loop_cycle
@@ -618,26 +594,18 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs
618594
event_loop_parent_span=self.trace_span,
619595
**kwargs,
620596
)
621-
for event in events:
622-
if "callback" in event:
623-
callback_handler(**event["callback"])
624-
625-
stop_reason, message, metrics, state = event["stop"]
626-
627-
return AgentResult(stop_reason, message, metrics, state)
628597

629598
except ContextWindowOverflowException as e:
630599
# Try reducing the context size and retrying
631-
632600
self.conversation_manager.reduce_context(self, e=e)
633-
return self._execute_event_loop_cycle(callback_handler_override, kwargs)
601+
yield from self._execute_event_loop_cycle(callback_handler_override, kwargs)
634602

635603
def _record_tool_execution(
636604
self,
637-
tool: Dict[str, Any],
638-
tool_result: Dict[str, Any],
605+
tool: dict[str, Any],
606+
tool_result: dict[str, Any],
639607
user_message_override: Optional[str],
640-
messages: List[Dict[str, Any]],
608+
messages: list[dict[str, Any]],
641609
) -> None:
642610
"""Record a tool execution in the message history.
643611
@@ -716,7 +684,7 @@ def _end_agent_trace_span(
716684
error: Error to record as a trace attribute.
717685
"""
718686
if self.trace_span:
719-
trace_attributes: Dict[str, Any] = {
687+
trace_attributes: dict[str, Any] = {
720688
"span": self.trace_span,
721689
}
722690

tests-integ/test_agent_async.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
3+
import strands
4+
5+
6+
@pytest.fixture
7+
def agent():
8+
return strands.Agent()
9+
10+
11+
@pytest.mark.asyncio
12+
async def test_stream_async(agent):
13+
stream = agent.stream_async("hello")
14+
15+
exp_message = ""
16+
async for event in stream:
17+
if "event" in event and "contentBlockDelta" in event["event"]:
18+
exp_message += event["event"]["contentBlockDelta"]["delta"]["text"]
19+
20+
tru_message = agent.messages[-1]["content"][0]["text"]
21+
22+
assert tru_message == exp_message

0 commit comments

Comments
 (0)