Skip to content

Commit 568b097

Browse files
authored
iterative agent (#295)
1 parent 6a1ccea commit 568b097

File tree

3 files changed

+82
-200
lines changed

3 files changed

+82
-200
lines changed

src/strands/agent/agent.py

Lines changed: 45 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,18 @@
99
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
1010
"""
1111

12-
import asyncio
1312
import json
1413
import logging
1514
import os
1615
import random
1716
from concurrent.futures import ThreadPoolExecutor
18-
from threading import Thread
19-
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union, cast
20-
from uuid import uuid4
17+
from typing import Any, AsyncIterator, Callable, Generator, Mapping, Optional, Type, TypeVar, Union, cast
2118

2219
from opentelemetry import trace
2320
from pydantic import BaseModel
2421

2522
from ..event_loop.event_loop import event_loop_cycle
26-
from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler
23+
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
2724
from ..handlers.tool_handler import AgentToolHandler
2825
from ..models.bedrock import BedrockModel
2926
from ..telemetry.metrics import EventLoopMetrics
@@ -210,7 +207,7 @@ def __init__(
210207
self,
211208
model: Union[Model, str, None] = None,
212209
messages: Optional[Messages] = None,
213-
tools: Optional[List[Union[str, Dict[str, str], Any]]] = None,
210+
tools: Optional[list[Union[str, dict[str, str], Any]]] = None,
214211
system_prompt: Optional[str] = None,
215212
callback_handler: Optional[
216213
Union[Callable[..., Any], _DefaultCallbackHandlerSentinel]
@@ -282,7 +279,7 @@ def __init__(
282279
self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager()
283280

284281
# Process trace attributes to ensure they're of compatible types
285-
self.trace_attributes: Dict[str, AttributeValue] = {}
282+
self.trace_attributes: dict[str, AttributeValue] = {}
286283
if trace_attributes:
287284
for k, v in trace_attributes.items():
288285
if isinstance(v, (str, int, float, bool)) or (
@@ -339,7 +336,7 @@ def tool(self) -> ToolCaller:
339336
return self.tool_caller
340337

341338
@property
342-
def tool_names(self) -> List[str]:
339+
def tool_names(self) -> list[str]:
343340
"""Get a list of all registered tool names.
344341
345342
Returns:
@@ -384,19 +381,25 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
384381
- metrics: Performance metrics from the event loop
385382
- state: The final state of the event loop
386383
"""
384+
callback_handler = kwargs.get("callback_handler", self.callback_handler)
385+
387386
self._start_agent_trace_span(prompt)
388387

389388
try:
390-
# Run the event loop and get the result
391-
result = self._run_loop(prompt, kwargs)
389+
events = self._run_loop(callback_handler, prompt, kwargs)
390+
for event in events:
391+
if "callback" in event:
392+
callback_handler(**event["callback"])
393+
394+
stop_reason, message, metrics, state = event["stop"]
395+
result = AgentResult(stop_reason, message, metrics, state)
392396

393397
self._end_agent_trace_span(response=result)
394398

395399
return result
400+
396401
except Exception as e:
397402
self._end_agent_trace_span(error=e)
398-
399-
# Re-raise the exception to preserve original behavior
400403
raise
401404

402405
def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T:
@@ -460,83 +463,56 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
460463
yield event["data"]
461464
```
462465
"""
463-
self._start_agent_trace_span(prompt)
466+
callback_handler = kwargs.get("callback_handler", self.callback_handler)
464467

465-
_stop_event = uuid4()
466-
467-
queue = asyncio.Queue[Any]()
468-
loop = asyncio.get_event_loop()
469-
470-
def enqueue(an_item: Any) -> None:
471-
nonlocal queue
472-
nonlocal loop
473-
loop.call_soon_threadsafe(queue.put_nowait, an_item)
474-
475-
def queuing_callback_handler(**handler_kwargs: Any) -> None:
476-
enqueue(handler_kwargs.copy())
468+
self._start_agent_trace_span(prompt)
477469

478-
def target_callback() -> None:
479-
nonlocal kwargs
470+
try:
471+
events = self._run_loop(callback_handler, prompt, kwargs)
472+
for event in events:
473+
if "callback" in event:
474+
callback_handler(**event["callback"])
475+
yield event["callback"]
480476

481-
try:
482-
result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler)
483-
self._end_agent_trace_span(response=result)
484-
except Exception as e:
485-
self._end_agent_trace_span(error=e)
486-
enqueue(e)
487-
finally:
488-
enqueue(_stop_event)
477+
stop_reason, message, metrics, state = event["stop"]
478+
result = AgentResult(stop_reason, message, metrics, state)
489479

490-
thread = Thread(target=target_callback, daemon=True)
491-
thread.start()
480+
self._end_agent_trace_span(response=result)
492481

493-
try:
494-
while True:
495-
item = await queue.get()
496-
if item == _stop_event:
497-
break
498-
if isinstance(item, Exception):
499-
raise item
500-
yield item
501-
finally:
502-
thread.join()
482+
except Exception as e:
483+
self._end_agent_trace_span(error=e)
484+
raise
503485

504486
def _run_loop(
505-
self, prompt: str, kwargs: Dict[str, Any], supplementary_callback_handler: Optional[Callable[..., Any]] = None
506-
) -> AgentResult:
487+
self, callback_handler: Callable[..., Any], prompt: str, kwargs: dict[str, Any]
488+
) -> Generator[dict[str, Any], None, None]:
507489
"""Execute the agent's event loop with the given prompt and parameters."""
508490
try:
509-
# If the call had a callback_handler passed in, then for this event_loop
510-
# cycle we call both handlers as the callback_handler
511-
invocation_callback_handler = (
512-
CompositeCallbackHandler(self.callback_handler, supplementary_callback_handler)
513-
if supplementary_callback_handler is not None
514-
else self.callback_handler
515-
)
516-
517491
# Extract key parameters
518-
invocation_callback_handler(init_event_loop=True, **kwargs)
492+
yield {"callback": {"init_event_loop": True, **kwargs}}
519493

520494
# Set up the user message with optional knowledge base retrieval
521-
message_content: List[ContentBlock] = [{"text": prompt}]
495+
message_content: list[ContentBlock] = [{"text": prompt}]
522496
new_message: Message = {"role": "user", "content": message_content}
523497
self.messages.append(new_message)
524498

525499
# Execute the event loop cycle with retry logic for context limits
526-
return self._execute_event_loop_cycle(invocation_callback_handler, kwargs)
500+
yield from self._execute_event_loop_cycle(callback_handler, kwargs)
527501

528502
finally:
529503
self.conversation_manager.apply_management(self)
530504

531-
def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs: Dict[str, Any]) -> AgentResult:
505+
def _execute_event_loop_cycle(
506+
self, callback_handler: Callable[..., Any], kwargs: dict[str, Any]
507+
) -> Generator[dict[str, Any], None, None]:
532508
"""Execute the event loop cycle with retry logic for context window limits.
533509
534510
This internal method handles the execution of the event loop cycle and implements
535511
retry logic for handling context window overflow exceptions by reducing the
536512
conversation context and retrying.
537513
538-
Returns:
539-
The result of the event loop cycle.
514+
Yields:
515+
Events of the loop cycle.
540516
"""
541517
# Extract parameters with fallbacks to instance values
542518
system_prompt = kwargs.pop("system_prompt", self.system_prompt)
@@ -551,7 +527,7 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs
551527

552528
try:
553529
# Execute the main event loop cycle
554-
events = event_loop_cycle(
530+
yield from event_loop_cycle(
555531
model=model,
556532
system_prompt=system_prompt,
557533
messages=messages, # will be modified by event_loop_cycle
@@ -564,26 +540,18 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs
564540
event_loop_parent_span=self.trace_span,
565541
**kwargs,
566542
)
567-
for event in events:
568-
if "callback" in event:
569-
callback_handler(**event["callback"])
570-
571-
stop_reason, message, metrics, state = event["stop"]
572-
573-
return AgentResult(stop_reason, message, metrics, state)
574543

575544
except ContextWindowOverflowException as e:
576545
# Try reducing the context size and retrying
577-
578546
self.conversation_manager.reduce_context(self, e=e)
579-
return self._execute_event_loop_cycle(callback_handler_override, kwargs)
547+
yield from self._execute_event_loop_cycle(callback_handler_override, kwargs)
580548

581549
def _record_tool_execution(
582550
self,
583-
tool: Dict[str, Any],
584-
tool_result: Dict[str, Any],
551+
tool: dict[str, Any],
552+
tool_result: dict[str, Any],
585553
user_message_override: Optional[str],
586-
messages: List[Dict[str, Any]],
554+
messages: list[dict[str, Any]],
587555
) -> None:
588556
"""Record a tool execution in the message history.
589557
@@ -662,7 +630,7 @@ def _end_agent_trace_span(
662630
error: Error to record as a trace attribute.
663631
"""
664632
if self.trace_span:
665-
trace_attributes: Dict[str, Any] = {
633+
trace_attributes: dict[str, Any] = {
666634
"span": self.trace_span,
667635
}
668636

tests-integ/test_agent_async.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
3+
import strands
4+
5+
6+
@pytest.fixture
7+
def agent():
8+
return strands.Agent()
9+
10+
11+
@pytest.mark.asyncio
12+
async def test_stream_async(agent):
13+
stream = agent.stream_async("hello")
14+
15+
exp_message = ""
16+
async for event in stream:
17+
if "event" in event and "contentBlockDelta" in event["event"]:
18+
exp_message += event["event"]["contentBlockDelta"]["delta"]["text"]
19+
20+
tru_message = agent.messages[-1]["content"][0]["text"]
21+
22+
assert tru_message == exp_message

0 commit comments

Comments
 (0)