From 46f289162e4c2d7de32c63ca8d2b3b94c78483e9 Mon Sep 17 00:00:00 2001 From: Rein Yang Date: Thu, 19 Mar 2026 07:14:59 +0000 Subject: [PATCH] revert 1758 which introduce qwen3-omni percision problem Signed-off-by: Rein Yang --- .../models/qwen3_omni/qwen3_omni.py | 50 +- .../qwen3_omni_moe_code_predictor_mtp.py | 554 ++++++++++-------- .../qwen3_omni/qwen3_omni_moe_talker.py | 115 ++-- vllm_omni/worker/gpu_model_runner.py | 9 +- 4 files changed, 428 insertions(+), 300 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 6dcd278acae..ec928195ab7 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -159,12 +159,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # suppress tokens by setting their probability to ~1e-9 (finite very small) self.suppressed_tokens = self._get_talker_suppressed_tokens() self.requires_raw_input_tokens = True - # Keys that should stay on GPU in model_intermediate_buffer to avoid CPU↔GPU round-trips - self.gpu_resident_buffer_keys: set[str] = { - "last_talker_hidden", - "trailing_text_hidden", - "tts_pad_embed_projected", - } elif self.model_stage == "code2wav": self.enable_update_additional_information = True @@ -229,16 +223,14 @@ def embed_multimodal(self, **kwargs): # ==================== Forward Pass ==================== def _get_talker_suppressed_tokens(self): - """Return a boolean mask on GPU for suppressed token positions.""" - vocab_size = self.config.talker_config.text_config.vocab_size - mask = torch.zeros(vocab_size, dtype=torch.bool) - start = vocab_size - 1024 - eos_id = self.config.talker_config.codec_eos_token_id - for i in range(start, vocab_size): - if i != eos_id: - mask[i] = True - # Will be moved to the correct device on first use - return mask + return [ + i + for i in range( + self.config.talker_config.text_config.vocab_size - 1024, + self.config.talker_config.text_config.vocab_size, + ) + if i != self.config.talker_config.codec_eos_token_id + ] def get_mrope_input_positions( self, @@ -586,7 +578,7 @@ def talker_postprocess(self, hidden_states: torch.Tensor, **info_dict: object): Postprocess the talker hidden states. """ update_dict = {} - update_dict["last_talker_hidden"] = hidden_states[-1, :].detach() + update_dict["last_talker_hidden"] = hidden_states[-1, :].detach().to("cpu").contiguous() return update_dict def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **info_dict: dict): @@ -640,9 +632,9 @@ def talker_mtp( if inputs_embeds.shape[-1] == 2048: inputs_embeds = self.text_projection(inputs_embeds) code_predictor_codes, summed_embeddings = self.talker.code_predictor_forward( - input_ids, inputs_embeds, last_talker_hidden=last_talker_hidden + input_ids, inputs_embeds.clone(), last_talker_hidden=last_talker_hidden ) - inputs_embeds = summed_embeddings + inputs_embeds = summed_embeddings.clone() inputs_embeds = (inputs_embeds + text_step).reshape(-1, self.talker_config.text_config.hidden_size) return inputs_embeds, code_predictor_codes.squeeze(-1) @@ -766,7 +758,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch # compatible with old shape [1,S,D] rem_tail = trailing_text_hidden.squeeze(0) if rem_tail.shape[0] > 0: - update_dict["trailing_text_hidden"] = rem_tail.detach() + update_dict["trailing_text_hidden"] = rem_tail.detach().to("cpu").contiguous() # Also persist projected tts_pad for decode fallback if needed if isinstance(tts_pad_thinker, torch.Tensor): pad_in = tts_pad_thinker @@ -775,7 +767,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch if pad_in.ndim == 1: pad_in = pad_in.view(1, 1, -1) pad_proj = self.talker.text_projection(pad_in.to(self._module_device(self.talker))) - update_dict["tts_pad_embed_projected"] = pad_proj.detach() + update_dict["tts_pad_embed_projected"] = pad_proj.detach().to("cpu").contiguous() except Exception: pass self._talker_cache_thinker_decode_embeds(info_dict, update_dict) @@ -934,7 +926,7 @@ def talker_preprocess_decode( if isinstance(q_tail, torch.Tensor) and q_tail.numel() > 0: use_vec = q_tail[0:1, :] new_q_tail = ( - q_tail[1:, :].detach() + q_tail[1:, :].detach().to("cpu").contiguous() if q_tail.shape[0] > 1 else self.tts_pad_embed.to(input_embeds.device, dtype=input_embeds.dtype) ) @@ -1130,10 +1122,16 @@ def compute_logits( # implemented by assigning their logits to log(1e-9). if getattr(self, "model_stage", None) == "talker" and isinstance(logits, torch.Tensor): - # Move mask to device once (lazy), then reuse every step - if self.suppressed_tokens.device != logits.device: - self.suppressed_tokens = self.suppressed_tokens.to(logits.device) - logits.masked_fill_(self.suppressed_tokens.unsqueeze(0), -1e9) + try: + logits_cpu = logits.cpu() + logits_cpu[:, self.suppressed_tokens] = -1e9 + logits = logits_cpu.to(logits.device) + except Exception as e: + print(f"Error in logits suppression: {e}") + print(f"logits.shape: {logits.shape}") + print(f"self.suppressed_tokens: {self.suppressed_tokens}") + raise e + logits[:, self.suppressed_tokens] = -1e9 return logits def sample( diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py index 2b823364f65..fc7402890ab 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py @@ -1,16 +1,27 @@ -"""Qwen3-Omni Code Predictor -- optimized re-prefill, no KV cache. +"""Qwen3-Omni Code Predictor with MTP (Multi-Token Prediction) support. -* SDPA attention (F.scaled_dot_product_attention) -- no HF backend fallback -* Persistent pre-allocated buffers (_proj_buf, _pos_ids) -- zero per-call alloc -* Inline top-k sampling -- no LogitsProcessorList / custom-op overhead -* torch.compile on inner transformer by default -* No @support_torch_compile / static_forward_context / namedtuple +This module implements the code predictor component for Qwen3-Omni talker models. + +The code predictor generates residual RVQ (Residual Vector Quantization) codes +autoregressively, predicting layers 1 to N based on layer-0 codes from the talker. """ +from collections import namedtuple +from typing import Any + import torch import torch.nn as nn import torch.nn.functional as F -from vllm.config import VllmConfig +from transformers import Cache, PretrainedConfig +from transformers.generation.logits_process import ( + LogitsProcessorList, + TopKLogitsWarper, + TopPLogitsWarper, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -21,53 +32,47 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding - -from vllm_omni.platforms import current_omni_platform +from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) - -# =================================================================== -# Standalone Attention (SDPA, no KV cache, no HF backend fallback) -# =================================================================== +# ============================================================================ +# Code Predictor Attention Layer +# ============================================================================ class Qwen3OmniCodePredictorAttention(nn.Module): - """Multi-head self-attention for code predictor. - - Uses ``F.scaled_dot_product_attention`` directly. No KV cache -- the code - predictor always re-prefills the full (short) sequence each AR step. - - Input : [B, seq_len, hidden_size] - Output: [B, seq_len, hidden_size] - """ + """Multi-head self-attention for code predictor with vLLM optimization.""" def __init__( self, config, + layer_idx: int, + vllm_config: VllmConfig = None, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() - cp_cfg = config.code_predictor_config - self.num_heads = cp_cfg.num_attention_heads - self.num_kv_heads = cp_cfg.num_key_value_heads + + self.num_heads = config.code_predictor_config.num_attention_heads + self.num_key_value_heads = config.code_predictor_config.num_key_value_heads self.head_dim = getattr( - cp_cfg, + config.code_predictor_config, "head_dim", - cp_cfg.hidden_size // cp_cfg.num_attention_heads, + config.code_predictor_config.hidden_size // config.code_predictor_config.num_attention_heads, ) - self.hidden_size = cp_cfg.hidden_size - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self._use_gqa = self.num_kv_heads != self.num_heads + self.hidden_size = config.code_predictor_config.hidden_size + + if self.num_heads % self.num_key_value_heads != 0: + raise ValueError("num_attention_heads must be divisible by num_key_value_heads") + + self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.qkv_proj = QKVParallelLinear( hidden_size=self.hidden_size, head_size=self.head_dim, total_num_heads=self.num_heads, - total_num_kv_heads=self.num_kv_heads, + total_num_kv_heads=self.num_key_value_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", @@ -83,55 +88,120 @@ def __init__( ) self.rotary_emb = get_rope( self.head_dim, - max_position=cp_cfg.max_position_embeddings, + max_position=config.code_predictor_config.max_position_embeddings, rope_parameters=None, dual_chunk_attention_config=None, ) - self.q_norm = RMSNorm(self.head_dim, eps=cp_cfg.rms_norm_eps) - self.k_norm = RMSNorm(self.head_dim, eps=cp_cfg.rms_norm_eps) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_key_value_heads * self.head_dim + + # Query/Key normalization + self.q_norm = RMSNorm(self.head_dim, eps=config.code_predictor_config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.code_predictor_config.rms_norm_eps) + self.is_causal = True + self.config = config + + self.attention_backends = ["flash_attention_2", "xformers", "eager", "sdpa"] + cudagraph_mode = get_current_vllm_config().compilation_config.cudagraph_mode + if "flash_attention_2" in ALL_ATTENTION_FUNCTIONS and cudagraph_mode.has_full_cudagraphs(): + logger.warning( + f"CUDAGraphMode.{cudagraph_mode.name} is currently not supported " + f"with flash attention for Qwen3-Omni talker MTP." + f"removing flash attention from attention_backends" + ) + self.attention_backends.remove("flash_attention_2") def forward( self, hidden_states: torch.Tensor, - position_ids: torch.Tensor, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + use_cache: bool = False, + position_ids: torch.LongTensor | None = None, ) -> torch.Tensor: bsz, seq_len, _ = hidden_states.shape - # Flatten to 2-D so vLLM rotary_emb gets [num_tokens, size] - qkv, _ = self.qkv_proj(hidden_states.reshape(bsz * seq_len, -1)) + qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # QK-norm -> RoPE (both 2-D) - q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(q.shape) - k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(k.shape) + # Reshape for attention + q = q.reshape(bsz, seq_len, self.num_heads, self.head_dim) + k = k.reshape(bsz, seq_len, self.num_key_value_heads, self.head_dim) + v = v.reshape(bsz, seq_len, self.num_key_value_heads, self.head_dim) + + # Apply normalization + q = self.q_norm(q).contiguous() + k = self.k_norm(k).contiguous() + q = q.reshape(-1, self.q_size) + k = k.reshape(-1, self.kv_size) + + # Apply RoPE q, k = self.rotary_emb(position_ids, q, k) - # [B, heads, seq, head_dim] for SDPA - q = q.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - k = k.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = v.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - - attn_out = F.scaled_dot_product_attention( - q, - k, - v, - scale=self.scaling, - is_causal=True, - enable_gqa=self._use_gqa, - ) + # Reshape for attention + q = q.reshape(bsz, seq_len, self.num_heads, self.head_dim) + k = k.reshape(bsz, seq_len, self.num_key_value_heads, self.head_dim) + + v_heads = v.transpose(1, 2).contiguous() + q_heads = q.transpose(1, 2).contiguous() + k_heads = k.transpose(1, 2).contiguous() + + if past_key_values is not None: + sin, cos = self.rotary_emb.get_cos_sin(seq_len) + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k_heads, v_heads = past_key_values.update(k_heads, v_heads, self.layer_idx, cache_kwargs) + + # Try attention backends in order of preference, with runtime error handling + # This handles cases where the backend is registered but not actually available + attn_output = None + last_error = None + + for backend_name in self.attention_backends: + if backend_name not in ALL_ATTENTION_FUNCTIONS: + continue + + try: + attention_interface = ALL_ATTENTION_FUNCTIONS[backend_name] + attn_output, _ = attention_interface( + self, + q_heads, + k_heads, + v_heads, + None, + dropout=0.0 if not self.training else getattr(self, "attention_dropout", 0.0), + scaling=self.head_dim**-0.5, + sliding_window=None, + use_cache=use_cache, + position_ids=position_ids[:seq_len].unsqueeze(0), + output_hidden_states=True, + output_attentions=False, + ) + break + except (ValueError, ImportError, RuntimeError, AttributeError) as e: + # Store error and try next backend + last_error = e + continue + + if attn_output is None: + raise RuntimeError( + f"All attention backends failed. Last error: {last_error}. " + "Please install flash-attn, or ensure PyTorch's scaled_dot_product_attention is available." + ) + attn_output = attn_output.reshape(*(hidden_states.shape[:-1]), -1).contiguous() - attn_out = attn_out.transpose(1, 2).reshape(bsz * seq_len, -1) - output, _ = self.o_proj(attn_out) - return output.view(bsz, seq_len, -1) + attn_output, _ = self.o_proj(attn_output) + return attn_output -# =================================================================== -# MLP -# =================================================================== +# ============================================================================ +# Code Predictor MLP Layer +# ============================================================================ class Qwen3OmniCodePredictorMLP(nn.Module): - """SiLU-gated MLP for code predictor.""" + """Feed-forward network for code predictor with fused gate/up projection.""" def __init__( self, @@ -151,6 +221,7 @@ def __init__( prefix=f"{prefix}.gate_up_proj", disable_tp=True, ) + self.down_proj = RowParallelLinear( input_size=intermediate_size, output_size=hidden_size, @@ -167,23 +238,35 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return down -# =================================================================== -# Decoder Layer -# =================================================================== +# ============================================================================ +# MTP Layer (Multi-Token Prediction Layer) +# ============================================================================ -class Qwen3OmniCodePredictorDecoderLayer(nn.Module): - """Transformer decoder layer (SDPA, no KV cache).""" +class Qwen3OmniCodePredictorMTPLayer(nn.Module): + """MTP layer for speculative decoding - predicts next residual code layer.""" def __init__( self, - config, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + layer_idx: int, + cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, - prefix: str = "", ) -> None: super().__init__() + self.layer_idx = layer_idx + self.config = config + self.self_attn = Qwen3OmniCodePredictorAttention( config, + layer_idx, + vllm_config=type( + "VllmConfig", + (), + {"cache_config": cache_config, "quant_config": quant_config, "model_config": model_config}, + )(), quant_config=quant_config, prefix=f"{prefix}.self_attn", ) @@ -192,247 +275,260 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) - cp_cfg = config.code_predictor_config - self.input_layernorm = RMSNorm(cp_cfg.hidden_size, eps=cp_cfg.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(cp_cfg.hidden_size, eps=cp_cfg.rms_norm_eps) + self.input_layernorm = RMSNorm( + config.code_predictor_config.hidden_size, eps=config.code_predictor_config.rms_norm_eps + ) + self.post_attention_layernorm = RMSNorm( + config.code_predictor_config.hidden_size, eps=config.code_predictor_config.rms_norm_eps + ) - def forward( + def mtp_block( self, hidden_states: torch.Tensor, - position_ids: torch.Tensor, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + use_cache: bool = False, + position_ids: torch.LongTensor | None = None, ) -> torch.Tensor: + # Self-attention with residual residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn(hidden_states, position_ids) + hidden_states = self.self_attn(hidden_states, past_key_values, cache_position, use_cache, position_ids) hidden_states = residual + hidden_states + # MLP with residual residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return hidden_states - -# =================================================================== -# Base Transformer Model (re-prefill, no KV cache) -# =================================================================== + return hidden_states class Qwen3OmniCodePredictorBaseModel(nn.Module): - """Inner transformer for code predictor. + """ + Base model for code predictor - matches HF Qwen3OmniMoeTalkerCodePredictorModel structure. - Signature: ``forward(inputs_embeds, position_ids) -> hidden_states`` - -- plain Tensor in, plain Tensor out (no namedtuple). + This is a simple transformer that processes inputs_embeds and outputs hidden states. """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config.code_predictor_config + self.config = config + self.vocab_size = config.vocab_size + self.num_code_groups = config.num_code_groups + # Codec embeddings (for layers 1-num_code_groups-1) self.codec_embedding = nn.ModuleList( - [VocabParallelEmbedding(config.vocab_size, config.hidden_size) for _ in range(config.num_code_groups - 1)] + [ + VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + for _ in range(config.num_code_groups - 1) + ] ) + # Decoder layers self.layers = nn.ModuleList( [ - Qwen3OmniCodePredictorDecoderLayer( + Qwen3OmniCodePredictorMTPLayer( vllm_config.model_config.hf_config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + layer_idx=idx, + cache_config=vllm_config.cache_config, quant_config=vllm_config.quant_config, - prefix=f"{prefix}.layers.{idx}", ) for idx in range(config.num_hidden_layers) ] ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, inputs_embeds: torch.Tensor, - position_ids: torch.Tensor, - ) -> torch.Tensor: + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Any | None = None, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + **kwargs: Any, + ) -> Any: + """ + Forward pass matching HF structure. + + Args: + inputs_embeds: [batch, seq_len, hidden_size] + position_ids: Optional position IDs tensor + past_key_values: Optional cached key-value pairs + use_cache: Whether to use cache + cache_position: Optional cache position tensor + **kwargs: Additional keyword arguments + + Returns: + Named tuple with .last_hidden_state and .past_key_values attributes + """ + batch_size, seq_len, _ = inputs_embeds.shape + # Forward through decoder layers hidden_states = inputs_embeds + for layer in self.layers: - hidden_states = layer(hidden_states, position_ids) + hidden_states = layer.mtp_block(hidden_states, past_key_values, cache_position, use_cache, position_ids) + + # Final norm hidden_states = self.norm(hidden_states) - return hidden_states + + # Return in HF-compatible format + Output = namedtuple("Output", ["last_hidden_state", "past_key_values"]) + return Output(last_hidden_state=hidden_states, past_key_values=None) # [batch, num_code_groups-1, hidden_size] + + def get_input_embeddings(self): + """Return codec embeddings for HF compatibility.""" + return self.codec_embedding + + +def code_predictor_sample( + logits: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + forward_context = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + logits = self.logits_processors(None, logits[:, -1]) + probs = F.softmax(logits, dim=-1) + code = torch.multinomial(probs.squeeze(1), num_samples=1) # [batch, 1] + return code -# =================================================================== -# Code Predictor Wrapper (optimized re-prefill, persistent buffers) -# =================================================================== +def code_predictor_sample_fake( + logits: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + return torch.empty((logits.shape[0], 1), dtype=torch.int64, device=logits.device) +direct_register_custom_op( + op_name="qwen3_omni_code_predictor_sample", + op_func=code_predictor_sample, + fake_impl=code_predictor_sample_fake, +) + + +@support_torch_compile class Qwen3OmniMoeTalkerCodePredictor(nn.Module): - """Optimized code predictor -- re-prefill approach, no KV cache. - - Each AR step forwards the full growing sequence (len 2 -> num_code_groups+1) - through the transformer. The extra O(T^2) FLOPs are negligible for - short sequences, and this avoids all KV-cache management overhead. - - Optimizations: - 1. Pre-allocated embedding buffer -- no torch.cat per step. - 2. Pre-allocated position_ids -- no torch.arange per step. - 3. Inline top-k sampling -- no LogitsProcessorList / custom op. - 4. Cached module references -- bypass ModuleList indexing. - 5. torch.compile on inner transformer. + """ + Code predictor wrapper matching HF structure. + + Structure: + - self.model: Qwen3OmniCodePredictorBaseModel (transformer) + - self.lm_head: ModuleList of output heads """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config - self.config = config + talker_code_predictor_config = vllm_config.model_config.hf_config self.quant_config = vllm_config.quant_config self.prefix = prefix - self.num_code_groups = config.code_predictor_config.num_code_groups - self._hidden_size = config.code_predictor_config.hidden_size + self.config = talker_code_predictor_config + self.vocab_size = self.config.code_predictor_config.vocab_size + self.num_code_groups = self.config.code_predictor_config.num_code_groups - self.model = Qwen3OmniCodePredictorBaseModel( - vllm_config=vllm_config, - prefix=prefix, - ) + # Base transformer model (matches HF structure) + self.model = Qwen3OmniCodePredictorBaseModel(vllm_config=vllm_config, prefix=prefix) - # One lm_head per residual layer (layers 1 .. G-1) + # Output heads for each residual layer (1-num_layers-1) self.lm_head = nn.ModuleList( [ nn.Linear( - config.code_predictor_config.hidden_size, - config.code_predictor_config.vocab_size, + self.config.code_predictor_config.hidden_size, + self.config.code_predictor_config.vocab_size, bias=False, ) for _ in range(self.num_code_groups - 1) ] ) - - # Sampling hyperparams (inlined) - self._top_k = 50 - - # Persistent buffers (lazily initialised on first forward) - self._proj_buf: torch.Tensor | None = None - self._pos_ids: torch.Tensor | None = None - - # Cached plain-list refs (set once) - self._lm_heads: list | None = None - self._codec_embeds: list | None = None - - # Model forward (optionally compiled) - self._model_fwd: object | None = None - - # ------------------------------------------------------------------ - # Lazy-init helpers - # ------------------------------------------------------------------ - - def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None: - max_seq = self.num_code_groups + 1 - if ( - self._proj_buf is not None - and self._proj_buf.shape[0] >= bsz - and self._proj_buf.device == device - and self._proj_buf.dtype == dtype - ): - return - self._proj_buf = torch.zeros(bsz, max_seq, self._hidden_size, dtype=dtype, device=device) - self._pos_ids = torch.arange(max_seq, dtype=torch.long, device=device) - - def _ensure_cached_refs(self) -> None: - if self._lm_heads is not None: - return - self._lm_heads = list(self.lm_head) - self._codec_embeds = list(self.model.codec_embedding) - - def _ensure_model_fwd(self) -> None: - if self._model_fwd is not None: - return - if not current_omni_platform.supports_torch_inductor(): - logger.warning_once("code_predictor: torch.compile disabled") - self._model_fwd = self.model.forward - return - self._model_fwd = torch.compile( - self.model.forward, - mode="default", - dynamic=True, + self.logits_processors = LogitsProcessorList( + [ + TopKLogitsWarper(top_k=50), + TopPLogitsWarper(top_p=0.8), + ] ) - logger.info("code_predictor: torch.compile enabled") - # ------------------------------------------------------------------ - # Forward -- re-prefill + persistent buffers + inline sampling - # ------------------------------------------------------------------ + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix - @torch.inference_mode() def forward( self, layer0_code: torch.Tensor, layer0_embed: torch.Tensor, last_talker_hidden: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - """Predict residual codebooks 1..G-1 autoregressively via re-prefill. + """ + Forward pass for code predictor. Args: - layer0_code: [bsz, 1] int64 - layer0_embed: [bsz, 1, hidden_size] - last_talker_hidden: [bsz, 1, hidden_size] + layer0_code: + Code index for code-group (layer) 0. + Shape: [batch_size, 1], dtype typically int64. + + last_talker_hidden: + + Shape: [batch_size, hidden_size]. Returns: - all_codes: [bsz, num_code_groups, 1] - proj_buf: [bsz, num_code_groups + 1, hidden_size] - pos 0 = last_talker_hidden (NOT a codec embed) - pos 1 = layer0_embed - pos 2.. = codec_embedding[i](predicted_code_i) + pos_all_layers: + Predicted codes for all code groups, including `layer0_code`. + Shape: [batch_size, num_code_groups, 1]. + + current_input: + The final input embedding sequence after appending embeddings of all + predicted codes (one token per predicted layer). + Shape: [batch_size, num_code_groups + 2, hidden_size]. """ - bsz = int(layer0_code.shape[0]) - device = layer0_code.device - dtype = last_talker_hidden.dtype - num_groups = self.num_code_groups - top_k = self._top_k - - # Lazy init - self._ensure_buffers(bsz, device, dtype) - self._ensure_model_fwd() - self._ensure_cached_refs() - - proj_buf = self._proj_buf - pos_ids = self._pos_ids - model_fwd = self._model_fwd - lm_heads = self._lm_heads - codec_embeds = self._codec_embeds - - # Output codes - all_codes = torch.empty(bsz, num_groups, 1, dtype=torch.int64, device=device) - all_codes[:, 0] = layer0_code - - # Fill buffer positions 0 & 1 - proj_buf[:bsz, 0:1, :] = last_talker_hidden - proj_buf[:bsz, 1:2, :] = layer0_embed - - # Autoregressive loop: predict layers 1..G-1 - for step in range(1, num_groups): - seq_len = step + 1 - projected = proj_buf[:bsz, :seq_len, :] - step_pos_ids = pos_ids[:seq_len] if bsz == 1 else pos_ids[:seq_len].repeat(bsz) - - hidden_out = model_fwd(projected, step_pos_ids) - - # Inline top-k sampling - logits = lm_heads[step - 1](hidden_out[:, -1, :]) - if top_k > 0: - topk_vals, _ = logits.topk(top_k, dim=-1) - logits = logits.masked_fill(logits < topk_vals[:, -1:], float("-inf")) - probs = F.softmax(logits, dim=-1) - code = torch.multinomial(probs, num_samples=1) # [bsz, 1] - - all_codes[:, step] = code - - # Embed predicted code -> next buffer position - new_embed = codec_embeds[step - 1](code) - proj_buf[:bsz, step + 1 : step + 2, :] = new_embed - - return all_codes, proj_buf[:bsz] - - # ------------------------------------------------------------------ - # Weight loading - # ------------------------------------------------------------------ + pos_codes = [layer0_code] # Start with layer 0: [batch, 1] + try: + current_input = torch.cat([last_talker_hidden, layer0_embed], dim=1) # [batch, 2, hidden_size] + except Exception as e: + print(f"Error in current_input: {e}") + print(f"last_talker_hidden shape: {last_talker_hidden.shape}") + print(f"prev_embed shape: {layer0_embed.shape}") + raise e + batch_size = current_input.shape[0] + + # Predict all residual layers (layers 1 to num_code_groups-1) autoregressively + for layer_idx in range(self.num_code_groups - 1): + seq_len = layer_idx + 2 + # Compute position_ids dynamically to avoid torch.compile specializing batch_size + position_ids = torch.arange(seq_len, device=current_input.device, dtype=torch.int64).repeat(batch_size) + # Forward through code_predictor model + outputs = self.model( + inputs_embeds=current_input, + attention_mask=None, + position_ids=position_ids, + past_key_values=None, + use_cache=False, + cache_position=None, + ) + hidden_state = outputs.last_hidden_state # [batch, 2, hidden_size] + + # Use the corresponding lm_head for this layer + logits = self.lm_head[layer_idx](hidden_state[:, -1:, :]) + code = torch.ops.vllm.qwen3_omni_code_predictor_sample(logits, self.layer_name) + pos_codes.append(code) + # Update prev_embed for next layer (if not last layer) + # layer_idx=0 predicts layer 1, embed with codec_embedding[1] + new_embed = self.model.codec_embedding[layer_idx](code) # [batch, 1, hidden_size] + current_input = torch.cat([current_input, new_embed], dim=1) # [batch, 3~n, hidden_size] + pos_all_layers = torch.stack(pos_codes, dim=1) # [batch, num_code_groups, 1] + return pos_all_layers, current_input def load_weights(self, weights: list[tuple[str, torch.Tensor]]) -> set[str]: """Load weights with mapping for fused QKV and gate_up projections. diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py index 9f5c20cf7e4..3362ae17fac 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py @@ -129,70 +129,111 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.code_predictor = Qwen3OmniMoeTalkerCodePredictor( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "code_predictor") ) + max_batch_size = max( + vllm_config.scheduler_config.max_num_seqs, vllm_config.compilation_config.max_cudagraph_capture_size + ) + self.layer0_embed_buffer = torch.zeros( + (max_batch_size, 1, self.config.text_config.hidden_size), + dtype=vllm_config.model_config.dtype, + ) def code_predictor_forward( self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor | None = None, *, + temperature: float = 1.0, + top_k: int = 50, # Match transformers default + top_p: float = 0.8, # Match transformers default + generation_steps: int | None = None, last_talker_hidden: torch.Tensor | None = None, **_: object, ) -> tuple[torch.Tensor, torch.Tensor]: - """Generate full RVQ codes + summed embeddings (single-loop, no KV cache). + """ + Generate full RVQ codec codes for the provided sequence. - The code predictor uses re-prefill: each AR step re-forwards the full - (short) sequence through the transformer. The returned ``proj_buf`` - already contains all codec embeddings at positions 1..G, - so summed_embeddings = proj_buf[:, 1:, :].sum(dim=1) — no second - loop or re-embedding needed. + The code predictor consumes the layer-0 codec codes produced by the talker + alongside the talker's hidden states, and autoregressively predicts the remaining + residual layers (to num_codec_groups). Returns: - result_codes: [batch, num_code_groups, seq_len] - summed_embeddings: [batch, seq_len, hidden_size] + tuple containing: + - residual_codes: A tensor of shape [batch, num_code_groups, seq_len] containing + the complete set of codec codes + - summed_embeddings: A tensor of shape [batch, seq_len, hidden_size] + Sum of all layer embeddings at each position (like Transformers) """ if input_ids is None: - raise ValueError("`input_ids` (layer-0 codes) must be provided.") + raise ValueError("`input_ids` containing layer-0 codec codes must be provided.") if inputs_embeds is None: - raise ValueError("`inputs_embeds` (talker hidden states) must be provided.") + raise ValueError("`inputs_embeds` containing talker hidden states must be provided.") if inputs_embeds.ndim == 2: inputs_embeds = inputs_embeds.unsqueeze(0) if input_ids.ndim == 1: input_ids = input_ids.unsqueeze(0) + # Ensure the tensors are contiguous for the autoregressive sampling loop + inputs_embeds = inputs_embeds.contiguous() + input_ids = input_ids.contiguous() + + # Generate full codec codes using MTP + # This will be the parallel prediction implementation batch_size, seq_len = input_ids.shape - device = input_ids.device - embed_fn = self.language_model.model.codec_embedding - hidden_size = self.config.code_predictor_config.hidden_size - - result_codes = torch.empty( - batch_size, - self.num_code_groups, - seq_len, - dtype=torch.int64, - device=device, - ) - summed_embeddings = torch.empty( - batch_size, - seq_len, - hidden_size, - dtype=inputs_embeds.dtype, - device=device, - ) + # For now, use sequential generation (TODO: implement parallel) + # Result will be [batch, num_code_groups, seq_len] + # - all_codes_per_position will collect [batch, num_code_groups, 1] for each position + all_codes_per_position = [] + middle_hidden_states = [] # Collect hidden states for each position + + # Generate residual layers for each position for pos in range(seq_len): - layer0_code = input_ids[:, pos : pos + 1] - layer0_embed = embed_fn(layer0_code) + layer0_code = input_ids[:, pos : pos + 1] # [batch, 1] - pos_all_layers, proj_buf = self.code_predictor( - layer0_code, - layer0_embed, - last_talker_hidden, + # Initial input: [last_talker_hidden, layer0_embed] + layer0_embed = self.embed_input_ids(layer0_code) + self.layer0_embed_buffer[:batch_size].copy_(layer0_embed) + pos_all_layers, current_input = self.code_predictor( + layer0_code, self.layer0_embed_buffer[:batch_size], last_talker_hidden ) - result_codes[:, :, pos : pos + 1] = pos_all_layers - # proj_buf layout: [0]=talker_hidden, [1..G]=codec embeds - summed_embeddings[:, pos, :] = proj_buf[:, 1:, :].sum(dim=1) + # Stack all layers for this position: [batch, num_code_groups, 1] + all_codes_per_position.append(pos_all_layers) + middle_hidden_states.append(current_input[:, 2:-1, :]) + + # Concatenate across positions: [batch, num_code_groups, seq_len] + result_codes = torch.cat(all_codes_per_position, dim=2) + + # Build summed embeddings for each position (like Transformers) + # This combines layer-0 embed, mid layers hidden states, and last layer embed + all_summed_embeddings = [] + + for pos in range(seq_len): + # Layer 0 embedding + layer0_code = result_codes[:, 0, pos : pos + 1] # [batch, 1] + layer0_embed = self.embed_input_ids(layer0_code) # [batch, 1, hidden_size] + + # mid layers hidden states (from CodePredictor) + mid_residual_hiddens = middle_hidden_states[pos] # [batch, num_code_groups-2, hidden_size] + mid_list = list(mid_residual_hiddens.split(1, dim=1)) + + # last layer embedding + last_layer_code = result_codes[:, -1, pos : pos + 1] # [batch, 1] + last_residual_hidden = self.code_predictor.model.codec_embedding[-1](last_layer_code) + + # Concatenate all layers: [batch, num_code_groups, hidden_size] + pos_codec_hiddens = torch.cat( + [layer0_embed] + mid_list + [last_residual_hidden], + dim=1, + ) + + # Sum across layers: [batch, 1, hidden_size] (like Transformers) + pos_summed = pos_codec_hiddens.sum(dim=1, keepdim=True) + all_summed_embeddings.append(pos_summed) + + # Concatenate across positions: [batch, seq_len, hidden_size] + summed_embeddings = torch.cat(all_summed_embeddings, dim=1).squeeze(1) return result_codes, summed_embeddings diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 17acbe8b005..8725c62f66c 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1341,17 +1341,10 @@ def _update_intermediate_buffer(self, req_id: str, upd: dict) -> None: req_state = self.requests.get(req_id) if req_state is None: return - # Check if the model declares keys that should stay on GPU - gpu_keys: set[str] = set() - if hasattr(self, "model") and hasattr(self.model, "gpu_resident_buffer_keys"): - gpu_keys = self.model.gpu_resident_buffer_keys existing = self.model_intermediate_buffer.setdefault(req_id, {}) for k, v in upd.items(): if isinstance(v, torch.Tensor): - if k in gpu_keys: - existing[k] = v.detach().clone() - else: - existing[k] = v.detach().to("cpu").contiguous() + existing[k] = v.detach().to("cpu").contiguous() elif isinstance(v, list): existing[k] = [ (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v