diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index 79f0f4a8def..ad9f30a2a9f 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -324,12 +324,18 @@ def forward( multimodal_outputs={"model_outputs": audios, "sr": srs}, ) - def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: Any) -> OmniOutput: + def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput | tuple, **kwargs: Any) -> OmniOutput: if isinstance(model_outputs, OmniOutput): return model_outputs + if isinstance(model_outputs, tuple) and len(model_outputs) == len(OmniOutput._fields): + return OmniOutput(*model_outputs) + if not (isinstance(model_outputs, tuple) and len(model_outputs) == 2): - raise TypeError(f"Qwen3TTSCode2Wav expected (audio_tensor, sr) outputs, got {type(model_outputs)}") + raise TypeError( + "Qwen3TTSCode2Wav expected OmniOutput, OmniOutput tuple, " + f"or (audio_tensor, sr) outputs, got {type(model_outputs)}" + ) audio_tensor, sr = model_outputs return OmniOutput(