diff --git a/livekit-agents/livekit/agents/voice/agent_activity.py b/livekit-agents/livekit/agents/voice/agent_activity.py index 88a439f86d..0b864106da 100644 --- a/livekit-agents/livekit/agents/voice/agent_activity.py +++ b/livekit-agents/livekit/agents/voice/agent_activity.py @@ -89,7 +89,8 @@ class AgentActivity(RecognitionHooks): def __init__(self, agent: Agent, sess: AgentSession) -> None: self._agent, self._session = agent, sess self._rt_session: llm.RealtimeSession | None = None - self._realtime_spans: utils.BoundedDict[str, trace.Span] | None = None + # Lock to coordinate speech assignment with interruptions + self._speech_assignment_lock = asyncio.Lock() self._audio_recognition: AudioRecognition | None = None self._lock = asyncio.Lock() self._tool_choice: llm.ToolChoice | None = None @@ -360,19 +361,7 @@ def _create_speech_task( if speech_handle is not None: tk1 = _SpeechHandleContextVar.set(speech_handle) - # Capture the current OpenTelemetry context to ensure proper span nesting - current_context = otel_context.get_current() - - # Create a wrapper coroutine that runs in the captured context - async def _context_aware_coro() -> Any: - # Attach the captured context before running the original coroutine - token = otel_context.attach(current_context) - try: - return await coro - finally: - otel_context.detach(token) - - task = asyncio.create_task(_context_aware_coro(), name=name) + task = asyncio.create_task(coro, name=name) self._speech_tasks.append(task) task.add_done_callback(lambda _: self._speech_tasks.remove(task)) @@ -515,7 +504,6 @@ async def _list_mcp_tools_task( except llm.RealtimeError: logger.exception("failed to update the tools") - self._realtime_spans = utils.BoundedDict[str, trace.Span](maxsize=100) if ( not self.llm.capabilities.audio_output and not self.tts @@ -664,9 +652,6 @@ async def _close_session(self) -> None: if self._rt_session is not None: await self._rt_session.aclose() - if self._realtime_spans is not None: - self._realtime_spans.clear() - if self._audio_recognition is not None: await self._audio_recognition.aclose() @@ -957,16 +942,19 @@ async def _scheduling_task(self) -> None: # skip done speech (interrupted when it's in the queue) self._current_speech = None continue + # async with speech._assignment_lock: self._current_speech = speech if self.min_consecutive_speech_delay > 0.0: await asyncio.sleep( self.min_consecutive_speech_delay - (time.time() - last_playout_ts) ) + # check again if speech is done after sleep delay if speech.done(): # skip done speech (interrupted during delay) self._current_speech = None continue + speech._authorize_generation() await speech._wait_for_generation() self._current_speech = None @@ -1013,12 +1001,6 @@ def _on_metrics_collected( isinstance(ev, LLMMetrics) or isinstance(ev, TTSMetrics) ): ev.speech_id = speech_handle.id - if ( - isinstance(ev, RealtimeModelMetrics) - and self._realtime_spans is not None - and (realtime_span := self._realtime_spans.pop(ev.request_id, None)) - ): - trace_utils.record_realtime_metrics(realtime_span, ev) self._session.emit("metrics_collected", MetricsCollectedEvent(metrics=ev)) def _on_error( @@ -1072,11 +1054,14 @@ def _on_input_audio_transcription_completed(self, ev: llm.InputTranscriptionComp self._session._conversation_item_added(msg) def _on_generation_created(self, ev: llm.GenerationCreatedEvent) -> None: + print(f"🎯 AGENT: Generation created - user_initiated={ev.user_initiated}") + if ev.user_initiated: # user_initiated generations are directly handled inside _realtime_reply_task return if self._scheduling_paused: + print(f" ❌ BLOCKED: Scheduling paused!") # TODO(theomonnom): should we "forward" this new turn to the next agent? logger.warning("skipping new realtime generation, the speech scheduling is not running") return @@ -1134,14 +1119,27 @@ def _interrupt_by_audio_activity(self) -> None: self._session.output.audio.pause() self._session._update_agent_state("listening") else: - if self._rt_session is not None: - self._rt_session.interrupt() + # if self._rt_session is not None: + # self._rt_session.interrupt() - self._current_speech.interrupt() + # self._current_speech.interrupt() + + # Use lock to prevent interrupting a speech that's being assigned + async def _do_interrupt(): + async with self._speech_assignment_lock: + if self._rt_session is not None: + self._rt_session.interrupt() + + # Only interrupt if this is still the current speech + if self._current_speech and self._current_speech == human_speech: + self._current_speech.interrupt() + + # Schedule the interrupt (can't await in sync function) + asyncio.create_task(_do_interrupt()) # region recognition hooks - def on_start_of_speech(self, ev: vad.VADEvent | None) -> None: + def on_start_of_speech(self, ev: vad.VADEvent) -> None: self._session._update_user_state("speaking") if self._false_interruption_timer: @@ -1149,13 +1147,10 @@ def on_start_of_speech(self, ev: vad.VADEvent | None) -> None: self._false_interruption_timer.cancel() self._false_interruption_timer = None - def on_end_of_speech(self, ev: vad.VADEvent | None) -> None: - speech_end_time = time.time() - if ev: - speech_end_time = speech_end_time - ev.silence_duration + def on_end_of_speech(self, ev: vad.VADEvent) -> None: self._session._update_user_state( "listening", - last_speaking_time=speech_end_time, + last_speaking_time=time.time() - ev.silence_duration, ) if ( @@ -1419,14 +1414,6 @@ async def _user_turn_completed_task( # await the interrupt to make sure user message is added to the chat context before the new task starts await speech_handle.interrupt() - metadata: Metadata | None = None - if isinstance(self.turn_detection, str): - metadata = Metadata(model_name="unknown", model_provider=self.turn_detection) - elif self.turn_detection is not None: - metadata = Metadata( - model_name=self.turn_detection.model, model_provider=self.turn_detection.provider - ) - eou_metrics = EOUMetrics( timestamp=time.time(), end_of_utterance_delay=info.end_of_utterance_delay, @@ -1434,7 +1421,6 @@ async def _user_turn_completed_task( on_user_turn_completed_delay=callback_duration, speech_id=speech_handle.id, last_speaking_time=info.last_speaking_time, - metadata=metadata, ) self._session.emit("metrics_collected", MetricsCollectedEvent(metrics=eou_metrics)) @@ -1500,7 +1486,6 @@ def _on_first_frame(_: asyncio.Future[None]) -> None: node=self._agent.tts_node, input=audio_source, model_settings=model_settings, - text_transforms=self._session.options.tts_text_transforms, ) tasks.append(tts_task) if ( @@ -1642,7 +1627,6 @@ async def _pipeline_reply_task( node=self._agent.tts_node, input=tts_text_input, model_settings=model_settings, - text_transforms=self._session.options.tts_text_transforms, ) tasks.append(tts_task) if ( @@ -1803,7 +1787,6 @@ def _tool_execution_completed_cb(out: ToolExecutionOutput) -> None: ) if generated_msg: - chat_ctx.insert(generated_msg) self._agent._chat_ctx.insert(generated_msg) self._session._conversation_item_added(generated_msg) current_span.set_attribute( @@ -1962,16 +1945,14 @@ async def _realtime_generation_task( model_settings: ModelSettings, instructions: str | None = None, ) -> None: + print(f"🎬 STARTING _realtime_generation_task") current_span = trace.get_current_span() current_span.set_attribute(trace_types.ATTR_SPEECH_ID, speech_handle.id) assert self._rt_session is not None, "rt_session is not available" assert isinstance(self.llm, llm.RealtimeModel), "llm is not a realtime model" - current_span.set_attribute(trace_types.ATTR_GEN_AI_REQUEST_MODEL, self.llm.model) - if self._realtime_spans is not None and generation_ev.response_id: - self._realtime_spans[generation_ev.response_id] = current_span - + print("checking audio output") audio_output = self._session.output.audio if self._session.output.audio_enabled else None text_output = ( self._session.output.transcription @@ -1980,14 +1961,19 @@ async def _realtime_generation_task( ) tool_ctx = llm.ToolContext(self.tools) + print("waiting for authorization") wait_for_authorization = asyncio.ensure_future(speech_handle._wait_for_authorization()) + + print("waiting for speech handle") await speech_handle.wait_if_not_interrupted([wait_for_authorization]) speech_handle._clear_authorization() + print(f"🔍 CHECK: speech_handle.interrupted = {speech_handle.interrupted}") if speech_handle.interrupted: - await utils.aio.cancel_and_wait(wait_for_authorization) + print(f"🚨 EARLY EXIT: Speech interrupted, returning early") + # await utils.aio.cancel_and_wait(wait_for_authorization) current_span.set_attribute(trace_types.ATTR_SPEECH_INTERRUPTED, True) - return # TODO(theomonnom): remove the message from the serverside history + # return # TODO(theomonnom): remove the message from the serverside history def _on_first_frame(_: asyncio.Future[None]) -> None: self._session._update_agent_state("speaking") @@ -2002,12 +1988,34 @@ def _on_first_frame(_: asyncio.Future[None]) -> None: async def _read_messages( outputs: list[tuple[MessageGeneration, _TextOutput | None, _AudioOutput | None]], ) -> None: + print("reading messages") nonlocal read_transcript_from_tts assert isinstance(self.llm, llm.RealtimeModel) forward_tasks: list[asyncio.Task[Any]] = [] try: - async for msg in generation_ev.message_stream: + # async for msg in generation_ev.message_stream: + stream_iter = generation_ev.message_stream.__aiter__() + + ## __aiter__(self) → returns the async iterator object itself + ## __anext__(self) → returns an awaitable that produces the next item, or raises StopAsyncIteration + + while True: + try: + print(f"🔄 LOOP: Attempting to get next message...") + # Use timeout to prevent infinite blocking on message stream + msg = await asyncio.wait_for(stream_iter.__anext__(), timeout=3.0) + print(f"💬 MESSAGE: Got message from stream") + except asyncio.TimeoutError: + print("Message stream timeout - checking for interruption") + if speech_handle.interrupted: + print("Speech interrupted, breaking message reading loop") + break + continue + except StopAsyncIteration: + print("Message stream ended normally") + break + if len(forward_tasks) > 0: logger.warning( "expected to receive only one message generation from the realtime API" @@ -2017,6 +2025,7 @@ async def _read_messages( msg_modalities = await msg.modalities tts_text_input: AsyncIterable[str] | None = None if "audio" not in msg_modalities and self.tts: + print("audio not in msg_modalities") if self.llm.capabilities.audio_output: logger.warning( "text response received from realtime API, falling back to use a TTS model." @@ -2036,7 +2045,6 @@ async def _read_messages( node=self._agent.tts_node, input=tts_text_input, model_settings=model_settings, - text_transforms=self._session.options.tts_text_transforms, ) if ( @@ -2054,6 +2062,7 @@ async def _read_messages( tasks.append(tts_task) realtime_audio_result = tts_gen_data.audio_ch elif "audio" in msg_modalities: + print("audio in msg_modalities") realtime_audio = self._agent.realtime_audio_output_node( msg.audio_stream, model_settings ) @@ -2074,6 +2083,7 @@ async def _read_messages( ) if realtime_audio_result is not None: + print("before calling perform_audio_forwarding") forward_task, audio_out = perform_audio_forwarding( audio_output=audio_output, tts_output=realtime_audio_result, @@ -2099,17 +2109,25 @@ async def _read_messages( await asyncio.gather(*forward_tasks) finally: + print("🔚 EXITING _read_messages task") await utils.aio.cancel_and_wait(*forward_tasks) + print(f"🔍 CHECK: generation_ev = {generation_ev}") + print(f"🔍 CHECK: generation_ev.message_stream = {generation_ev.message_stream}") + print(f"🔍 CHECK: audio_output = {audio_output}") + print(f"🔍 CHECK: text_output = {text_output}") + message_outputs: list[ tuple[MessageGeneration, _TextOutput | None, _AudioOutput | None] ] = [] + print(f"🔨 ABOUT TO CREATE _read_messages task") tasks.append( asyncio.create_task( _read_messages(message_outputs), name="AgentActivity.realtime_generation.read_messages", ) ) + print(f"✅ CREATED _read_messages task") # read function calls fnc_tee = utils.aio.itertools.tee(generation_ev.function_stream, 2) @@ -2243,6 +2261,7 @@ def _tool_execution_completed_cb(out: ToolExecutionOutput) -> None: try: await exe_task finally: + print(f"🎬 EXITING _realtime_generation_task") self._background_speeches.discard(speech_handle) # important: no agent ouput should be used after this point diff --git a/livekit-agents/livekit/agents/voice/generation.py b/livekit-agents/livekit/agents/voice/generation.py index c8026bb282..beef6a5f45 100644 --- a/livekit-agents/livekit/agents/voice/generation.py +++ b/livekit-agents/livekit/agents/voice/generation.py @@ -174,21 +174,12 @@ class _TTSGenerationData: def perform_tts_inference( - *, - node: io.TTSNode, - input: AsyncIterable[str], - model_settings: ModelSettings, - text_transforms: Sequence[TextTransforms] | None, + *, node: io.TTSNode, input: AsyncIterable[str], model_settings: ModelSettings ) -> tuple[asyncio.Task[bool], _TTSGenerationData]: audio_ch = aio.Chan[rtc.AudioFrame]() timed_texts_fut = asyncio.Future[Optional[aio.Chan[io.TimedString]]]() data = _TTSGenerationData(audio_ch=audio_ch, timed_texts_fut=timed_texts_fut) - if text_transforms: - from .transcription.filters import apply_text_transforms - - input = apply_text_transforms(input, text_transforms) - tts_task = asyncio.create_task(_tts_inference_task(node, input, model_settings, data)) def _inference_done(_: asyncio.Task[bool]) -> None: @@ -277,6 +268,7 @@ def perform_audio_forwarding( audio_output: io.AudioOutput, tts_output: AsyncIterable[rtc.AudioFrame], ) -> tuple[asyncio.Task[None], _AudioOutput]: + print(f"🚀 PERFORM: Starting audio forwarding, output_type={type(audio_output).__name__}") out = _AudioOutput(audio=[], first_frame_fut=asyncio.Future()) task = asyncio.create_task(_audio_forwarding_task(audio_output, tts_output, out)) return task, out @@ -288,10 +280,13 @@ async def _audio_forwarding_task( tts_output: AsyncIterable[rtc.AudioFrame], out: _AudioOutput, ) -> None: + print(f"🔄 FORWARDING: Started, output_id={id(audio_output)}") + print(f"🔄 FORWARDING: Task started, output_type={type(audio_output).__name__}") resampler: rtc.AudioResampler | None = None try: audio_output.resume() async for frame in tts_output: + print(f"📨 FORWARDING: Got frame {len(frame.data)} bytes, calling capture_frame") out.audio.append(frame) if ( diff --git a/livekit-agents/livekit/agents/voice/transcription/synchronizer.py b/livekit-agents/livekit/agents/voice/transcription/synchronizer.py index 6b49abfe62..48d11b509b 100644 --- a/livekit-agents/livekit/agents/voice/transcription/synchronizer.py +++ b/livekit-agents/livekit/agents/voice/transcription/synchronizer.py @@ -513,25 +513,31 @@ def __init__( async def capture_frame(self, frame: rtc.AudioFrame) -> None: # using barrier() on capture should be sufficient, flush() must not be called if # capture_frame isn't completed - await self._synchronizer.barrier() + import traceback + # print(f"🔒 SYNC: About to wait for barrier, {len(frame.data)} bytes") + # print(f"📍 CALL STACK:\n{''.join(traceback.format_stack())}") + # await self._synchronizer.barrier() + # print(f"✅ SYNC: Barrier passed, calling next_in_chain") self._capturing = True await super().capture_frame(frame) await self._next_in_chain.capture_frame(frame) # passthrough audio self._pushed_duration += frame.duration - if not self._synchronizer.enabled: - return + return - if self._synchronizer._impl.audio_input_ended: - # this should not happen if `on_playback_finished` is called after each flush - logger.warning( - "_SegmentSynchronizerImpl audio marked as ended in capture audio, rotating segment" - ) - self._synchronizer.rotate_segment() - await self._synchronizer.barrier() + # if not self._synchronizer.enabled: + # return + + # if self._synchronizer._impl.audio_input_ended: + # # this should not happen if `on_playback_finished` is called after each flush + # logger.warning( + # "_SegmentSynchronizerImpl audio marked as ended in capture audio, rotating segment" + # ) + # self._synchronizer.rotate_segment() + # # await self._synchronizer.barrier() - self._synchronizer._impl.push_audio(frame) + # self._synchronizer._impl.push_audio(frame) def flush(self) -> None: super().flush() @@ -607,7 +613,7 @@ def __init__( self._capturing = False async def capture_text(self, text: str) -> None: - await self._synchronizer.barrier() + # await self._synchronizer.barrier() if not self._synchronizer.enabled: # passthrough text if the synchronizer is disabled await self._next_in_chain.capture_text(text) @@ -620,7 +626,7 @@ async def capture_text(self, text: str) -> None: "_SegmentSynchronizerImpl text marked as ended in capture text, rotating segment" ) self._synchronizer.rotate_segment() - await self._synchronizer.barrier() + # await self._synchronizer.barrier() self._synchronizer._impl.push_text(text) diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index 80f74418b9..2a6eb48094 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -127,7 +127,7 @@ class _RealtimeOptions: voice: str tool_choice: llm.ToolChoice | None input_audio_transcription: AudioTranscription | None - input_audio_noise_reduction: NoiseReduction | None + input_audio_noise_reduction: NoiseReductionType | None turn_detection: RealtimeAudioInputTurnDetection | None max_response_output_tokens: int | Literal["inf"] | None tracing: Tracing | None @@ -179,7 +179,7 @@ def __init__( AudioTranscription | InputAudioTranscription | None ] = NOT_GIVEN, input_audio_noise_reduction: NotGivenOr[ - NoiseReductionType | NoiseReduction | InputAudioNoiseReduction | None + NoiseReductionType | InputAudioNoiseReduction | None ] = NOT_GIVEN, turn_detection: NotGivenOr[ RealtimeAudioInputTurnDetection | TurnDetection | None @@ -210,7 +210,7 @@ def __init__( AudioTranscription | InputAudioTranscription | None ] = NOT_GIVEN, input_audio_noise_reduction: NotGivenOr[ - NoiseReductionType | NoiseReduction | InputAudioNoiseReduction | None + NoiseReductionType | InputAudioNoiseReduction | None ] = NOT_GIVEN, turn_detection: NotGivenOr[ RealtimeAudioInputTurnDetection | TurnDetection | None @@ -236,7 +236,7 @@ def __init__( AudioTranscription | InputAudioTranscription | None ] = NOT_GIVEN, input_audio_noise_reduction: NotGivenOr[ - NoiseReductionType | NoiseReduction | InputAudioNoiseReduction | None + NoiseReductionType | InputAudioNoiseReduction | None ] = NOT_GIVEN, turn_detection: NotGivenOr[ RealtimeAudioInputTurnDetection | TurnDetection | None @@ -262,7 +262,7 @@ def __init__( tool_choice (llm.ToolChoice | None | NotGiven): Tool selection policy for responses. base_url (str | NotGiven): HTTP base URL of the OpenAI/Azure API. If not provided, uses OPENAI_BASE_URL for OpenAI; for Azure, constructed from AZURE_OPENAI_ENDPOINT. input_audio_transcription (AudioTranscription | None | NotGiven): Options for transcribing input audio. - input_audio_noise_reduction (NoiseReductionType | NoiseReduction | InputAudioNoiseReduction | None | NotGiven): Input audio noise reduction settings. + input_audio_noise_reduction (NoiseReductionType | None | NotGiven): Input audio noise reduction settings. turn_detection (RealtimeAudioInputTurnDetection | None | NotGiven): Server-side turn-detection options. speed (float | NotGiven): Audio playback speed multiplier. tracing (Tracing | None | NotGiven): Tracing configuration for OpenAI Realtime. @@ -365,16 +365,6 @@ def __init__( self._http_session_owned = False self._sessions = weakref.WeakSet[RealtimeSession]() - @property - def model(self) -> str: - return self._opts.model - - @property - def provider(self) -> str: - from urllib.parse import urlparse - - return urlparse(self._opts.base_url).netloc - @classmethod def with_azure( cls, @@ -390,7 +380,9 @@ def with_azure( input_audio_transcription: NotGivenOr[ AudioTranscription | InputAudioTranscription | None ] = NOT_GIVEN, - input_audio_noise_reduction: NoiseReductionType | InputAudioNoiseReduction | None = None, + input_audio_noise_reduction: NotGivenOr[ + NoiseReductionType | InputAudioNoiseReduction | None + ] = NOT_GIVEN, turn_detection: NotGivenOr[ RealtimeAudioInputTurnDetection | TurnDetection | None ] = NOT_GIVEN, @@ -400,8 +392,7 @@ def with_azure( temperature: NotGivenOr[float] = NOT_GIVEN, # deprecated, unused in v1 ) -> RealtimeModel | RealtimeModelBeta: """ - Create a RealtimeModelBeta configured for Azure OpenAI. Azure does not currently support the GA API, - so we return RealtimeModelBeta instead of RealtimeModel. + Create a RealtimeModel configured for Azure OpenAI. Args: azure_deployment (str): Azure OpenAI deployment name. @@ -413,7 +404,7 @@ def with_azure( voice (str): Voice used for audio responses. modalities (list[Literal["text", "audio"]] | NotGiven): Modalities to enable. Defaults to ["text", "audio"] if not provided. input_audio_transcription (AudioTranscription | InputAudioTranscription | None | NotGiven): Transcription options; defaults to Azure-optimized values when not provided. - input_audio_noise_reduction (NoiseReductionType | InputAudioNoiseReduction | None): Input noise reduction settings. Defaults to None. + input_audio_noise_reduction (NoiseReductionType | InputAudioNoiseReduction | None | NotGiven): Input noise reduction settings. turn_detection (RealtimeAudioInputTurnDetection | TurnDetection | None | NotGiven): Server-side VAD; defaults to Azure-optimized values when not provided. speed (float | NotGiven): Audio playback speed multiplier. tracing (Tracing | None | NotGiven): Tracing configuration for OpenAI Realtime. @@ -421,7 +412,7 @@ def with_azure( temperature (float | NotGiven): Deprecated; ignored by Realtime v1. Returns: - RealtimeModelBeta: Configured client for Azure OpenAI Realtime. + RealtimeModel: Configured client for Azure OpenAI Realtime. Raises: ValueError: If credentials are missing, `api_version` is not provided, Azure endpoint cannot be determined, or both `base_url` and `azure_endpoint` are provided. @@ -509,32 +500,50 @@ def with_azure( if not is_given(turn_detection): turn_detection = AZURE_DEFAULT_TURN_DETECTION - if is_given(input_audio_transcription) and not isinstance( - input_audio_transcription, InputAudioTranscription - ): - raise ValueError( - f"input_audio_transcription must be an instance of InputAudioTranscription for api-version {api_version}" - ) - if is_given(turn_detection) and not isinstance(turn_detection, TurnDetection): - raise ValueError( - f"turn_detection must be an instance of TurnDetection for api-version {api_version}" - ) - if input_audio_noise_reduction is not None and not isinstance( - input_audio_noise_reduction, InputAudioNoiseReduction - ): - raise ValueError( - f"input_audio_noise_reduction must be an instance of InputAudioNoiseReduction for api-version {api_version}" + if api_version == "2024-10-01-preview": + if is_given(input_audio_transcription) and not isinstance( + input_audio_transcription, InputAudioTranscription + ): + raise ValueError( + f"input_audio_transcription must be an instance of InputAudioTranscription for api-version {api_version}" + ) + if is_given(turn_detection) and not isinstance(turn_detection, TurnDetection): + raise ValueError( + f"turn_detection must be an instance of TurnDetection for api-version {api_version}" + ) + if is_given(input_audio_noise_reduction) and not isinstance( + input_audio_noise_reduction, InputAudioNoiseReduction + ): + raise ValueError( + f"input_audio_noise_reduction must be an instance of InputAudioNoiseReduction for api-version {api_version}" + ) + + return RealtimeModelBeta( + voice=voice, + modalities=modalities, + input_audio_transcription=input_audio_transcription, # type: ignore + input_audio_noise_reduction=input_audio_noise_reduction, # type: ignore + turn_detection=turn_detection, # type: ignore + temperature=temperature, + speed=speed, + tracing=tracing, # type: ignore + api_key=api_key, + http_session=http_session, + azure_deployment=azure_deployment, + api_version=api_version, + entra_token=entra_token, + base_url=base_url, ) - return RealtimeModelBeta( + return cls( voice=voice, modalities=modalities, - input_audio_transcription=input_audio_transcription, # type: ignore - input_audio_noise_reduction=input_audio_noise_reduction, - turn_detection=turn_detection, # type: ignore + input_audio_transcription=to_audio_transcription(input_audio_transcription), + input_audio_noise_reduction=to_noise_reduction(input_audio_noise_reduction), + turn_detection=to_turn_detection(turn_detection), temperature=temperature, speed=speed, - tracing=tracing, # type: ignore + tracing=tracing, api_key=api_key, http_session=http_session, azure_deployment=azure_deployment, @@ -555,7 +564,7 @@ def update_options( InputAudioTranscription | AudioTranscription | None ] = NOT_GIVEN, input_audio_noise_reduction: NotGivenOr[ - NoiseReduction | NoiseReductionType | InputAudioNoiseReduction | None + InputAudioNoiseReduction | NoiseReductionType | None ] = NOT_GIVEN, max_response_output_tokens: NotGivenOr[int | Literal["inf"] | None] = NOT_GIVEN, speed: NotGivenOr[float] = NOT_GIVEN, @@ -725,7 +734,7 @@ async def _reconnect() -> None: exclude_instructions=True, exclude_empty_message=True, ) - old_chat_ctx = self._remote_chat_ctx + old_chat_ctx_copy = copy.deepcopy(self._remote_chat_ctx) self._remote_chat_ctx = llm.remote_chat_context.RemoteChatContext() events.extend(self._create_update_chat_ctx_events(chat_ctx)) @@ -735,7 +744,7 @@ async def _reconnect() -> None: self.emit("openai_client_event_queued", msg) await ws_conn.send_str(json.dumps(msg)) except Exception as e: - self._remote_chat_ctx = old_chat_ctx # restore the old chat context + self._remote_chat_ctx = old_chat_ctx_copy # restore the old chat context raise APIConnectionError( message=( "Failed to send message to OpenAI Realtime API during session re-connection" @@ -853,8 +862,9 @@ async def _recv_task() -> None: aiohttp.WSMsgType.CLOSING, ): if closing: # closing is expected, see _send_task + print(f"✓ WebSocket closing normally") return - + print(f"💀 WebSocket closed unexpectedly!") # this will trigger a reconnection raise APIConnectionError(message="OpenAI S2S connection closed unexpectedly") @@ -972,6 +982,12 @@ async def _recv_task() -> None: await ws_conn.close() def _create_session_update_event(self) -> SessionUpdateEvent: + noise_reduction: realtime.realtime_audio_config_input.NoiseReduction | None = None + if self._realtime_model._opts.input_audio_noise_reduction: + noise_reduction = realtime.realtime_audio_config_input.NoiseReduction( + type=self._realtime_model._opts.input_audio_noise_reduction, + ) + audio_format = realtime.realtime_audio_formats.AudioPCM(rate=SAMPLE_RATE, type="audio/pcm") # they do not support both text and audio modalities, it'll respond in audio + transcript modality = "audio" if "audio" in self._realtime_model._opts.modalities else "text" @@ -983,7 +999,7 @@ def _create_session_update_event(self) -> SessionUpdateEvent: audio=RealtimeAudioConfig( input=RealtimeAudioConfigInput( format=audio_format, - noise_reduction=self._realtime_model._opts.input_audio_noise_reduction, + noise_reduction=noise_reduction, transcription=self._realtime_model._opts.input_audio_transcription, turn_detection=self._realtime_model._opts.turn_detection, ), @@ -1025,76 +1041,52 @@ def update_options( turn_detection: NotGivenOr[RealtimeAudioInputTurnDetection | None] = NOT_GIVEN, max_response_output_tokens: NotGivenOr[int | Literal["inf"] | None] = NOT_GIVEN, input_audio_transcription: NotGivenOr[AudioTranscription | None] = NOT_GIVEN, - input_audio_noise_reduction: NotGivenOr[ - NoiseReductionType | NoiseReduction | InputAudioNoiseReduction | None - ] = NOT_GIVEN, + input_audio_noise_reduction: NotGivenOr[NoiseReductionType | None] = NOT_GIVEN, speed: NotGivenOr[float] = NOT_GIVEN, tracing: NotGivenOr[Tracing | None] = NOT_GIVEN, ) -> None: - session = RealtimeSessionCreateRequest( - type="realtime", - ) - has_changes = False + kwargs: dict[str, Any] = { + "type": "realtime", + } if is_given(tool_choice): tool_choice = cast(Optional[llm.ToolChoice], tool_choice) self._realtime_model._opts.tool_choice = tool_choice - session.tool_choice = to_oai_tool_choice(tool_choice) - has_changes = True - - if is_given(max_response_output_tokens): - self._realtime_model._opts.max_response_output_tokens = max_response_output_tokens # type: ignore - session.max_output_tokens = max_response_output_tokens # type: ignore - has_changes = True - - if is_given(tracing): - self._realtime_model._opts.tracing = cast(Union[Tracing, None], tracing) - session.tracing = cast(Union[Tracing, None], tracing) # type: ignore - has_changes = True - - has_audio_config = False - audio_output = RealtimeAudioConfigOutput() - audio_input = RealtimeAudioConfigInput() - audio_config = RealtimeAudioConfig( - output=audio_output, - input=audio_input, - ) + kwargs["tool_choice"] = to_oai_tool_choice(tool_choice) if is_given(voice): self._realtime_model._opts.voice = voice - audio_output.voice = voice - has_audio_config = True + kwargs["voice"] = voice if is_given(turn_detection): self._realtime_model._opts.turn_detection = turn_detection # type: ignore - audio_input.turn_detection = turn_detection # type: ignore - has_audio_config = True + kwargs["turn_detection"] = turn_detection + + if is_given(max_response_output_tokens): + self._realtime_model._opts.max_response_output_tokens = max_response_output_tokens # type: ignore + kwargs["max_response_output_tokens"] = max_response_output_tokens if is_given(input_audio_transcription): self._realtime_model._opts.input_audio_transcription = input_audio_transcription - audio_input.transcription = input_audio_transcription - has_audio_config = True + kwargs["input_audio_transcription"] = input_audio_transcription if is_given(input_audio_noise_reduction): - input_audio_noise_reduction = to_noise_reduction(input_audio_noise_reduction) # type: ignore - self._realtime_model._opts.input_audio_noise_reduction = input_audio_noise_reduction - audio_input.noise_reduction = input_audio_noise_reduction - has_audio_config = True + self._realtime_model._opts.input_audio_noise_reduction = input_audio_noise_reduction # type: ignore + kwargs["input_audio_noise_reduction"] = input_audio_noise_reduction if is_given(speed): self._realtime_model._opts.speed = speed - audio_output.speed = speed - has_audio_config = True + kwargs["speed"] = speed - if has_audio_config: - session.audio = audio_config - has_changes = True + if is_given(tracing): + self._realtime_model._opts.tracing = cast(Union[Tracing, None], tracing) + kwargs["tracing"] = cast(Union[Tracing, None], tracing) - if has_changes: + if kwargs: self.send_event( SessionUpdateEvent( type="session.update", - session=session, + session=RealtimeSessionCreateRequest.model_construct(**kwargs), event_id=utils.shortuuid("options_update_"), ) ) @@ -1391,6 +1383,7 @@ def _handle_input_audio_buffer_speech_stopped( ) def _handle_response_created(self, event: ResponseCreatedEvent) -> None: + print("response created") assert event.response.id is not None, "response.id is None" self._current_generation = _ResponseGeneration( @@ -1405,7 +1398,6 @@ def _handle_response_created(self, event: ResponseCreatedEvent) -> None: message_stream=self._current_generation.message_ch, function_stream=self._current_generation.function_ch, user_initiated=False, - response_id=event.response.id, ) if ( @@ -1437,6 +1429,7 @@ def _handle_response_output_item_added(self, event: ResponseOutputItemAddedEvent item_generation.audio_ch.close() item_generation.modalities.set_result(["text"]) + print("sending current generation message channel") self._current_generation.message_ch.send_nowait( llm.MessageGeneration( message_id=item_id, @@ -1515,11 +1508,12 @@ def _handle_conversion_item_input_audio_transcription_failed( def _handle_response_text_delta(self, event: ResponseTextDeltaEvent) -> None: assert self._current_generation is not None, "current_generation is None" item_generation = self._current_generation.messages[event.item_id] + if ( item_generation.audio_ch.closed and self._current_generation._first_token_timestamp is None - ): - # only if audio is not available + ): + # only if audio is not available self._current_generation._first_token_timestamp = time.time() item_generation.text_ch.send_nowait(event.delta) @@ -1542,11 +1536,14 @@ def _handle_response_audio_transcript_delta(self, event: dict[str, Any]) -> None item_generation.audio_transcript += delta def _handle_response_audio_delta(self, event: ResponseAudioDeltaEvent) -> None: + print("response audio delta") + # print("🔥 TEST: This should definitely show up!") assert self._current_generation is not None, "current_generation is None" item_generation = self._current_generation.messages[event.item_id] if self._current_generation._first_token_timestamp is None: self._current_generation._first_token_timestamp = time.time() + if not item_generation.modalities.done(): item_generation.modalities.set_result(["audio", "text"]) @@ -1601,6 +1598,7 @@ def _handle_response_output_item_done(self, event: ResponseOutputItemDoneEvent) item_generation.modalities.set_result(self._realtime_model._opts.modalities) def _handle_response_done(self, event: ResponseDoneEvent) -> None: + print("response done") if self._current_generation is None: return # OpenAI has a race condition where we could receive response.done without any previous response.created (This happens generally during interruption) # noqa: E501 @@ -1642,7 +1640,7 @@ def _handle_response_done(self, event: ResponseDoneEvent) -> None: ttft=ttft, duration=duration, cancelled=event.response.status == "cancelled", - label=self._realtime_model.label, + label=self._realtime_model._label, input_tokens=usage.get("input_tokens", 0), output_tokens=usage.get("output_tokens", 0), total_tokens=usage.get("total_tokens", 0), @@ -1669,9 +1667,6 @@ def _handle_response_done(self, event: ResponseDoneEvent) -> None: audio_tokens=usage.get("output_token_details", {}).get("audio_tokens", 0), image_tokens=0, ), - metadata=Metadata( - model_name=self._realtime_model.model, model_provider=self._realtime_model.provider - ), ) self.emit("metrics_collected", metrics) self._handle_response_done_but_not_complete(event) @@ -1718,6 +1713,7 @@ def _handle_response_done_but_not_complete(self, event: ResponseDoneEvent) -> No logger.debug("Unknown response status: %s", event.response.status) def _handle_error(self, event: RealtimeErrorEvent) -> None: + print(f"💀 REALTIME ERROR: {event.error}") if event.error.message.startswith("Cancellation failed"): return