diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 49fd0c03a22..b4327a4c46b 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -361,6 +361,16 @@ async def create_chat_completion( _image_gen_height = None _image_gen_width = None + # Extract per-request voice_type from extra_body for audio output. + # Passed as additional_information so Qwen3-Omni's talker stage can + # select the speaker on a per-request basis. + _voice_type: str | None = None + if request.modalities and "audio" in request.modalities: + extra_body = getattr(request, "extra_body", None) + if not extra_body: + extra_body = getattr(request, "model_extra", None) or {} + _voice_type = extra_body.get("voice_type") or extra_body.get("voice") + # Schedule the request and get the result generator. generators: list[AsyncGenerator[RequestOutput, None]] = [] try: @@ -380,6 +390,14 @@ async def create_chat_completion( if hasattr(sp, "width") and _image_gen_width is not None: sp.width = _image_gen_width + # Inject voice_type into the engine prompt's additional_information + # so it flows through thinker → talker stage transition. + if _voice_type is not None and isinstance(engine_prompt, dict): + ai = engine_prompt.setdefault("additional_information", {}) + # Use a list wrapper so the serializer accepts the value + # (AdditionalInformationPayload only supports tensors and lists). + ai["voice_type"] = [_voice_type] + self._log_inputs( request_id, engine_prompt, diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 21f0185aa3f..03fa2a0065a 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -322,9 +322,8 @@ def forward( if inputs_embeds is None and input_ids is not None: inputs_embeds = self.talker.embed_input_ids(input_ids) - # TODO(Peiqi): temporal hack here to support voice_type. - if not hasattr(self, "voice_type"): - self.voice_type = voice_type + # voice_type is now resolved per-request in talker_preprocess_prefill + # via info_dict["voice_type"]. No instance-level caching needed. # Run talker forward with torch.inference_mode(): @@ -678,8 +677,11 @@ def _proj_from_thinker(x_opt: torch.Tensor | None) -> torch.Tensor: def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **info_dict: dict): # Containers to return per-request updates (e.g., code_predictor_hidden_per_request) update_dict: dict[str, dict] = {} - # TODO(Peiqi): add voice_type support - voice_type = self.voice_type + # Per-request voice_type: read from additional_information passed by + # the stage input processor, falling back to the model default. + vt_raw = info_dict.get("voice_type", self.default_tts_text_spk_type) + # voice_type may arrive as a single-element list from serialization. + voice_type = vt_raw[0] if isinstance(vt_raw, list) and vt_raw else vt_raw start_index = info_dict.get("num_processed_tokens", 0) end_index = start_index + input_embeds.shape[0] # Read thinker outputs for prefill diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 7cfc59f79c2..199799f2ef4 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -63,6 +63,37 @@ def _ensure_list(x): return list(x) +def _extract_voice_type(request_or_prompt) -> str | None: + """Extract voice_type from a request's additional_information. + + Supports both raw dict prompts and serialized OmniEngineCoreRequest. + Returns None if voice_type is not specified. + """ + # Raw dict prompt (non-async path) + if isinstance(request_or_prompt, dict): + ai = request_or_prompt.get("additional_information") + if isinstance(ai, dict): + vt = ai.get("voice_type") + # voice_type is stored as a list for serialization compatibility + if isinstance(vt, list) and len(vt) >= 1: + return str(vt[0]) + if isinstance(vt, str): + return vt + return None + + # OmniEngineCoreRequest (async_chunk path) + ai = getattr(request_or_prompt, "additional_information", None) + if ai is None: + return None + entries = getattr(ai, "entries", None) + if entries is None or "voice_type" not in entries: + return None + entry = entries["voice_type"] + if entry.list_data is not None and len(entry.list_data) >= 1: + return str(entry.list_data[0]) + return None + + def _validate_stage_inputs(stage_list, engine_input_source): if not engine_input_source: raise ValueError("engine_input_source cannot be empty") @@ -97,6 +128,8 @@ def thinker2talker_async_chunk( """ request_id = request.external_req_id + # Extract per-request voice_type from the thinker request's additional_information. + voice_type = _extract_voice_type(request) chunk_id = transfer_manager.put_req_chunk[request_id] if chunk_id == 0: all_token_ids = request.all_token_ids # prefill + decode @@ -115,6 +148,8 @@ def thinker2talker_async_chunk( "tts_pad_embed": pooling_output.get("tts_pad_embed").detach().cpu(), "finished": torch.tensor(is_finished, dtype=torch.bool), } + if voice_type is not None: + talker_additional_info["voice_type"] = voice_type if transfer_manager.request_payload.get(request_id) is None: if not is_finished: transfer_manager.request_payload[request_id] = talker_additional_info @@ -179,6 +214,9 @@ def thinker2talker( device = torch.device(current_platform.device_type) + # Extract per-request voice_type from the original prompt. + voice_type = _extract_voice_type(prompt) + # Process each thinker output for thinker_output in thinker_outputs: output = thinker_output.outputs[0] @@ -195,6 +233,10 @@ def thinker2talker( "tts_eos_embed": output.multimodal_output["tts_eos_embed"].detach().to(device=device, dtype=torch.float), "tts_pad_embed": output.multimodal_output["tts_pad_embed"].detach().to(device=device, dtype=torch.float), } + if voice_type is not None: + # Wrap in list for AdditionalInformationPayload serialization + # compatibility (only tensors and lists are supported). + info["voice_type"] = [voice_type] prompt_len = _compute_talker_prompt_ids_length(info, device=device)