diff --git a/examples/offline_inference/mimo_audio/end2end.py b/examples/offline_inference/mimo_audio/end2end.py index d3728dcc55b..ecdbc5cfb52 100644 --- a/examples/offline_inference/mimo_audio/end2end.py +++ b/examples/offline_inference/mimo_audio/end2end.py @@ -221,9 +221,13 @@ def main(args): # Notice: The audio files used in this example are available at: https://github.com/XiaomiMiMo/MiMo-Audio/tree/main/examples if args.query_type == "tts_sft": # python3 -u end2end.py --stage-configs-path ${config_file} --model ${MODEL_PATH} --query-type tts_sft + if text is None: + text = "The weather is so nice today." query_result = query_func(text=text, read_text_only=True) elif args.query_type == "tts_sft_with_instruct": # python3 -u end2end.py --stage-configs-path ${config_file} --model ${MODEL_PATH} --query-type tts_sft_with_instruct --instruct "Speak happily in a child's voice" + if text is None: + text = "The weather is so nice today." query_result = query_func(text=text, instruct=instruct, read_text_only=True) elif args.query_type == "tts_sft_with_audio": # python3 -u end2end.py --stage-configs-path ${config_file} --model ${MODEL_PATH} --query-type tts_sft_with_audio --audio_path "./spoken_dialogue_assistant_turn_1.wav" @@ -355,7 +359,7 @@ def parse_args(): "--text", "-t", type=str, - default="", + default=None, help="input text", ) parser.add_argument( diff --git a/vllm_omni/deploy/mimo_audio.yaml b/vllm_omni/deploy/mimo_audio.yaml index 6ebe5920957..a4d6fd56f6c 100644 --- a/vllm_omni/deploy/mimo_audio.yaml +++ b/vllm_omni/deploy/mimo_audio.yaml @@ -59,3 +59,27 @@ stages: top_k: -1 max_tokens: 18192 seed: 42 + +platforms: + xpu: + async_chunk: false + stages: + - stage_id: 0 + gpu_memory_utilization: 0.5 + enforce_eager: true + disable_hybrid_kv_cache_manager: true + enable_prefix_caching: false + skip_mm_profiling: true + max_num_batched_tokens: 8192 + max_model_len: 8192 + devices: "2" + - stage_id: 1 + gpu_memory_utilization: 0.35 + enforce_eager: true + disable_hybrid_kv_cache_manager: true + enable_prefix_caching: false + skip_mm_profiling: true + async_scheduling: false + max_num_batched_tokens: 8192 + max_model_len: 8192 + devices: "3" diff --git a/vllm_omni/model_executor/models/mimo_audio/mimo_audio.py b/vllm_omni/model_executor/models/mimo_audio/mimo_audio.py index ecf094c879c..9f9369c9cd8 100644 --- a/vllm_omni/model_executor/models/mimo_audio/mimo_audio.py +++ b/vllm_omni/model_executor/models/mimo_audio/mimo_audio.py @@ -349,6 +349,18 @@ def _parse_audio_data( class MiMoAudioLLMMultiModalProcessor(BaseMultiModalProcessor[MiMoAudioLLMProcessingInfo]): + def _hf_processor_applies_updates( + self, + prompt_text, + mm_items, + hf_processor_mm_kwargs, + tokenization_kwargs, + ) -> bool: + # MiMoAudio's _call_hf_processor does NOT expand <|empty|> placeholders + # into audio feature tokens. We must let vllm apply the updates itself + # via _apply_prompt_updates (requires is_update_applied=False). + return False + def _call_hf_processor( self, prompt: str, @@ -865,8 +877,10 @@ def generate_codes( def generate_audio(self, code: torch.Tensor): token2wav_dev = self._module_device(self.token2wav) # Check if in CUDA graph capture phase - is_capturing = torch.cuda.is_current_stream_capturing() - + if torch.cuda.is_available() and token2wav_dev.type == "cuda": + is_capturing = torch.cuda.is_current_stream_capturing() + else: + is_capturing = False if isinstance(code, torch.Tensor): if is_capturing: # During CUDA graph capture, avoid device movement operations diff --git a/vllm_omni/model_executor/models/mimo_audio/mimo_audio_code2wav.py b/vllm_omni/model_executor/models/mimo_audio/mimo_audio_code2wav.py index bd178afe667..4fc88818f46 100644 --- a/vllm_omni/model_executor/models/mimo_audio/mimo_audio_code2wav.py +++ b/vllm_omni/model_executor/models/mimo_audio/mimo_audio_code2wav.py @@ -645,7 +645,10 @@ def forward( ) per_left, per_chunk = self._mimo_codec_runtime_lists(len(request_ids_list), runtime_additional_information) - is_capturing = torch.cuda.is_current_stream_capturing() + if torch.cuda.is_available() and self.device.type == "cuda": + is_capturing = torch.cuda.is_current_stream_capturing() + else: + is_capturing = False if is_capturing: return OmniOutput( text_hidden_states=None, @@ -953,7 +956,10 @@ def _check_dummy_code_tensor(self, code_tensor: torch.Tensor) -> bool: def _decode_waveform_from_codes(self, code_tensor: torch.Tensor) -> torch.Tensor: # Check if in CUDA graph capture phase - is_capturing = torch.cuda.is_current_stream_capturing() + if torch.cuda.is_available() and self.device.type == "cuda": + is_capturing = torch.cuda.is_current_stream_capturing() + else: + is_capturing = False # During CUDA graph capture, return dummy tensor to avoid operations like .cpu() which are not allowed if is_capturing: diff --git a/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py b/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py index a640ade3aaf..d51c2d486bc 100644 --- a/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py +++ b/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py @@ -808,11 +808,11 @@ def base_local_forward( self, local_embeds: torch.FloatTensor, # [1, 1, hidden_size] tokens_dtype: torch.dtype = torch.int64, - tokens_device: torch.device = torch.device( - f"cuda:{torch.accelerator.current_device_index()}" if torch.cuda.is_available() else "cpu" - ), + tokens_device: torch.device | None = None, local_sampler: MiMoSampler | MiMoLocalSamplerTensor | None = None, ): + if tokens_device is None: + tokens_device = local_embeds.device B = local_embeds.shape[0] delay_iters = self.group_size + max(self.delay_pattern) @@ -862,11 +862,11 @@ def local_forward( self, local_embeds: torch.FloatTensor, # [1, 1, hidden_size] tokens_dtype: torch.dtype = torch.int64, - tokens_device: torch.device = torch.device( - f"cuda:{torch.accelerator.current_device_index()}" if torch.cuda.is_available() else "cpu" - ), + tokens_device: torch.device | None = None, local_sampler: MiMoSampler | None = None, ): + if tokens_device is None: + tokens_device = local_embeds.device if local_sampler is None: local_sampler = MiMoSampler(do_sample=False, temperature=0.9, top_p=0.95) @@ -1049,7 +1049,10 @@ def forward( else: request_ids = [str(i) for i in range(len(query_start_loc[1:]))] if query_start_loc is not None else [] num_reqs = len(request_ids) - is_capturing = torch.cuda.is_current_stream_capturing() + if torch.cuda.is_available() and input_ids.device.type == "cuda": + is_capturing = torch.cuda.is_current_stream_capturing() + else: + is_capturing = False merge_mm_embedding_info, has_merge_mm_embedding, kwargs = self._collect_merge_mm_embedding_info( input_ids,