Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions vllm_omni/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,16 @@ async def create_chat_completion(
_image_gen_height = None
_image_gen_width = None

# Extract per-request voice_type from extra_body for audio output.
# Passed as additional_information so Qwen3-Omni's talker stage can
# select the speaker on a per-request basis.
_voice_type: str | None = None
if request.modalities and "audio" in request.modalities:
extra_body = getattr(request, "extra_body", None)
if not extra_body:
extra_body = getattr(request, "model_extra", None) or {}
_voice_type = extra_body.get("voice_type") or extra_body.get("voice")

# Schedule the request and get the result generator.
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
Expand All @@ -380,6 +390,14 @@ async def create_chat_completion(
if hasattr(sp, "width") and _image_gen_width is not None:
sp.width = _image_gen_width

# Inject voice_type into the engine prompt's additional_information
# so it flows through thinker → talker stage transition.
if _voice_type is not None and isinstance(engine_prompt, dict):
ai = engine_prompt.setdefault("additional_information", {})
# Use a list wrapper so the serializer accepts the value
# (AdditionalInformationPayload only supports tensors and lists).
ai["voice_type"] = [_voice_type]

self._log_inputs(
request_id,
engine_prompt,
Expand Down
12 changes: 7 additions & 5 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,8 @@ def forward(
if inputs_embeds is None and input_ids is not None:
inputs_embeds = self.talker.embed_input_ids(input_ids)

# TODO(Peiqi): temporal hack here to support voice_type.
if not hasattr(self, "voice_type"):
self.voice_type = voice_type
# voice_type is now resolved per-request in talker_preprocess_prefill
# via info_dict["voice_type"]. No instance-level caching needed.

# Run talker forward
with torch.inference_mode():
Expand Down Expand Up @@ -678,8 +677,11 @@ def _proj_from_thinker(x_opt: torch.Tensor | None) -> torch.Tensor:
def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **info_dict: dict):
# Containers to return per-request updates (e.g., code_predictor_hidden_per_request)
update_dict: dict[str, dict] = {}
# TODO(Peiqi): add voice_type support
voice_type = self.voice_type
# Per-request voice_type: read from additional_information passed by
# the stage input processor, falling back to the model default.
vt_raw = info_dict.get("voice_type", self.default_tts_text_spk_type)
# voice_type may arrive as a single-element list from serialization.
voice_type = vt_raw[0] if isinstance(vt_raw, list) and vt_raw else vt_raw
start_index = info_dict.get("num_processed_tokens", 0)
end_index = start_index + input_embeds.shape[0]
# Read thinker outputs for prefill
Expand Down
42 changes: 42 additions & 0 deletions vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,37 @@ def _ensure_list(x):
return list(x)


def _extract_voice_type(request_or_prompt) -> str | None:
"""Extract voice_type from a request's additional_information.

Supports both raw dict prompts and serialized OmniEngineCoreRequest.
Returns None if voice_type is not specified.
"""
# Raw dict prompt (non-async path)
if isinstance(request_or_prompt, dict):
ai = request_or_prompt.get("additional_information")
if isinstance(ai, dict):
vt = ai.get("voice_type")
# voice_type is stored as a list for serialization compatibility
if isinstance(vt, list) and len(vt) >= 1:
return str(vt[0])
if isinstance(vt, str):
return vt
return None

# OmniEngineCoreRequest (async_chunk path)
ai = getattr(request_or_prompt, "additional_information", None)
if ai is None:
return None
entries = getattr(ai, "entries", None)
if entries is None or "voice_type" not in entries:
return None
entry = entries["voice_type"]
if entry.list_data is not None and len(entry.list_data) >= 1:
return str(entry.list_data[0])
return None


def _validate_stage_inputs(stage_list, engine_input_source):
if not engine_input_source:
raise ValueError("engine_input_source cannot be empty")
Expand Down Expand Up @@ -97,6 +128,8 @@ def thinker2talker_async_chunk(
"""

request_id = request.external_req_id
# Extract per-request voice_type from the thinker request's additional_information.
voice_type = _extract_voice_type(request)
chunk_id = transfer_manager.put_req_chunk[request_id]
if chunk_id == 0:
all_token_ids = request.all_token_ids # prefill + decode
Expand All @@ -115,6 +148,8 @@ def thinker2talker_async_chunk(
"tts_pad_embed": pooling_output.get("tts_pad_embed").detach().cpu(),
"finished": torch.tensor(is_finished, dtype=torch.bool),
}
if voice_type is not None:
talker_additional_info["voice_type"] = voice_type
if transfer_manager.request_payload.get(request_id) is None:
if not is_finished:
transfer_manager.request_payload[request_id] = talker_additional_info
Expand Down Expand Up @@ -179,6 +214,9 @@ def thinker2talker(

device = torch.device(current_platform.device_type)

# Extract per-request voice_type from the original prompt.
voice_type = _extract_voice_type(prompt)

# Process each thinker output
for thinker_output in thinker_outputs:
output = thinker_output.outputs[0]
Expand All @@ -195,6 +233,10 @@ def thinker2talker(
"tts_eos_embed": output.multimodal_output["tts_eos_embed"].detach().to(device=device, dtype=torch.float),
"tts_pad_embed": output.multimodal_output["tts_pad_embed"].detach().to(device=device, dtype=torch.float),
}
if voice_type is not None:
# Wrap in list for AdditionalInformationPayload serialization
# compatibility (only tensors and lists are supported).
info["voice_type"] = [voice_type]

prompt_len = _compute_talker_prompt_ids_length(info, device=device)

Expand Down
Loading