Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/offline_inference/mimo_audio/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,13 @@ def main(args):
# Notice: The audio files used in this example are available at: https://github.com/XiaomiMiMo/MiMo-Audio/tree/main/examples
if args.query_type == "tts_sft":
# python3 -u end2end.py --stage-configs-path ${config_file} --model ${MODEL_PATH} --query-type tts_sft
if text is None:
text = "The weather is so nice today."
query_result = query_func(text=text, read_text_only=True)
elif args.query_type == "tts_sft_with_instruct":
# python3 -u end2end.py --stage-configs-path ${config_file} --model ${MODEL_PATH} --query-type tts_sft_with_instruct --instruct "Speak happily in a child's voice"
if text is None:
text = "The weather is so nice today."
query_result = query_func(text=text, instruct=instruct, read_text_only=True)
elif args.query_type == "tts_sft_with_audio":
# python3 -u end2end.py --stage-configs-path ${config_file} --model ${MODEL_PATH} --query-type tts_sft_with_audio --audio_path "./spoken_dialogue_assistant_turn_1.wav"
Expand Down Expand Up @@ -355,7 +359,7 @@ def parse_args():
"--text",
"-t",
type=str,
default="",
default=None,
help="input text",
)
parser.add_argument(
Expand Down
24 changes: 24 additions & 0 deletions vllm_omni/deploy/mimo_audio.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,27 @@ stages:
top_k: -1
max_tokens: 18192
seed: 42

platforms:
xpu:
async_chunk: false
stages:
- stage_id: 0
gpu_memory_utilization: 0.5
enforce_eager: true
disable_hybrid_kv_cache_manager: true
enable_prefix_caching: false
skip_mm_profiling: true
max_num_batched_tokens: 8192
max_model_len: 8192
devices: "2"
- stage_id: 1
gpu_memory_utilization: 0.35
enforce_eager: true
disable_hybrid_kv_cache_manager: true
enable_prefix_caching: false
skip_mm_profiling: true
async_scheduling: false
max_num_batched_tokens: 8192
max_model_len: 8192
devices: "3"
18 changes: 16 additions & 2 deletions vllm_omni/model_executor/models/mimo_audio/mimo_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,18 @@ def _parse_audio_data(


class MiMoAudioLLMMultiModalProcessor(BaseMultiModalProcessor[MiMoAudioLLMProcessingInfo]):
def _hf_processor_applies_updates(
self,
prompt_text,
mm_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
) -> bool:
# MiMoAudio's _call_hf_processor does NOT expand <|empty|> placeholders
# into audio feature tokens. We must let vllm apply the updates itself
# via _apply_prompt_updates (requires is_update_applied=False).
return False

def _call_hf_processor(
self,
prompt: str,
Expand Down Expand Up @@ -865,8 +877,10 @@ def generate_codes(
def generate_audio(self, code: torch.Tensor):
token2wav_dev = self._module_device(self.token2wav)
# Check if in CUDA graph capture phase
is_capturing = torch.cuda.is_current_stream_capturing()

if torch.cuda.is_available() and token2wav_dev.type == "cuda":
is_capturing = torch.cuda.is_current_stream_capturing()
else:
is_capturing = False
if isinstance(code, torch.Tensor):
if is_capturing:
# During CUDA graph capture, avoid device movement operations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,10 @@ def forward(
)
per_left, per_chunk = self._mimo_codec_runtime_lists(len(request_ids_list), runtime_additional_information)

is_capturing = torch.cuda.is_current_stream_capturing()
if torch.cuda.is_available() and self.device.type == "cuda":
is_capturing = torch.cuda.is_current_stream_capturing()
else:
is_capturing = False
if is_capturing:
return OmniOutput(
text_hidden_states=None,
Expand Down Expand Up @@ -953,7 +956,10 @@ def _check_dummy_code_tensor(self, code_tensor: torch.Tensor) -> bool:

def _decode_waveform_from_codes(self, code_tensor: torch.Tensor) -> torch.Tensor:
# Check if in CUDA graph capture phase
is_capturing = torch.cuda.is_current_stream_capturing()
if torch.cuda.is_available() and self.device.type == "cuda":
is_capturing = torch.cuda.is_current_stream_capturing()
else:
is_capturing = False

# During CUDA graph capture, return dummy tensor to avoid operations like .cpu() which are not allowed
if is_capturing:
Expand Down
17 changes: 10 additions & 7 deletions vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,11 +808,11 @@ def base_local_forward(
self,
local_embeds: torch.FloatTensor, # [1, 1, hidden_size]
tokens_dtype: torch.dtype = torch.int64,
tokens_device: torch.device = torch.device(
f"cuda:{torch.accelerator.current_device_index()}" if torch.cuda.is_available() else "cpu"
),
tokens_device: torch.device | None = None,
local_sampler: MiMoSampler | MiMoLocalSamplerTensor | None = None,
):
if tokens_device is None:
tokens_device = local_embeds.device
B = local_embeds.shape[0]
delay_iters = self.group_size + max(self.delay_pattern)

Expand Down Expand Up @@ -862,11 +862,11 @@ def local_forward(
self,
local_embeds: torch.FloatTensor, # [1, 1, hidden_size]
tokens_dtype: torch.dtype = torch.int64,
tokens_device: torch.device = torch.device(
f"cuda:{torch.accelerator.current_device_index()}" if torch.cuda.is_available() else "cpu"
),
tokens_device: torch.device | None = None,
local_sampler: MiMoSampler | None = None,
):
if tokens_device is None:
tokens_device = local_embeds.device
if local_sampler is None:
local_sampler = MiMoSampler(do_sample=False, temperature=0.9, top_p=0.95)

Expand Down Expand Up @@ -1049,7 +1049,10 @@ def forward(
else:
request_ids = [str(i) for i in range(len(query_start_loc[1:]))] if query_start_loc is not None else []
num_reqs = len(request_ids)
is_capturing = torch.cuda.is_current_stream_capturing()
if torch.cuda.is_available() and input_ids.device.type == "cuda":
is_capturing = torch.cuda.is_current_stream_capturing()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't other platforms support torch.xxx.is_current_stream_capturing()?

else:
is_capturing = False

merge_mm_embedding_info, has_merge_mm_embedding, kwargs = self._collect_merge_mm_embedding_info(
input_ids,
Expand Down
Loading