|
16 | 16 | from concurrent.futures import ThreadPoolExecutor |
17 | 17 | from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast |
18 | 18 |
|
19 | | -from opentelemetry import trace |
| 19 | +from opentelemetry import trace as trace_api |
20 | 20 | from pydantic import BaseModel |
21 | 21 |
|
22 | 22 | from ..event_loop.event_loop import event_loop_cycle, run_tool |
@@ -300,7 +300,7 @@ def __init__( |
300 | 300 |
|
301 | 301 | # Initialize tracer instance (no-op if not configured) |
302 | 302 | self.tracer = get_tracer() |
303 | | - self.trace_span: Optional[trace.Span] = None |
| 303 | + self.trace_span: Optional[trace_api.Span] = None |
304 | 304 |
|
305 | 305 | # Initialize agent state management |
306 | 306 | if state is not None: |
@@ -504,23 +504,25 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A |
504 | 504 | message: Message = {"role": "user", "content": content} |
505 | 505 |
|
506 | 506 | self._start_agent_trace_span(message) |
507 | | - |
508 | | - try: |
509 | | - events = self._run_loop(message, invocation_state=kwargs) |
510 | | - async for event in events: |
511 | | - if "callback" in event: |
512 | | - callback_handler(**event["callback"]) |
513 | | - yield event["callback"] |
514 | | - |
515 | | - result = AgentResult(*event["stop"]) |
516 | | - callback_handler(result=result) |
517 | | - yield {"result": result} |
518 | | - |
519 | | - self._end_agent_trace_span(response=result) |
520 | | - |
521 | | - except Exception as e: |
522 | | - self._end_agent_trace_span(error=e) |
523 | | - raise |
| 507 | + if self.trace_span is not None: |
| 508 | + span = self.trace_span |
| 509 | + with trace_api.use_span(span): |
| 510 | + try: |
| 511 | + events = self._run_loop(message, invocation_state=kwargs) |
| 512 | + async for event in events: |
| 513 | + if "callback" in event: |
| 514 | + callback_handler(**event["callback"]) |
| 515 | + yield event["callback"] |
| 516 | + |
| 517 | + result = AgentResult(*event["stop"]) |
| 518 | + callback_handler(result=result) |
| 519 | + yield {"result": result} |
| 520 | + |
| 521 | + self._end_agent_trace_span(response=result) |
| 522 | + |
| 523 | + except Exception as e: |
| 524 | + self._end_agent_trace_span(error=e) |
| 525 | + raise |
524 | 526 |
|
525 | 527 | async def _run_loop( |
526 | 528 | self, message: Message, invocation_state: dict[str, Any] |
@@ -659,7 +661,6 @@ def _start_agent_trace_span(self, message: Message) -> None: |
659 | 661 | message: The user message. |
660 | 662 | """ |
661 | 663 | model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None |
662 | | - |
663 | 664 | self.trace_span = self.tracer.start_agent_span( |
664 | 665 | message=message, |
665 | 666 | agent_name=self.name, |
|
0 commit comments