-
Notifications
You must be signed in to change notification settings - Fork 1k
[Feature] add session based audio streaming input #2208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7ea9d41
746871f
c6396fc
d0e808d
572bd3e
5a89944
1777e57
2871922
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| """ | ||
| This script demonstrates how to use the vLLM-Omni Realtime WebSocket API to perform | ||
| audio transcription by uploading an audio file. | ||
|
|
||
| Before running this script, you must start the vLLM-Omni server with a realtime-capable | ||
| model, for example: | ||
|
|
||
| vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni | ||
|
|
||
| Requirements: | ||
| - vllm with audio support | ||
| - websockets | ||
| - librosa | ||
| - numpy | ||
|
|
||
| The script: | ||
| 1. Connects to the Realtime WebSocket endpoint | ||
| 2. Converts an audio file to PCM16 @ 16kHz | ||
| 3. Sends audio chunks to the server | ||
| 4. Receives and prints transcription as it streams | ||
| """ | ||
|
|
||
| import argparse | ||
| import asyncio | ||
| import base64 | ||
| import json | ||
|
|
||
| import librosa | ||
| import numpy as np | ||
| import websockets | ||
| from vllm.assets.audio import AudioAsset | ||
|
|
||
|
|
||
| def audio_to_pcm16_base64(audio_path: str) -> str: | ||
| """ | ||
| Load an audio file and convert it to base64-encoded PCM16 @ 16kHz. | ||
| """ | ||
| # Load audio and resample to 16kHz mono | ||
| audio, _ = librosa.load(audio_path, sr=16000, mono=True) | ||
| # Convert to PCM16 | ||
| pcm16 = (audio * 32767).astype(np.int16) | ||
| # Encode as base64 | ||
| return base64.b64encode(pcm16.tobytes()).decode("utf-8") | ||
|
|
||
|
|
||
| async def realtime_transcribe(audio_path: str, host: str, port: int, model: str): | ||
| """ | ||
| Connect to the Realtime API and transcribe an audio file. | ||
| """ | ||
| uri = f"ws://{host}:{port}/v1/realtime" | ||
|
|
||
| async with websockets.connect(uri) as ws: | ||
| # Wait for session.created | ||
| response = json.loads(await ws.recv()) | ||
| if response["type"] == "session.created": | ||
| print(f"Session created: {response['id']}") | ||
| else: | ||
| print(f"Unexpected response: {response}") | ||
| return | ||
|
|
||
| # Validate model | ||
| await ws.send(json.dumps({"type": "session.update", "model": model})) | ||
|
|
||
| # Signal ready to start | ||
| await ws.send(json.dumps({"type": "input_audio_buffer.commit"})) | ||
|
|
||
| # Convert audio file to base64 PCM16 | ||
| print(f"Loading audio from: {audio_path}") | ||
| audio_base64 = audio_to_pcm16_base64(audio_path) | ||
|
|
||
| # Send audio in chunks (4KB of raw audio = ~8KB base64) | ||
| chunk_size = 4096 | ||
| audio_bytes = base64.b64decode(audio_base64) | ||
| total_chunks = (len(audio_bytes) + chunk_size - 1) // chunk_size | ||
|
|
||
| print(f"Sending {total_chunks} audio chunks...") | ||
| for i in range(0, len(audio_bytes), chunk_size): | ||
| chunk = audio_bytes[i : i + chunk_size] | ||
| await ws.send( | ||
| json.dumps( | ||
| { | ||
| "type": "input_audio_buffer.append", | ||
| "audio": base64.b64encode(chunk).decode("utf-8"), | ||
| } | ||
| ) | ||
| ) | ||
|
|
||
| # Signal all audio is sent | ||
| await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True})) | ||
| print("Audio sent. Waiting for transcription...\n") | ||
|
|
||
| # Receive transcription | ||
| print("Transcription: ", end="", flush=True) | ||
| while True: | ||
| response = json.loads(await ws.recv()) | ||
| if response["type"] == "transcription.delta": | ||
| print(response["delta"], end="", flush=True) | ||
| elif response["type"] == "transcription.done": | ||
| print(f"\n\nFinal transcription: {response['text']}") | ||
| if response.get("usage"): | ||
| print(f"Usage: {response['usage']}") | ||
| break | ||
| elif response["type"] == "error": | ||
| print(f"\nError: {response['error']}") | ||
| break | ||
|
|
||
|
|
||
| def main(args): | ||
| if args.audio_path: | ||
| audio_path = args.audio_path | ||
| else: | ||
| # Use default audio asset | ||
| audio_path = str(AudioAsset("mary_had_lamb").get_local_path()) | ||
| print(f"No audio path provided, using default: {audio_path}") | ||
|
|
||
| asyncio.run(realtime_transcribe(audio_path, args.host, args.port, args.model)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="Realtime WebSocket Transcription Client") | ||
| parser.add_argument( | ||
| "--model", | ||
| type=str, | ||
| default="Qwen/Qwen3-Omni-30B-A3B-Instruct", | ||
| help="Model that is served and should be pinged.", | ||
| ) | ||
| parser.add_argument( | ||
| "--audio_path", | ||
| type=str, | ||
| default=None, | ||
| help="Path to the audio file to transcribe.", | ||
| ) | ||
| parser.add_argument( | ||
| "--host", | ||
| type=str, | ||
| default="localhost", | ||
| help="vLLM-Omni server host (default: localhost)", | ||
| ) | ||
| parser.add_argument( | ||
| "--port", | ||
| type=int, | ||
| default=8000, | ||
| help="vLLM-Omni server port (default: 8000)", | ||
| ) | ||
| args = parser.parse_args() | ||
| main(args) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -635,9 +635,13 @@ def _build_add_request_message( | |
| self, | ||
| request_id: str, | ||
| prompt: EngineCoreRequest | PromptType, | ||
| prompt_text: str | None = None, | ||
| sampling_params_list: Sequence[Any] | None = None, | ||
| final_stage_id: int = 0, | ||
| arrival_time: float | None = None, | ||
| *, | ||
| resumable: bool = False, | ||
| message_type: str = "add_request", | ||
| ) -> dict[str, Any]: | ||
| """Build an add_request message after stage-0 preprocessing.""" | ||
| effective_sampling_params_list = ( | ||
|
|
@@ -669,6 +673,7 @@ def _build_add_request_message( | |
| params=params, | ||
| supported_tasks=self.supported_tasks, | ||
| arrival_time=arrival_time, | ||
| resumable=resumable, | ||
| ) | ||
| # TODO (Peiqi): add this for Qwen3-TTS only. Other models don't have | ||
| # additional_information field in the prompt. | ||
|
|
@@ -683,17 +688,18 @@ def _build_add_request_message( | |
| request.external_req_id = request_id | ||
|
|
||
| # Register with stage 0's output processor. | ||
| output_prompt_text = prompt_text | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please confirm if this is necessary?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this aligns with vLLM: the prompt in RequestState within output_processor.py should be of type str; otherwise, |
||
| self.output_processors[0].add_request( | ||
| request=request, | ||
| prompt=prompt, | ||
| prompt=output_prompt_text, | ||
| parent_req=None, | ||
| request_index=0, | ||
| queue=None, | ||
| ) | ||
| prompt = request | ||
|
|
||
| return { | ||
| "type": "add_request", | ||
| "type": message_type, | ||
| "request_id": request_id, | ||
| "prompt": prompt, | ||
| "original_prompt": original_prompt, | ||
|
|
@@ -949,9 +955,12 @@ def add_request( | |
| self, | ||
| request_id: str, | ||
| prompt: EngineCoreRequest | PromptType, | ||
| prompt_text: str | None = None, | ||
| sampling_params_list: Sequence[Any] | None = None, | ||
| final_stage_id: int = 0, | ||
| arrival_time: float | None = None, | ||
| *, | ||
| resumable: bool = False, | ||
| ) -> None: | ||
| """Process stage-0 input locally, then send to the Orchestrator. | ||
|
|
||
|
|
@@ -963,9 +972,11 @@ def add_request( | |
| msg = self._build_add_request_message( | ||
| request_id=request_id, | ||
| prompt=prompt, | ||
| prompt_text=prompt_text, | ||
| sampling_params_list=sampling_params_list, | ||
| final_stage_id=final_stage_id, | ||
| arrival_time=arrival_time, | ||
| resumable=resumable, | ||
| ) | ||
| if self.request_queue is None: | ||
| raise RuntimeError("request_queue is not initialized") | ||
|
|
@@ -984,17 +995,70 @@ async def add_request_async( | |
| self, | ||
| request_id: str, | ||
| prompt: EngineCoreRequest | PromptType, | ||
| prompt_text: str | None = None, | ||
| sampling_params_list: Sequence[Any] | None = None, | ||
| final_stage_id: int = 0, | ||
| arrival_time: float | None = None, | ||
| *, | ||
| resumable: bool = False, | ||
| ) -> None: | ||
| """Async add_request API.""" | ||
| self.add_request( | ||
| request_id=request_id, | ||
| prompt=prompt, | ||
| prompt_text=prompt_text, | ||
| sampling_params_list=sampling_params_list, | ||
| final_stage_id=final_stage_id, | ||
| arrival_time=arrival_time, | ||
| resumable=resumable, | ||
| ) | ||
|
|
||
| def add_streaming_update( | ||
| self, | ||
| request_id: str, | ||
| prompt: EngineCoreRequest | PromptType, | ||
| prompt_text: str | None = None, | ||
| sampling_params_list: Sequence[Any] | None = None, | ||
| final_stage_id: int = 0, | ||
| arrival_time: float | None = None, | ||
| *, | ||
| resumable: bool = True, | ||
| ) -> None: | ||
| """Send an incremental streaming update for an existing request.""" | ||
| msg = self._build_add_request_message( | ||
| request_id=request_id, | ||
| prompt=prompt, | ||
| prompt_text=prompt_text, | ||
| sampling_params_list=sampling_params_list, | ||
| final_stage_id=final_stage_id, | ||
| arrival_time=arrival_time, | ||
| resumable=resumable, | ||
| message_type="streaming_update", | ||
| ) | ||
| if self.request_queue is None: | ||
| raise RuntimeError("request_queue is not initialized") | ||
| self.request_queue.sync_q.put_nowait(msg) | ||
|
|
||
| async def add_streaming_update_async( | ||
| self, | ||
| request_id: str, | ||
| prompt: EngineCoreRequest | PromptType, | ||
| prompt_text: str | None = None, | ||
| sampling_params_list: Sequence[Any] | None = None, | ||
| final_stage_id: int = 0, | ||
| arrival_time: float | None = None, | ||
| *, | ||
| resumable: bool = True, | ||
| ) -> None: | ||
| """Async wrapper for add_streaming_update().""" | ||
| self.add_streaming_update( | ||
| request_id=request_id, | ||
| prompt=prompt, | ||
| prompt_text=prompt_text, | ||
| sampling_params_list=sampling_params_list, | ||
| final_stage_id=final_stage_id, | ||
| arrival_time=arrival_time, | ||
| resumable=resumable, | ||
| ) | ||
|
|
||
| def try_get_output(self, timeout: float = 0.001) -> dict[str, Any] | None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the README.md