diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py index 664d2a2957c..41832267202 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py @@ -38,6 +38,18 @@ logger = init_logger(__name__) +_TASK_TYPE_CANONICAL: dict[str, str] = { + "customvoice": "CustomVoice", + "voicedesign": "VoiceDesign", + "base": "Base", +} + + +def _normalize_task_type(raw: str) -> str: + """Normalize task type string to its canonical PascalCase form.""" + return _TASK_TYPE_CANONICAL.get(raw.lower(), raw) + + AudioLike = ( str # wav path, URL, base64 | np.ndarray # waveform (requires sr) @@ -81,7 +93,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): torch_dtype=torch.bfloat16, **attn_kwargs, ) - self.task_type = model_path.split("-")[-1].split("/")[0] + self.task_type = _normalize_task_type(model_path.split("-")[-1].split("/")[0]) # Mark that this model produces multimodal outputs self.have_multimodal_outputs = True @@ -116,8 +128,8 @@ def forward( if isinstance(runtime_additional_information, list) and len(runtime_additional_information) > 0: runtime_additional_information = runtime_additional_information[0] text = runtime_additional_information.pop("text", [""])[0] - # Extract task_type from kwargs, default to "instruct" - task_type = runtime_additional_information.pop("task_type", [self.task_type])[0] + # Extract task_type from kwargs, default to self.task_type + task_type = _normalize_task_type(runtime_additional_information.pop("task_type", [self.task_type])[0]) speaker = runtime_additional_information.pop("speaker", ["uncle_fu"])[0] language = runtime_additional_information.pop("language", ["Auto"])[0] instruct = runtime_additional_information.pop("instruct", [""])[0]