Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast

from opentelemetry import trace
from opentelemetry import trace as trace_api
from pydantic import BaseModel

from ..event_loop.event_loop import event_loop_cycle, run_tool
Expand Down Expand Up @@ -300,7 +300,7 @@ def __init__(

# Initialize tracer instance (no-op if not configured)
self.tracer = get_tracer()
self.trace_span: Optional[trace.Span] = None
self.trace_span: Optional[trace_api.Span] = None

# Initialize agent state management
if state is not None:
Expand Down Expand Up @@ -503,24 +503,24 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
message: Message = {"role": "user", "content": content}

self._start_agent_trace_span(message)
self.trace_span = self._start_agent_trace_span(message)
with trace_api.use_span(self.trace_span):
try:
events = self._run_loop(message, invocation_state=kwargs)
async for event in events:
if "callback" in event:
callback_handler(**event["callback"])
yield event["callback"]

try:
events = self._run_loop(message, invocation_state=kwargs)
async for event in events:
if "callback" in event:
callback_handler(**event["callback"])
yield event["callback"]
result = AgentResult(*event["stop"])
callback_handler(result=result)
yield {"result": result}

result = AgentResult(*event["stop"])
callback_handler(result=result)
yield {"result": result}
self._end_agent_trace_span(response=result)

self._end_agent_trace_span(response=result)

except Exception as e:
self._end_agent_trace_span(error=e)
raise
except Exception as e:
self._end_agent_trace_span(error=e)
raise

async def _run_loop(
self, message: Message, invocation_state: dict[str, Any]
Expand Down Expand Up @@ -652,15 +652,14 @@ def _record_tool_execution(
self._append_message(tool_result_msg)
self._append_message(assistant_msg)

def _start_agent_trace_span(self, message: Message) -> None:
def _start_agent_trace_span(self, message: Message) -> trace_api.Span:
"""Starts a trace span for the agent.

Args:
message: The user message.
"""
model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None

self.trace_span = self.tracer.start_agent_span(
return self.tracer.start_agent_span(
message=message,
agent_name=self.name,
model_id=model_id,
Expand Down
115 changes: 60 additions & 55 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast

from opentelemetry import trace as trace_api

from ..experimental.hooks import (
AfterModelInvocationEvent,
AfterToolInvocationEvent,
Expand Down Expand Up @@ -114,72 +116,75 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
parent_span=cycle_span,
model_id=model_id,
)

tool_specs = agent.tool_registry.get_all_tool_specs()

agent.hooks.invoke_callbacks(
BeforeModelInvocationEvent(
agent=agent,
)
)

try:
# TODO: To maintain backwards compatibility, we need to combine the stream event with invocation_state
# before yielding to the callback handler. This will be revisited when migrating to strongly
# typed events.
async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs):
if "callback" in event:
yield {
"callback": {**event["callback"], **(invocation_state if "delta" in event["callback"] else {})}
}

stop_reason, message, usage, metrics = event["stop"]
invocation_state.setdefault("request_state", {})
with trace_api.use_span(model_invoke_span):
tool_specs = agent.tool_registry.get_all_tool_specs()

agent.hooks.invoke_callbacks(
AfterModelInvocationEvent(
BeforeModelInvocationEvent(
agent=agent,
stop_response=AfterModelInvocationEvent.ModelStopResponse(
stop_reason=stop_reason,
message=message,
),
)
)

if model_invoke_span:
tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason)
break # Success! Break out of retry loop

except Exception as e:
if model_invoke_span:
tracer.end_span_with_error(model_invoke_span, str(e), e)

agent.hooks.invoke_callbacks(
AfterModelInvocationEvent(
agent=agent,
exception=e,
try:
# TODO: To maintain backwards compatibility, we need to combine the stream event with invocation_state
# before yielding to the callback handler. This will be revisited when migrating to strongly
# typed events.
async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs):
if "callback" in event:
yield {
"callback": {
**event["callback"],
**(invocation_state if "delta" in event["callback"] else {}),
}
}

stop_reason, message, usage, metrics = event["stop"]
invocation_state.setdefault("request_state", {})

agent.hooks.invoke_callbacks(
AfterModelInvocationEvent(
agent=agent,
stop_response=AfterModelInvocationEvent.ModelStopResponse(
stop_reason=stop_reason,
message=message,
),
)
)
)

if isinstance(e, ModelThrottledException):
if attempt + 1 == MAX_ATTEMPTS:
yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
raise e
if model_invoke_span:
tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason)
break # Success! Break out of retry loop

logger.debug(
"retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> "
"| throttling exception encountered "
"| delaying before next retry",
current_delay,
MAX_ATTEMPTS,
attempt + 1,
except Exception as e:
if model_invoke_span:
tracer.end_span_with_error(model_invoke_span, str(e), e)

agent.hooks.invoke_callbacks(
AfterModelInvocationEvent(
agent=agent,
exception=e,
)
)
time.sleep(current_delay)
current_delay = min(current_delay * 2, MAX_DELAY)

yield {"callback": {"event_loop_throttled_delay": current_delay, **invocation_state}}
else:
raise e
if isinstance(e, ModelThrottledException):
if attempt + 1 == MAX_ATTEMPTS:
yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
raise e

logger.debug(
"retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> "
"| throttling exception encountered "
"| delaying before next retry",
current_delay,
MAX_ATTEMPTS,
attempt + 1,
)
time.sleep(current_delay)
current_delay = min(current_delay * 2, MAX_DELAY)

yield {"callback": {"event_loop_throttled_delay": current_delay, **invocation_state}}
else:
raise e

try:
# Add message in trace and mark the end of the stream messages trace
Expand Down
4 changes: 2 additions & 2 deletions src/strands/telemetry/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def start_model_invoke_span(
parent_span: Optional[Span] = None,
model_id: Optional[str] = None,
**kwargs: Any,
) -> Optional[Span]:
) -> Span:
"""Start a new span for a model invocation.

Args:
Expand Down Expand Up @@ -414,7 +414,7 @@ def start_agent_span(
tools: Optional[list] = None,
custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
**kwargs: Any,
) -> Optional[Span]:
) -> Span:
"""Start a new span for an agent invocation.

Args:
Expand Down
Loading