Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 72 additions & 53 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union, cast

from opentelemetry import context as otel_context, trace

Check failure on line 12 in livekit-agents/livekit/agents/voice/agent_activity.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

livekit-agents/livekit/agents/voice/agent_activity.py:12:38: F401 `opentelemetry.context` imported but unused

from livekit import rtc
from livekit.agents.llm.realtime import MessageGeneration
from livekit.agents.metrics.base import Metadata

Check failure on line 16 in livekit-agents/livekit/agents/voice/agent_activity.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

livekit-agents/livekit/agents/voice/agent_activity.py:16:41: F401 `livekit.agents.metrics.base.Metadata` imported but unused

from .. import llm, stt, tts, utils, vad
from ..llm.tool_context import StopResponse
Expand All @@ -26,7 +26,7 @@
TTSMetrics,
VADMetrics,
)
from ..telemetry import trace_types, tracer, utils as trace_utils

Check failure on line 29 in livekit-agents/livekit/agents/voice/agent_activity.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

livekit-agents/livekit/agents/voice/agent_activity.py:29:55: F401 `..telemetry.utils` imported but unused
from ..tokenize.basic import split_words
from ..types import NOT_GIVEN, NotGivenOr
from ..utils.misc import is_given
Expand Down Expand Up @@ -89,7 +89,8 @@
def __init__(self, agent: Agent, sess: AgentSession) -> None:
self._agent, self._session = agent, sess
self._rt_session: llm.RealtimeSession | None = None
self._realtime_spans: utils.BoundedDict[str, trace.Span] | None = None
# Lock to coordinate speech assignment with interruptions

Check failure on line 92 in livekit-agents/livekit/agents/voice/agent_activity.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

livekit-agents/livekit/agents/voice/agent_activity.py:92:66: W291 Trailing whitespace
self._speech_assignment_lock = asyncio.Lock()

Check failure on line 93 in livekit-agents/livekit/agents/voice/agent_activity.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

livekit-agents/livekit/agents/voice/agent_activity.py:93:54: W291 Trailing whitespace
self._audio_recognition: AudioRecognition | None = None
self._lock = asyncio.Lock()
self._tool_choice: llm.ToolChoice | None = None
Expand Down Expand Up @@ -360,19 +361,7 @@
if speech_handle is not None:
tk1 = _SpeechHandleContextVar.set(speech_handle)

# Capture the current OpenTelemetry context to ensure proper span nesting
current_context = otel_context.get_current()

# Create a wrapper coroutine that runs in the captured context
async def _context_aware_coro() -> Any:
# Attach the captured context before running the original coroutine
token = otel_context.attach(current_context)
try:
return await coro
finally:
otel_context.detach(token)

task = asyncio.create_task(_context_aware_coro(), name=name)
task = asyncio.create_task(coro, name=name)
self._speech_tasks.append(task)
task.add_done_callback(lambda _: self._speech_tasks.remove(task))

Expand Down Expand Up @@ -515,7 +504,6 @@
except llm.RealtimeError:
logger.exception("failed to update the tools")

self._realtime_spans = utils.BoundedDict[str, trace.Span](maxsize=100)
if (
not self.llm.capabilities.audio_output
and not self.tts
Expand Down Expand Up @@ -664,9 +652,6 @@
if self._rt_session is not None:
await self._rt_session.aclose()

if self._realtime_spans is not None:
self._realtime_spans.clear()

if self._audio_recognition is not None:
await self._audio_recognition.aclose()

Expand Down Expand Up @@ -957,16 +942,19 @@
# skip done speech (interrupted when it's in the queue)
self._current_speech = None
continue
# async with speech._assignment_lock:

Check failure on line 945 in livekit-agents/livekit/agents/voice/agent_activity.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

livekit-agents/livekit/agents/voice/agent_activity.py:945:54: W291 Trailing whitespace
self._current_speech = speech
if self.min_consecutive_speech_delay > 0.0:
await asyncio.sleep(
self.min_consecutive_speech_delay - (time.time() - last_playout_ts)
)

# check again if speech is done after sleep delay
if speech.done():
# skip done speech (interrupted during delay)
self._current_speech = None
continue

speech._authorize_generation()
await speech._wait_for_generation()
self._current_speech = None
Expand Down Expand Up @@ -1013,12 +1001,6 @@
isinstance(ev, LLMMetrics) or isinstance(ev, TTSMetrics)
):
ev.speech_id = speech_handle.id
if (
isinstance(ev, RealtimeModelMetrics)
and self._realtime_spans is not None
and (realtime_span := self._realtime_spans.pop(ev.request_id, None))
):
trace_utils.record_realtime_metrics(realtime_span, ev)
self._session.emit("metrics_collected", MetricsCollectedEvent(metrics=ev))

def _on_error(
Expand Down Expand Up @@ -1072,11 +1054,14 @@
self._session._conversation_item_added(msg)

def _on_generation_created(self, ev: llm.GenerationCreatedEvent) -> None:
print(f"🎯 AGENT: Generation created - user_initiated={ev.user_initiated}")

if ev.user_initiated:
# user_initiated generations are directly handled inside _realtime_reply_task
return

if self._scheduling_paused:
print(f" ❌ BLOCKED: Scheduling paused!")

Check failure on line 1064 in livekit-agents/livekit/agents/voice/agent_activity.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F541)

livekit-agents/livekit/agents/voice/agent_activity.py:1064:19: F541 f-string without any placeholders
# TODO(theomonnom): should we "forward" this new turn to the next agent?
logger.warning("skipping new realtime generation, the speech scheduling is not running")
return
Expand Down Expand Up @@ -1134,28 +1119,38 @@
self._session.output.audio.pause()
self._session._update_agent_state("listening")
else:
if self._rt_session is not None:
self._rt_session.interrupt()
# if self._rt_session is not None:
# self._rt_session.interrupt()

self._current_speech.interrupt()
# self._current_speech.interrupt()

# Use lock to prevent interrupting a speech that's being assigned

Check failure on line 1127 in livekit-agents/livekit/agents/voice/agent_activity.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

livekit-agents/livekit/agents/voice/agent_activity.py:1127:82: W291 Trailing whitespace
async def _do_interrupt():

Check failure on line 1128 in livekit-agents/livekit/agents/voice/agent_activity.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

livekit-agents/livekit/agents/voice/agent_activity.py:1128:43: W291 Trailing whitespace
async with self._speech_assignment_lock:

Check failure on line 1129 in livekit-agents/livekit/agents/voice/agent_activity.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

livekit-agents/livekit/agents/voice/agent_activity.py:1129:61: W291 Trailing whitespace
if self._rt_session is not None:
self._rt_session.interrupt()

# Only interrupt if this is still the current speech
if self._current_speech and self._current_speech == human_speech:
self._current_speech.interrupt()

# Schedule the interrupt (can't await in sync function)
asyncio.create_task(_do_interrupt())

# region recognition hooks

def on_start_of_speech(self, ev: vad.VADEvent | None) -> None:
def on_start_of_speech(self, ev: vad.VADEvent) -> None:
self._session._update_user_state("speaking")

if self._false_interruption_timer:
# cancel the timer when user starts speaking but leave the paused state unchanged
self._false_interruption_timer.cancel()
self._false_interruption_timer = None

def on_end_of_speech(self, ev: vad.VADEvent | None) -> None:
speech_end_time = time.time()
if ev:
speech_end_time = speech_end_time - ev.silence_duration
def on_end_of_speech(self, ev: vad.VADEvent) -> None:
self._session._update_user_state(
"listening",
last_speaking_time=speech_end_time,
last_speaking_time=time.time() - ev.silence_duration,
)

if (
Expand Down Expand Up @@ -1419,22 +1414,13 @@
# await the interrupt to make sure user message is added to the chat context before the new task starts
await speech_handle.interrupt()

metadata: Metadata | None = None
if isinstance(self.turn_detection, str):
metadata = Metadata(model_name="unknown", model_provider=self.turn_detection)
elif self.turn_detection is not None:
metadata = Metadata(
model_name=self.turn_detection.model, model_provider=self.turn_detection.provider
)

eou_metrics = EOUMetrics(
timestamp=time.time(),
end_of_utterance_delay=info.end_of_utterance_delay,
transcription_delay=info.transcription_delay,
on_user_turn_completed_delay=callback_duration,
speech_id=speech_handle.id,
last_speaking_time=info.last_speaking_time,
metadata=metadata,
)
self._session.emit("metrics_collected", MetricsCollectedEvent(metrics=eou_metrics))

Expand Down Expand Up @@ -1500,7 +1486,6 @@
node=self._agent.tts_node,
input=audio_source,
model_settings=model_settings,
text_transforms=self._session.options.tts_text_transforms,
)
tasks.append(tts_task)
if (
Expand Down Expand Up @@ -1642,7 +1627,6 @@
node=self._agent.tts_node,
input=tts_text_input,
model_settings=model_settings,
text_transforms=self._session.options.tts_text_transforms,
)
tasks.append(tts_task)
if (
Expand Down Expand Up @@ -1803,7 +1787,6 @@
)

if generated_msg:
chat_ctx.insert(generated_msg)
self._agent._chat_ctx.insert(generated_msg)
self._session._conversation_item_added(generated_msg)
current_span.set_attribute(
Expand Down Expand Up @@ -1962,16 +1945,14 @@
model_settings: ModelSettings,
instructions: str | None = None,
) -> None:
print(f"🎬 STARTING _realtime_generation_task")
current_span = trace.get_current_span()
current_span.set_attribute(trace_types.ATTR_SPEECH_ID, speech_handle.id)

assert self._rt_session is not None, "rt_session is not available"
assert isinstance(self.llm, llm.RealtimeModel), "llm is not a realtime model"

current_span.set_attribute(trace_types.ATTR_GEN_AI_REQUEST_MODEL, self.llm.model)
if self._realtime_spans is not None and generation_ev.response_id:
self._realtime_spans[generation_ev.response_id] = current_span

print("checking audio output")
audio_output = self._session.output.audio if self._session.output.audio_enabled else None
text_output = (
self._session.output.transcription
Expand All @@ -1980,14 +1961,19 @@
)
tool_ctx = llm.ToolContext(self.tools)

print("waiting for authorization")
wait_for_authorization = asyncio.ensure_future(speech_handle._wait_for_authorization())

print("waiting for speech handle")
await speech_handle.wait_if_not_interrupted([wait_for_authorization])
speech_handle._clear_authorization()

print(f"🔍 CHECK: speech_handle.interrupted = {speech_handle.interrupted}")
if speech_handle.interrupted:
await utils.aio.cancel_and_wait(wait_for_authorization)
print(f"🚨 EARLY EXIT: Speech interrupted, returning early")
# await utils.aio.cancel_and_wait(wait_for_authorization)
current_span.set_attribute(trace_types.ATTR_SPEECH_INTERRUPTED, True)
return # TODO(theomonnom): remove the message from the serverside history
# return # TODO(theomonnom): remove the message from the serverside history

def _on_first_frame(_: asyncio.Future[None]) -> None:
self._session._update_agent_state("speaking")
Expand All @@ -2002,12 +1988,34 @@
async def _read_messages(
outputs: list[tuple[MessageGeneration, _TextOutput | None, _AudioOutput | None]],
) -> None:
print("reading messages")
nonlocal read_transcript_from_tts
assert isinstance(self.llm, llm.RealtimeModel)

forward_tasks: list[asyncio.Task[Any]] = []
try:
async for msg in generation_ev.message_stream:
# async for msg in generation_ev.message_stream:
stream_iter = generation_ev.message_stream.__aiter__()

## __aiter__(self) → returns the async iterator object itself
## __anext__(self) → returns an awaitable that produces the next item, or raises StopAsyncIteration

while True:
try:
print(f"🔄 LOOP: Attempting to get next message...")
# Use timeout to prevent infinite blocking on message stream
msg = await asyncio.wait_for(stream_iter.__anext__(), timeout=3.0)
print(f"💬 MESSAGE: Got message from stream")
except asyncio.TimeoutError:
print("Message stream timeout - checking for interruption")
if speech_handle.interrupted:
print("Speech interrupted, breaking message reading loop")
break
continue
except StopAsyncIteration:
print("Message stream ended normally")
break

if len(forward_tasks) > 0:
logger.warning(
"expected to receive only one message generation from the realtime API"
Expand All @@ -2017,6 +2025,7 @@
msg_modalities = await msg.modalities
tts_text_input: AsyncIterable[str] | None = None
if "audio" not in msg_modalities and self.tts:
print("audio not in msg_modalities")
if self.llm.capabilities.audio_output:
logger.warning(
"text response received from realtime API, falling back to use a TTS model."
Expand All @@ -2036,7 +2045,6 @@
node=self._agent.tts_node,
input=tts_text_input,
model_settings=model_settings,
text_transforms=self._session.options.tts_text_transforms,
)

if (
Expand All @@ -2054,6 +2062,7 @@
tasks.append(tts_task)
realtime_audio_result = tts_gen_data.audio_ch
elif "audio" in msg_modalities:
print("audio in msg_modalities")
realtime_audio = self._agent.realtime_audio_output_node(
msg.audio_stream, model_settings
)
Expand All @@ -2074,6 +2083,7 @@
)

if realtime_audio_result is not None:
print("before calling perform_audio_forwarding")
forward_task, audio_out = perform_audio_forwarding(
audio_output=audio_output,
tts_output=realtime_audio_result,
Expand All @@ -2099,17 +2109,25 @@

await asyncio.gather(*forward_tasks)
finally:
print("🔚 EXITING _read_messages task")
await utils.aio.cancel_and_wait(*forward_tasks)

print(f"🔍 CHECK: generation_ev = {generation_ev}")
print(f"🔍 CHECK: generation_ev.message_stream = {generation_ev.message_stream}")
print(f"🔍 CHECK: audio_output = {audio_output}")
print(f"🔍 CHECK: text_output = {text_output}")

message_outputs: list[
tuple[MessageGeneration, _TextOutput | None, _AudioOutput | None]
] = []
print(f"🔨 ABOUT TO CREATE _read_messages task")
tasks.append(
asyncio.create_task(
_read_messages(message_outputs),
name="AgentActivity.realtime_generation.read_messages",
)
)
print(f"✅ CREATED _read_messages task")

# read function calls
fnc_tee = utils.aio.itertools.tee(generation_ev.function_stream, 2)
Expand Down Expand Up @@ -2243,6 +2261,7 @@
try:
await exe_task
finally:
print(f"🎬 EXITING _realtime_generation_task")
self._background_speeches.discard(speech_handle)

# important: no agent ouput should be used after this point
Expand Down
15 changes: 5 additions & 10 deletions livekit-agents/livekit/agents/voice/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,21 +174,12 @@ class _TTSGenerationData:


def perform_tts_inference(
*,
node: io.TTSNode,
input: AsyncIterable[str],
model_settings: ModelSettings,
text_transforms: Sequence[TextTransforms] | None,
*, node: io.TTSNode, input: AsyncIterable[str], model_settings: ModelSettings
) -> tuple[asyncio.Task[bool], _TTSGenerationData]:
audio_ch = aio.Chan[rtc.AudioFrame]()
timed_texts_fut = asyncio.Future[Optional[aio.Chan[io.TimedString]]]()
data = _TTSGenerationData(audio_ch=audio_ch, timed_texts_fut=timed_texts_fut)

if text_transforms:
from .transcription.filters import apply_text_transforms

input = apply_text_transforms(input, text_transforms)

tts_task = asyncio.create_task(_tts_inference_task(node, input, model_settings, data))

def _inference_done(_: asyncio.Task[bool]) -> None:
Expand Down Expand Up @@ -277,6 +268,7 @@ def perform_audio_forwarding(
audio_output: io.AudioOutput,
tts_output: AsyncIterable[rtc.AudioFrame],
) -> tuple[asyncio.Task[None], _AudioOutput]:
print(f"🚀 PERFORM: Starting audio forwarding, output_type={type(audio_output).__name__}")
out = _AudioOutput(audio=[], first_frame_fut=asyncio.Future())
task = asyncio.create_task(_audio_forwarding_task(audio_output, tts_output, out))
return task, out
Expand All @@ -288,10 +280,13 @@ async def _audio_forwarding_task(
tts_output: AsyncIterable[rtc.AudioFrame],
out: _AudioOutput,
) -> None:
print(f"🔄 FORWARDING: Started, output_id={id(audio_output)}")
print(f"🔄 FORWARDING: Task started, output_type={type(audio_output).__name__}")
resampler: rtc.AudioResampler | None = None
try:
audio_output.resume()
async for frame in tts_output:
print(f"📨 FORWARDING: Got frame {len(frame.data)} bytes, calling capture_frame")
out.audio.append(frame)

if (
Expand Down
Loading
Loading