Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions vllm_omni/deploy/mimo_audio.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# MiMo Audio deploy: 2-stage thinker+talker → code2wav.
#
# Default mode is async-chunk streaming on a single GPU (both stages on
Expand All @@ -17,8 +18,8 @@ connectors:
connector_get_sleep_s: 0.001
connector_get_max_wait_first_chunk: 3000
connector_get_max_wait: 300
codec_chunk_frames: 3
codec_left_context_frames: 3
codec_chunk_frames: 30
codec_left_context_frames: 40

stages:
- stage_id: 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@

logger = logging.getLogger(__name__)

# Minimum safe values for codec streaming parameters. Mirrors the constants
# in stage_input_processors/mimo_audio.py — keep in sync.
_MIN_CODEC_CHUNK_FRAMES = 3
_MIN_CODEC_LEFT_CONTEXT_FRAMES = 40 # must cover vocoder_attn_window_size[0]
_DEFAULT_CODEC_CHUNK_FRAMES = 10
_DEFAULT_CODEC_LEFT_CONTEXT_FRAMES = 40


def flat_codec_group_element_count(group_size: int, audio_channels: int) -> int:
"""Flat token count for one MiMo talker codec group on the code2wav wire.
Expand Down Expand Up @@ -485,12 +492,37 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
if connector_cfg
else None
)
self._codec_chunk_frames = int(extra_cfg.get("codec_chunk_frames", 3)) if isinstance(extra_cfg, dict) else 3
if self._codec_chunk_frames <= 0:
raise ValueError(f"codec_chunk_frames must be positive, got {self._codec_chunk_frames}")
self._codec_left_context_frames = (
int(extra_cfg.get("codec_left_context_frames", 3)) if isinstance(extra_cfg, dict) else 3
raw_chunk = (
int(extra_cfg.get("codec_chunk_frames", _DEFAULT_CODEC_CHUNK_FRAMES))
if isinstance(extra_cfg, dict)
else _DEFAULT_CODEC_CHUNK_FRAMES
)
if raw_chunk < _MIN_CODEC_CHUNK_FRAMES:
logger.warning(
"codec_chunk_frames=%d is below minimum %d; falling back to %d.",
raw_chunk,
_MIN_CODEC_CHUNK_FRAMES,
_DEFAULT_CODEC_CHUNK_FRAMES,
)
raw_chunk = _DEFAULT_CODEC_CHUNK_FRAMES
self._codec_chunk_frames = raw_chunk

raw_left = (
int(extra_cfg.get("codec_left_context_frames", _DEFAULT_CODEC_LEFT_CONTEXT_FRAMES))
if isinstance(extra_cfg, dict)
else _DEFAULT_CODEC_LEFT_CONTEXT_FRAMES
)
if raw_left < _MIN_CODEC_LEFT_CONTEXT_FRAMES:
logger.warning(
"codec_left_context_frames=%d is below minimum %d (must cover vocoder attention "
"window %s); falling back to %d to prevent voice instability.",
raw_left,
_MIN_CODEC_LEFT_CONTEXT_FRAMES,
getattr(self.config, "vocoder_attn_window_size", [40, 10]),
_DEFAULT_CODEC_LEFT_CONTEXT_FRAMES,
)
raw_left = _DEFAULT_CODEC_LEFT_CONTEXT_FRAMES
self._codec_left_context_frames = raw_left

def load_weights(
self,
Expand Down
18 changes: 15 additions & 3 deletions vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,8 +542,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
)

self.device = current_omni_platform.get_torch_device()
# global_sampler MUST stay greedy (do_sample=False) so its token decision
# matches vLLM's external sampler (SamplingParams temperature=0.0). Both
# run argmax on the same logits, so they always agree on whether the next
# token is <|empty|> (audio step) or a real text token. Enabling
# do_sample=True here without also routing vLLM's sampled token back into
# this gate check would cause the two to diverge and corrupt KV-cache state.
self.global_sampler = MiMoSampler(do_sample=False, temperature=0.6, top_p=0.95)
self.local_sampler = MiMoSampler(do_sample=False, temperature=0.9, top_p=0.95)
# local_sampler drives audio-code generation inside local_forward. It is
# entirely internal and NOT subject to vLLM's SamplingParams, so stochastic
# sampling is safe here and is required to produce natural, varied speech.
# Setting do_sample=True also disables the CUDA-graph path (use_cg gate in
# local_forward checks `do_sample is False`), preventing MiMoLocalSamplerTensor
# from silently forcing argmax even when temperature > 0.
self.local_sampler = MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)
self.removed_tokens = None

self.speech_vocab_sizes = config.parsed_speech_vocab_sizes()
Expand Down Expand Up @@ -806,7 +818,7 @@ def base_local_forward(
device=tokens_device,
)
if local_sampler is None:
local_sampler = MiMoSampler(do_sample=False, temperature=0.6, top_p=0.9)
local_sampler = MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)

past_key_values = DynamicCache()
for t in range(delay_iters):
Expand Down Expand Up @@ -852,7 +864,7 @@ def local_forward(
local_sampler: MiMoSampler | None = None,
):
if local_sampler is None:
local_sampler = MiMoSampler(do_sample=False, temperature=0.6, top_p=0.9)
local_sampler = MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)

b = int(local_embeds.shape[0])
use_cg = (local_sampler.do_sample is None or local_sampler.do_sample is False) and bool(
Expand Down
61 changes: 61 additions & 0 deletions vllm_omni/model_executor/stage_configs/mimo_audio.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@

# MiMo Audio deploy: 2-stage thinker+talker → code2wav.
#
# Default mode is async-chunk streaming on a single GPU (both stages on
# device 0) through SharedMemoryConnector. For the legacy 2-GPU sync
# pipeline (stage 1 on a second card), pass ``--no-async-chunk
# --stage-1-devices 1 --stage-1-max-model-len 18192
# --stage-1-max-num-batched-tokens 18192`` to the serve CLI.
async_chunk: true
dtype: bfloat16

connectors:
connector_of_shared_memory:
name: SharedMemoryConnector
extra:
shm_threshold_bytes: 365536
codec_streaming: true
connector_get_sleep_s: 0.001
connector_get_max_wait_first_chunk: 3000
connector_get_max_wait: 300
codec_chunk_frames: 30
codec_left_context_frames: 40

stages:
- stage_id: 0
max_num_seqs: 1
gpu_memory_utilization: 0.5
enforce_eager: true
trust_remote_code: true
enable_prefix_caching: false
max_num_batched_tokens: 8192
max_model_len: 8192
devices: "0"
output_connectors:
to_stage_1: connector_of_shared_memory
default_sampling_params:
temperature: 0.6
top_p: 0.95
top_k: 50
max_tokens: 18192
seed: 42
repetition_penalty: 1.1

- stage_id: 1
max_num_seqs: 1
gpu_memory_utilization: 0.4
enforce_eager: true
trust_remote_code: true
enable_prefix_caching: false
async_scheduling: false
max_num_batched_tokens: 8192
max_model_len: 8192
devices: "0"
input_connectors:
from_stage_0: connector_of_shared_memory
default_sampling_params:
temperature: 0.0
top_p: 1.0
top_k: -1
max_tokens: 18192
seed: 42
41 changes: 34 additions & 7 deletions vllm_omni/model_executor/stage_input_processors/mimo_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@
# example ``examples/offline_inference/mimo_audio/end2end.py``.
MAX_CODE2WAV_TOKENS = 18192

# Minimum safe values for codec streaming parameters.
# codec_left_context_frames must cover the vocoder attention window
# (vocoder_attn_window_size defaults to [40, 10]). Values below the minimum
# cause acoustic-state resets at chunk boundaries, producing voice instability
# (multiple speakers / timbre shifts in the output audio).
_MIN_CODEC_CHUNK_FRAMES = 3
_MIN_CODEC_LEFT_CONTEXT_FRAMES = 40
_DEFAULT_CODEC_CHUNK_FRAMES = 10
_DEFAULT_CODEC_LEFT_CONTEXT_FRAMES = 40


def prepend_and_flatten_colmajor(x: torch.Tensor, pad_vec: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -80,6 +90,10 @@ def _flush_remaining_codes(

length = len(accumulated)
chunk_length = length % chunk_size
# When the accumulated length aligns with chunk_size boundary (remainder == 0),
# we still need to flush the final chunk with full context to give the vocoder
# enough attention window — otherwise the tail audio cuts off and produces
# voice instability. Fall back to chunk_size as the context length.
context_length = chunk_length if chunk_length != 0 else chunk_size
end_index = min(length, left_context_size + context_length)

Expand Down Expand Up @@ -136,8 +150,26 @@ def llm2code2wav_async_chunk(
connector = getattr(transfer_manager, "connector", None)
raw_cfg = getattr(connector, "config", {}) or {}
cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {}
chunk_size = int(cfg.get("codec_chunk_frames", 3))
left_context_size = int(cfg.get("codec_left_context_frames", 3))
chunk_size = int(cfg.get("codec_chunk_frames", _DEFAULT_CODEC_CHUNK_FRAMES))
if chunk_size < _MIN_CODEC_CHUNK_FRAMES:
logger.warning(
"codec_chunk_frames=%d is below minimum %d; falling back to %d.",
chunk_size,
_MIN_CODEC_CHUNK_FRAMES,
_DEFAULT_CODEC_CHUNK_FRAMES,
)
chunk_size = _DEFAULT_CODEC_CHUNK_FRAMES

left_context_size = int(cfg.get("codec_left_context_frames", _DEFAULT_CODEC_LEFT_CONTEXT_FRAMES))
if left_context_size < _MIN_CODEC_LEFT_CONTEXT_FRAMES:
logger.warning(
"codec_left_context_frames=%d is below minimum %d (must cover vocoder attention window); "
"falling back to %d to prevent voice instability.",
left_context_size,
_MIN_CODEC_LEFT_CONTEXT_FRAMES,
_DEFAULT_CODEC_LEFT_CONTEXT_FRAMES,
)
left_context_size = _DEFAULT_CODEC_LEFT_CONTEXT_FRAMES

request_id = getattr(request, "external_req_id", None)

Expand All @@ -157,11 +189,6 @@ def llm2code2wav_async_chunk(
pad_vec = torch.tensor([TALKER_CODEC_PAD_TOKEN_ID] * 4, device=code_tensor.device, dtype=code_tensor.dtype)
code_list = prepend_and_flatten_colmajor(code_tensor, pad_vec).tolist()

if sum(code_list) == 0:
if is_finished:
return _flush_remaining_codes(transfer_manager, request_id, chunk_size, left_context_size)
return None

if request_id is None:
return None

Expand Down
Loading