Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions livekit-agents/livekit/agents/voice/run_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import functools
import json
import os
from collections.abc import Generator
from collections.abc import AsyncIterator, Generator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (
Expand All @@ -26,7 +26,7 @@
from ..llm import function_tool, utils as llm_utils
from ..telemetry import trace_types, tracer
from ..types import NOT_GIVEN, NotGivenOr
from ..utils import is_given
from ..utils import aio, is_given
from .speech_handle import SpeechHandle

if TYPE_CHECKING:
Expand Down Expand Up @@ -75,6 +75,7 @@ def __init__(self, *, user_input: str | None = None, output_type: type[Run_T] |
self._user_input = user_input
self._output_type = output_type
self._recorded_items: list[RunEvent] = []
self._recorded_items_ch: aio.Chan[RunEvent] = aio.Chan()
self._final_output: Run_T | None = None

self.__last_speech_handle: SpeechHandle | None = None
Expand Down Expand Up @@ -134,12 +135,28 @@ async def _await_impl() -> RunResult[Run_T]:

return _await_impl().__await__()

async def __anext__(self) -> RunEvent:
try:
val = await self._recorded_items_ch.__anext__()
except StopAsyncIteration:
if self._done_fut.done() and (exc := self._done_fut.exception()):
raise exc # noqa: B904

raise StopAsyncIteration from None

return val

def __aiter__(self) -> AsyncIterator[RunEvent]:
# NOTE: the order of FunctionCallEvent may not be the same as that in final result
return self

def _agent_handoff(
self, *, item: llm.AgentHandoff, old_agent: Agent | None, new_agent: Agent
) -> None:
event = AgentHandoffEvent(item=item, old_agent=old_agent, new_agent=new_agent)
index = self._find_insertion_index(created_at=event.item.created_at)
self._recorded_items.insert(index, event)
self._recorded_items_ch.send_nowait(event)

def _item_added(self, item: llm.ChatItem) -> None:
if self._done_fut.done():
Expand All @@ -156,6 +173,7 @@ def _item_added(self, item: llm.ChatItem) -> None:
if event is not None:
index = self._find_insertion_index(created_at=event.item.created_at)
self._recorded_items.insert(index, event)
self._recorded_items_ch.send_nowait(event)

def _watch_handle(self, handle: SpeechHandle | asyncio.Task) -> None:
self._handles.add(handle)
Expand All @@ -180,6 +198,7 @@ def _mark_done_if_needed(self, handle: SpeechHandle | asyncio.Task | None) -> No
self._mark_done()

def _mark_done(self) -> None:
self._recorded_items_ch.close()
with contextlib.suppress(asyncio.InvalidStateError):
if self.__last_speech_handle is None:
self._done_fut.set_result(None)
Expand Down
Loading