diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index ec928195ab7..21f0185aa3f 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -159,6 +159,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # suppress tokens by setting their probability to ~1e-9 (finite very small) self.suppressed_tokens = self._get_talker_suppressed_tokens() self.requires_raw_input_tokens = True + # Keys that should stay on GPU in model_intermediate_buffer to avoid CPU↔GPU round-trips + self.gpu_resident_buffer_keys: set[str] = { + "last_talker_hidden", + "trailing_text_hidden", + "tts_pad_embed_projected", + } elif self.model_stage == "code2wav": self.enable_update_additional_information = True @@ -223,14 +229,16 @@ def embed_multimodal(self, **kwargs): # ==================== Forward Pass ==================== def _get_talker_suppressed_tokens(self): - return [ - i - for i in range( - self.config.talker_config.text_config.vocab_size - 1024, - self.config.talker_config.text_config.vocab_size, - ) - if i != self.config.talker_config.codec_eos_token_id - ] + """Return a boolean mask on GPU for suppressed token positions.""" + vocab_size = self.config.talker_config.text_config.vocab_size + mask = torch.zeros(vocab_size, dtype=torch.bool) + start = vocab_size - 1024 + eos_id = self.config.talker_config.codec_eos_token_id + for i in range(start, vocab_size): + if i != eos_id: + mask[i] = True + # Will be moved to the correct device on first use + return mask def get_mrope_input_positions( self, @@ -578,7 +586,7 @@ def talker_postprocess(self, hidden_states: torch.Tensor, **info_dict: object): Postprocess the talker hidden states. """ update_dict = {} - update_dict["last_talker_hidden"] = hidden_states[-1, :].detach().to("cpu").contiguous() + update_dict["last_talker_hidden"] = hidden_states[-1, :].detach() return update_dict def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **info_dict: dict): @@ -632,9 +640,11 @@ def talker_mtp( if inputs_embeds.shape[-1] == 2048: inputs_embeds = self.text_projection(inputs_embeds) code_predictor_codes, summed_embeddings = self.talker.code_predictor_forward( - input_ids, inputs_embeds.clone(), last_talker_hidden=last_talker_hidden + input_ids, inputs_embeds, last_talker_hidden=last_talker_hidden ) - inputs_embeds = summed_embeddings.clone() + # summed_embeddings is [B, seq_len, H] (3D) while text_step is [B, H] (2D). + # Flatten to 2D first to avoid wrong broadcasting: [B,1,H]+[B,H] → [B,B,H] + inputs_embeds = summed_embeddings.reshape(-1, self.talker_config.text_config.hidden_size) inputs_embeds = (inputs_embeds + text_step).reshape(-1, self.talker_config.text_config.hidden_size) return inputs_embeds, code_predictor_codes.squeeze(-1) @@ -758,7 +768,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch # compatible with old shape [1,S,D] rem_tail = trailing_text_hidden.squeeze(0) if rem_tail.shape[0] > 0: - update_dict["trailing_text_hidden"] = rem_tail.detach().to("cpu").contiguous() + update_dict["trailing_text_hidden"] = rem_tail.detach() # Also persist projected tts_pad for decode fallback if needed if isinstance(tts_pad_thinker, torch.Tensor): pad_in = tts_pad_thinker @@ -767,7 +777,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch if pad_in.ndim == 1: pad_in = pad_in.view(1, 1, -1) pad_proj = self.talker.text_projection(pad_in.to(self._module_device(self.talker))) - update_dict["tts_pad_embed_projected"] = pad_proj.detach().to("cpu").contiguous() + update_dict["tts_pad_embed_projected"] = pad_proj.detach() except Exception: pass self._talker_cache_thinker_decode_embeds(info_dict, update_dict) @@ -926,7 +936,7 @@ def talker_preprocess_decode( if isinstance(q_tail, torch.Tensor) and q_tail.numel() > 0: use_vec = q_tail[0:1, :] new_q_tail = ( - q_tail[1:, :].detach().to("cpu").contiguous() + q_tail[1:, :].detach() if q_tail.shape[0] > 1 else self.tts_pad_embed.to(input_embeds.device, dtype=input_embeds.dtype) ) @@ -1122,16 +1132,10 @@ def compute_logits( # implemented by assigning their logits to log(1e-9). if getattr(self, "model_stage", None) == "talker" and isinstance(logits, torch.Tensor): - try: - logits_cpu = logits.cpu() - logits_cpu[:, self.suppressed_tokens] = -1e9 - logits = logits_cpu.to(logits.device) - except Exception as e: - print(f"Error in logits suppression: {e}") - print(f"logits.shape: {logits.shape}") - print(f"self.suppressed_tokens: {self.suppressed_tokens}") - raise e - logits[:, self.suppressed_tokens] = -1e9 + # Move mask to device once (lazy), then reuse every step + if self.suppressed_tokens.device != logits.device: + self.suppressed_tokens = self.suppressed_tokens.to(logits.device) + logits.masked_fill_(self.suppressed_tokens.unsqueeze(0), -1e9) return logits def sample( diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py index fc7402890ab..b1f0b2237e0 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py @@ -1,27 +1,16 @@ -"""Qwen3-Omni Code Predictor with MTP (Multi-Token Prediction) support. +"""Qwen3-Omni Code Predictor -- optimized re-prefill, no KV cache. -This module implements the code predictor component for Qwen3-Omni talker models. - -The code predictor generates residual RVQ (Residual Vector Quantization) codes -autoregressively, predicting layers 1 to N based on layer-0 codes from the talker. +* SDPA attention (F.scaled_dot_product_attention) with native GQA support +* Per-call embedding buffer to avoid cross-request aliasing +* Pre-allocated position_ids (read-only, safe to persist) +* torch.compile on inner transformer by default +* Inline sampling (top-k + top-p) -- no custom op overhead """ -from collections import namedtuple -from typing import Any - import torch import torch.nn as nn import torch.nn.functional as F -from transformers import Cache, PretrainedConfig -from transformers.generation.logits_process import ( - LogitsProcessorList, - TopKLogitsWarper, - TopPLogitsWarper, -) -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config -from vllm.forward_context import get_forward_context +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -32,47 +21,53 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.utils.torch_utils import direct_register_custom_op + +from vllm_omni.platforms import current_omni_platform logger = init_logger(__name__) -# ============================================================================ -# Code Predictor Attention Layer -# ============================================================================ + +# =================================================================== +# Standalone Attention (SDPA, no KV cache, no HF backend fallback) +# =================================================================== class Qwen3OmniCodePredictorAttention(nn.Module): - """Multi-head self-attention for code predictor with vLLM optimization.""" + """Multi-head self-attention for code predictor. + + Uses ``F.scaled_dot_product_attention`` directly. No KV cache -- the code + predictor always re-prefills the full (short) sequence each AR step. + + Input : [B, seq_len, hidden_size] + Output: [B, seq_len, hidden_size] + """ def __init__( self, config, - layer_idx: int, - vllm_config: VllmConfig = None, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() - - self.num_heads = config.code_predictor_config.num_attention_heads - self.num_key_value_heads = config.code_predictor_config.num_key_value_heads + cp_cfg = config.code_predictor_config + self.num_heads = cp_cfg.num_attention_heads + self.num_kv_heads = cp_cfg.num_key_value_heads self.head_dim = getattr( - config.code_predictor_config, + cp_cfg, "head_dim", - config.code_predictor_config.hidden_size // config.code_predictor_config.num_attention_heads, + cp_cfg.hidden_size // cp_cfg.num_attention_heads, ) - self.hidden_size = config.code_predictor_config.hidden_size - - if self.num_heads % self.num_key_value_heads != 0: - raise ValueError("num_attention_heads must be divisible by num_key_value_heads") - - self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.hidden_size = cp_cfg.hidden_size + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self._use_gqa = self.num_kv_heads != self.num_heads self.qkv_proj = QKVParallelLinear( hidden_size=self.hidden_size, head_size=self.head_dim, total_num_heads=self.num_heads, - total_num_kv_heads=self.num_key_value_heads, + total_num_kv_heads=self.num_kv_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", @@ -88,120 +83,55 @@ def __init__( ) self.rotary_emb = get_rope( self.head_dim, - max_position=config.code_predictor_config.max_position_embeddings, + max_position=cp_cfg.max_position_embeddings, rope_parameters=None, dual_chunk_attention_config=None, ) - - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_key_value_heads * self.head_dim - - # Query/Key normalization - self.q_norm = RMSNorm(self.head_dim, eps=config.code_predictor_config.rms_norm_eps) - self.k_norm = RMSNorm(self.head_dim, eps=config.code_predictor_config.rms_norm_eps) - self.is_causal = True - self.config = config - - self.attention_backends = ["flash_attention_2", "xformers", "eager", "sdpa"] - cudagraph_mode = get_current_vllm_config().compilation_config.cudagraph_mode - if "flash_attention_2" in ALL_ATTENTION_FUNCTIONS and cudagraph_mode.has_full_cudagraphs(): - logger.warning( - f"CUDAGraphMode.{cudagraph_mode.name} is currently not supported " - f"with flash attention for Qwen3-Omni talker MTP." - f"removing flash attention from attention_backends" - ) - self.attention_backends.remove("flash_attention_2") + self.q_norm = RMSNorm(self.head_dim, eps=cp_cfg.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=cp_cfg.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, - past_key_values: Cache | None = None, - cache_position: torch.LongTensor | None = None, - use_cache: bool = False, - position_ids: torch.LongTensor | None = None, + position_ids: torch.Tensor, ) -> torch.Tensor: bsz, seq_len, _ = hidden_states.shape - qkv, _ = self.qkv_proj(hidden_states) + # Flatten to 2-D so vLLM rotary_emb gets [num_tokens, size] + qkv, _ = self.qkv_proj(hidden_states.reshape(bsz * seq_len, -1)) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # Reshape for attention - q = q.reshape(bsz, seq_len, self.num_heads, self.head_dim) - k = k.reshape(bsz, seq_len, self.num_key_value_heads, self.head_dim) - v = v.reshape(bsz, seq_len, self.num_key_value_heads, self.head_dim) - - # Apply normalization - q = self.q_norm(q).contiguous() - k = self.k_norm(k).contiguous() - q = q.reshape(-1, self.q_size) - k = k.reshape(-1, self.kv_size) - - # Apply RoPE + # QK-norm -> RoPE (both 2-D) + q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(q.shape) + k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(k.shape) q, k = self.rotary_emb(position_ids, q, k) - # Reshape for attention - q = q.reshape(bsz, seq_len, self.num_heads, self.head_dim) - k = k.reshape(bsz, seq_len, self.num_key_value_heads, self.head_dim) - - v_heads = v.transpose(1, 2).contiguous() - q_heads = q.transpose(1, 2).contiguous() - k_heads = k.transpose(1, 2).contiguous() - - if past_key_values is not None: - sin, cos = self.rotary_emb.get_cos_sin(seq_len) - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - k_heads, v_heads = past_key_values.update(k_heads, v_heads, self.layer_idx, cache_kwargs) - - # Try attention backends in order of preference, with runtime error handling - # This handles cases where the backend is registered but not actually available - attn_output = None - last_error = None - - for backend_name in self.attention_backends: - if backend_name not in ALL_ATTENTION_FUNCTIONS: - continue - - try: - attention_interface = ALL_ATTENTION_FUNCTIONS[backend_name] - attn_output, _ = attention_interface( - self, - q_heads, - k_heads, - v_heads, - None, - dropout=0.0 if not self.training else getattr(self, "attention_dropout", 0.0), - scaling=self.head_dim**-0.5, - sliding_window=None, - use_cache=use_cache, - position_ids=position_ids[:seq_len].unsqueeze(0), - output_hidden_states=True, - output_attentions=False, - ) - break - except (ValueError, ImportError, RuntimeError, AttributeError) as e: - # Store error and try next backend - last_error = e - continue - - if attn_output is None: - raise RuntimeError( - f"All attention backends failed. Last error: {last_error}. " - "Please install flash-attn, or ensure PyTorch's scaled_dot_product_attention is available." - ) - attn_output = attn_output.reshape(*(hidden_states.shape[:-1]), -1).contiguous() + # [B, heads, seq, head_dim] + q = q.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + + attn_out = F.scaled_dot_product_attention( + q, + k, + v, + scale=self.scaling, + is_causal=True, + enable_gqa=self._use_gqa, + ) - attn_output, _ = self.o_proj(attn_output) - return attn_output + attn_out = attn_out.transpose(1, 2).reshape(bsz * seq_len, -1) + output, _ = self.o_proj(attn_out) + return output.view(bsz, seq_len, -1) -# ============================================================================ -# Code Predictor MLP Layer -# ============================================================================ +# =================================================================== +# MLP +# =================================================================== class Qwen3OmniCodePredictorMLP(nn.Module): - """Feed-forward network for code predictor with fused gate/up projection.""" + """SiLU-gated MLP for code predictor.""" def __init__( self, @@ -221,7 +151,6 @@ def __init__( prefix=f"{prefix}.gate_up_proj", disable_tp=True, ) - self.down_proj = RowParallelLinear( input_size=intermediate_size, output_size=hidden_size, @@ -238,35 +167,23 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return down -# ============================================================================ -# MTP Layer (Multi-Token Prediction Layer) -# ============================================================================ +# =================================================================== +# Decoder Layer +# =================================================================== -class Qwen3OmniCodePredictorMTPLayer(nn.Module): - """MTP layer for speculative decoding - predicts next residual code layer.""" +class Qwen3OmniCodePredictorDecoderLayer(nn.Module): + """Transformer decoder layer (SDPA, no KV cache).""" def __init__( self, - config: PretrainedConfig, - prefix: str, - model_config: ModelConfig, - layer_idx: int, - cache_config: CacheConfig | None = None, + config, quant_config: QuantizationConfig | None = None, + prefix: str = "", ) -> None: super().__init__() - self.layer_idx = layer_idx - self.config = config - self.self_attn = Qwen3OmniCodePredictorAttention( config, - layer_idx, - vllm_config=type( - "VllmConfig", - (), - {"cache_config": cache_config, "quant_config": quant_config, "model_config": model_config}, - )(), quant_config=quant_config, prefix=f"{prefix}.self_attn", ) @@ -275,260 +192,258 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm( - config.code_predictor_config.hidden_size, eps=config.code_predictor_config.rms_norm_eps - ) - self.post_attention_layernorm = RMSNorm( - config.code_predictor_config.hidden_size, eps=config.code_predictor_config.rms_norm_eps - ) + cp_cfg = config.code_predictor_config + self.input_layernorm = RMSNorm(cp_cfg.hidden_size, eps=cp_cfg.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(cp_cfg.hidden_size, eps=cp_cfg.rms_norm_eps) - def mtp_block( + def forward( self, hidden_states: torch.Tensor, - past_key_values: Cache | None = None, - cache_position: torch.LongTensor | None = None, - use_cache: bool = False, - position_ids: torch.LongTensor | None = None, + position_ids: torch.Tensor, ) -> torch.Tensor: - # Self-attention with residual residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn(hidden_states, past_key_values, cache_position, use_cache, position_ids) + hidden_states = self.self_attn(hidden_states, position_ids) hidden_states = residual + hidden_states - # MLP with residual residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return hidden_states +# =================================================================== +# Base Transformer Model (re-prefill, no KV cache) +# =================================================================== + + class Qwen3OmniCodePredictorBaseModel(nn.Module): - """ - Base model for code predictor - matches HF Qwen3OmniMoeTalkerCodePredictorModel structure. + """Inner transformer for code predictor. - This is a simple transformer that processes inputs_embeds and outputs hidden states. + Signature: ``forward(inputs_embeds, position_ids) -> hidden_states`` + -- plain Tensor in, plain Tensor out (no namedtuple). """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config.code_predictor_config - self.config = config - self.vocab_size = config.vocab_size - self.num_code_groups = config.num_code_groups - # Codec embeddings (for layers 1-num_code_groups-1) self.codec_embedding = nn.ModuleList( - [ - VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - for _ in range(config.num_code_groups - 1) - ] + [VocabParallelEmbedding(config.vocab_size, config.hidden_size) for _ in range(config.num_code_groups - 1)] ) - # Decoder layers self.layers = nn.ModuleList( [ - Qwen3OmniCodePredictorMTPLayer( + Qwen3OmniCodePredictorDecoderLayer( vllm_config.model_config.hf_config, - f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - layer_idx=idx, - cache_config=vllm_config.cache_config, quant_config=vllm_config.quant_config, + prefix=f"{prefix}.layers.{idx}", ) for idx in range(config.num_hidden_layers) ] ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, inputs_embeds: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Any | None = None, - use_cache: bool | None = False, - cache_position: torch.LongTensor | None = None, - **kwargs: Any, - ) -> Any: - """ - Forward pass matching HF structure. - - Args: - inputs_embeds: [batch, seq_len, hidden_size] - position_ids: Optional position IDs tensor - past_key_values: Optional cached key-value pairs - use_cache: Whether to use cache - cache_position: Optional cache position tensor - **kwargs: Additional keyword arguments - - Returns: - Named tuple with .last_hidden_state and .past_key_values attributes - """ - batch_size, seq_len, _ = inputs_embeds.shape - # Forward through decoder layers + position_ids: torch.Tensor, + ) -> torch.Tensor: hidden_states = inputs_embeds - for layer in self.layers: - hidden_states = layer.mtp_block(hidden_states, past_key_values, cache_position, use_cache, position_ids) - - # Final norm + hidden_states = layer(hidden_states, position_ids) hidden_states = self.norm(hidden_states) - - # Return in HF-compatible format - Output = namedtuple("Output", ["last_hidden_state", "past_key_values"]) - return Output(last_hidden_state=hidden_states, past_key_values=None) # [batch, num_code_groups-1, hidden_size] - - def get_input_embeddings(self): - """Return codec embeddings for HF compatibility.""" - return self.codec_embedding - - -def code_predictor_sample( - logits: torch.Tensor, - layer_name: str, -) -> torch.Tensor: - forward_context = get_forward_context() - self = forward_context.no_compile_layers[layer_name] - logits = self.logits_processors(None, logits[:, -1]) - probs = F.softmax(logits, dim=-1) - code = torch.multinomial(probs.squeeze(1), num_samples=1) # [batch, 1] - return code - - -def code_predictor_sample_fake( - logits: torch.Tensor, - layer_name: str, -) -> torch.Tensor: - return torch.empty((logits.shape[0], 1), dtype=torch.int64, device=logits.device) + return hidden_states -direct_register_custom_op( - op_name="qwen3_omni_code_predictor_sample", - op_func=code_predictor_sample, - fake_impl=code_predictor_sample_fake, -) +# =================================================================== +# Code Predictor Wrapper (optimized re-prefill, persistent buffers) +# =================================================================== -@support_torch_compile class Qwen3OmniMoeTalkerCodePredictor(nn.Module): - """ - Code predictor wrapper matching HF structure. - - Structure: - - self.model: Qwen3OmniCodePredictorBaseModel (transformer) - - self.lm_head: ModuleList of output heads + """Optimized code predictor -- re-prefill approach, no KV cache. + + Each AR step forwards the full growing sequence (len 2 -> num_code_groups+1) + through the transformer. The extra O(T^2) FLOPs are negligible for + short sequences, and this avoids all KV-cache management overhead. + + Optimizations: + 1. Per-call embedding buffer -- avoids cross-request aliasing. + 2. Pre-allocated position_ids -- no torch.arange per step. + 3. Cached module references -- bypass ModuleList indexing. + 4. torch.compile on inner transformer. + 5. Inline sampling (top-k + top-p) -- no custom op overhead. """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - talker_code_predictor_config = vllm_config.model_config.hf_config + config = vllm_config.model_config.hf_config + self.config = config self.quant_config = vllm_config.quant_config self.prefix = prefix - self.config = talker_code_predictor_config - self.vocab_size = self.config.code_predictor_config.vocab_size - self.num_code_groups = self.config.code_predictor_config.num_code_groups + self.num_code_groups = config.code_predictor_config.num_code_groups + self._hidden_size = config.code_predictor_config.hidden_size - # Base transformer model (matches HF structure) - self.model = Qwen3OmniCodePredictorBaseModel(vllm_config=vllm_config, prefix=prefix) + self.model = Qwen3OmniCodePredictorBaseModel( + vllm_config=vllm_config, + prefix=prefix, + ) - # Output heads for each residual layer (1-num_layers-1) + # One lm_head per residual layer (layers 1 .. G-1) self.lm_head = nn.ModuleList( [ nn.Linear( - self.config.code_predictor_config.hidden_size, - self.config.code_predictor_config.vocab_size, + config.code_predictor_config.hidden_size, + config.code_predictor_config.vocab_size, bias=False, ) for _ in range(self.num_code_groups - 1) ] ) - self.logits_processors = LogitsProcessorList( - [ - TopKLogitsWarper(top_k=50), - TopPLogitsWarper(top_p=0.8), - ] - ) - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - self.layer_name = prefix + self.set_sampling_params() + + # Lazily initialised position ids (read-only, safe to persist) + self._pos_ids: torch.Tensor | None = None + + # Cached plain-list refs (set once) + self._lm_heads: list | None = None + self._codec_embeds: list | None = None + + # Model forward (optionally compiled) + self._model_fwd: object | None = None + + def set_sampling_params(self, top_k: int = 50, top_p: float = 0.8): + """Configure sampling parameters to maintain consistency with previous implementation.""" + self._top_k = top_k + self._top_p = top_p + logger.debug(f"Sampling parameters updated: top_k={top_k}, top_p={top_p}s") + + # ------------------------------------------------------------------ + # Lazy-init helpers + # ------------------------------------------------------------------ + + def _ensure_pos_ids(self, device: torch.device) -> None: + if self._pos_ids is not None and self._pos_ids.device == device: + return + max_seq = self.num_code_groups + 1 + self._pos_ids = torch.arange(max_seq, dtype=torch.long, device=device) + + def _ensure_cached_refs(self) -> None: + if self._lm_heads is not None: + return + self._lm_heads = list(self.lm_head) + self._codec_embeds = list(self.model.codec_embedding) + + def _ensure_model_fwd(self) -> None: + if self._model_fwd is not None: + return + if current_omni_platform.supports_torch_inductor(): + # torch.compile fuses the 5-layer transformer's small kernels, + # reducing BF16 intermediate round-trips and improving precision. + # mode="default" avoids Inductor's own CUDA graph capture so it + # doesn't conflict with the outer CUDAGraphWrapper. + self._model_fwd = torch.compile( + self.model.forward, + mode="default", + dynamic=True, + ) + logger.info("code_predictor: torch.compile enabled (mode=default)") + else: + self._model_fwd = self.model.forward + logger.info("code_predictor: using eager mode (no torch.compile)") + + # ------------------------------------------------------------------ + # Forward -- re-prefill + inline sampling + # ------------------------------------------------------------------ + @torch.inference_mode() def forward( self, layer0_code: torch.Tensor, layer0_embed: torch.Tensor, last_talker_hidden: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Forward pass for code predictor. + """Predict residual codebooks 1..G-1 autoregressively via re-prefill. Args: - layer0_code: - Code index for code-group (layer) 0. - Shape: [batch_size, 1], dtype typically int64. - - last_talker_hidden: - - Shape: [batch_size, hidden_size]. + layer0_code: [bsz, 1] int64 + layer0_embed: [bsz, 1, hidden_size] + last_talker_hidden: [bsz, 1, hidden_size] Returns: - pos_all_layers: - Predicted codes for all code groups, including `layer0_code`. - Shape: [batch_size, num_code_groups, 1]. - - current_input: - The final input embedding sequence after appending embeddings of all - predicted codes (one token per predicted layer). - Shape: [batch_size, num_code_groups + 2, hidden_size]. + all_codes: [bsz, num_code_groups, 1] + proj_buf: [bsz, num_code_groups + 1, hidden_size] + pos 0 = last_talker_hidden (NOT a codec embed) + pos 1 = layer0_embed + pos 2.. = codec_embedding[i](predicted_code_i) """ - pos_codes = [layer0_code] # Start with layer 0: [batch, 1] - try: - current_input = torch.cat([last_talker_hidden, layer0_embed], dim=1) # [batch, 2, hidden_size] - except Exception as e: - print(f"Error in current_input: {e}") - print(f"last_talker_hidden shape: {last_talker_hidden.shape}") - print(f"prev_embed shape: {layer0_embed.shape}") - raise e - batch_size = current_input.shape[0] - - # Predict all residual layers (layers 1 to num_code_groups-1) autoregressively - for layer_idx in range(self.num_code_groups - 1): - seq_len = layer_idx + 2 - # Compute position_ids dynamically to avoid torch.compile specializing batch_size - position_ids = torch.arange(seq_len, device=current_input.device, dtype=torch.int64).repeat(batch_size) - # Forward through code_predictor model - outputs = self.model( - inputs_embeds=current_input, - attention_mask=None, - position_ids=position_ids, - past_key_values=None, - use_cache=False, - cache_position=None, - ) - hidden_state = outputs.last_hidden_state # [batch, 2, hidden_size] - - # Use the corresponding lm_head for this layer - logits = self.lm_head[layer_idx](hidden_state[:, -1:, :]) - code = torch.ops.vllm.qwen3_omni_code_predictor_sample(logits, self.layer_name) - pos_codes.append(code) - # Update prev_embed for next layer (if not last layer) - # layer_idx=0 predicts layer 1, embed with codec_embedding[1] - new_embed = self.model.codec_embedding[layer_idx](code) # [batch, 1, hidden_size] - current_input = torch.cat([current_input, new_embed], dim=1) # [batch, 3~n, hidden_size] - pos_all_layers = torch.stack(pos_codes, dim=1) # [batch, num_code_groups, 1] - return pos_all_layers, current_input + bsz = int(layer0_code.shape[0]) + device = layer0_code.device + dtype = last_talker_hidden.dtype + num_groups = self.num_code_groups + + # Lazy init (read-only caches only) + self._ensure_pos_ids(device) + self._ensure_model_fwd() + self._ensure_cached_refs() + + # Allocate proj_buf locally each call to avoid cross-call aliasing + max_seq = num_groups + 1 + proj_buf = torch.zeros(bsz, max_seq, self._hidden_size, dtype=dtype, device=device) + pos_ids = self._pos_ids + model_fwd = self._model_fwd + lm_heads = self._lm_heads + codec_embeds = self._codec_embeds + + # Output codes + all_codes = torch.empty(bsz, num_groups, 1, dtype=torch.int64, device=device) + all_codes[:, 0] = layer0_code + + # Fill buffer positions 0 & 1 + proj_buf[:bsz, 0:1, :] = last_talker_hidden + proj_buf[:bsz, 1:2, :] = layer0_embed + + # Autoregressive loop: predict layers 1..G-1 + for step in range(1, num_groups): + seq_len = step + 1 + projected = proj_buf[:bsz, :seq_len, :] + step_pos_ids = pos_ids[:seq_len] if bsz == 1 else pos_ids[:seq_len].repeat(bsz) + assert step_pos_ids.shape[0] == bsz * seq_len + + hidden_out = model_fwd(projected, step_pos_ids) + + # Inline sampling: top-k -> top-p -> softmax -> multinomial + logits = lm_heads[step - 1](hidden_out[:, -1, :]) # [bsz, vocab] + if self._top_k > 0: + topk_vals, _ = logits.topk(self._top_k, dim=-1) + logits = logits.masked_fill(logits < topk_vals[:, -1:], float("-inf")) + if self._top_p < 1.0: + sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) + cumulative_probs = F.softmax(sorted_logits, dim=-1).cumsum(dim=-1) + # Remove tokens with cumulative probability above top_p + remove_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= self._top_p + sorted_logits[remove_mask] = float("-inf") + logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) + probs = F.softmax(logits, dim=-1) + code = torch.multinomial(probs, num_samples=1) # [bsz, 1] + + all_codes[:, step] = code + + # Embed predicted code -> next buffer position + new_embed = codec_embeds[step - 1](code) # [batch, 1, hidden_size] + proj_buf[:bsz, step + 1 : step + 2, :] = new_embed + + return all_codes, proj_buf[:bsz] + + # ------------------------------------------------------------------ + # Weight loading + # ------------------------------------------------------------------ def load_weights(self, weights: list[tuple[str, torch.Tensor]]) -> set[str]: """Load weights with mapping for fused QKV and gate_up projections. diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py index 3362ae17fac..9f5c20cf7e4 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py @@ -129,111 +129,70 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.code_predictor = Qwen3OmniMoeTalkerCodePredictor( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "code_predictor") ) - max_batch_size = max( - vllm_config.scheduler_config.max_num_seqs, vllm_config.compilation_config.max_cudagraph_capture_size - ) - self.layer0_embed_buffer = torch.zeros( - (max_batch_size, 1, self.config.text_config.hidden_size), - dtype=vllm_config.model_config.dtype, - ) def code_predictor_forward( self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor | None = None, *, - temperature: float = 1.0, - top_k: int = 50, # Match transformers default - top_p: float = 0.8, # Match transformers default - generation_steps: int | None = None, last_talker_hidden: torch.Tensor | None = None, **_: object, ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Generate full RVQ codec codes for the provided sequence. + """Generate full RVQ codes + summed embeddings (single-loop, no KV cache). - The code predictor consumes the layer-0 codec codes produced by the talker - alongside the talker's hidden states, and autoregressively predicts the remaining - residual layers (to num_codec_groups). + The code predictor uses re-prefill: each AR step re-forwards the full + (short) sequence through the transformer. The returned ``proj_buf`` + already contains all codec embeddings at positions 1..G, + so summed_embeddings = proj_buf[:, 1:, :].sum(dim=1) — no second + loop or re-embedding needed. Returns: - tuple containing: - - residual_codes: A tensor of shape [batch, num_code_groups, seq_len] containing - the complete set of codec codes - - summed_embeddings: A tensor of shape [batch, seq_len, hidden_size] - Sum of all layer embeddings at each position (like Transformers) + result_codes: [batch, num_code_groups, seq_len] + summed_embeddings: [batch, seq_len, hidden_size] """ if input_ids is None: - raise ValueError("`input_ids` containing layer-0 codec codes must be provided.") + raise ValueError("`input_ids` (layer-0 codes) must be provided.") if inputs_embeds is None: - raise ValueError("`inputs_embeds` containing talker hidden states must be provided.") + raise ValueError("`inputs_embeds` (talker hidden states) must be provided.") if inputs_embeds.ndim == 2: inputs_embeds = inputs_embeds.unsqueeze(0) if input_ids.ndim == 1: input_ids = input_ids.unsqueeze(0) - # Ensure the tensors are contiguous for the autoregressive sampling loop - inputs_embeds = inputs_embeds.contiguous() - input_ids = input_ids.contiguous() - - # Generate full codec codes using MTP - # This will be the parallel prediction implementation batch_size, seq_len = input_ids.shape + device = input_ids.device + embed_fn = self.language_model.model.codec_embedding + hidden_size = self.config.code_predictor_config.hidden_size + + result_codes = torch.empty( + batch_size, + self.num_code_groups, + seq_len, + dtype=torch.int64, + device=device, + ) + summed_embeddings = torch.empty( + batch_size, + seq_len, + hidden_size, + dtype=inputs_embeds.dtype, + device=device, + ) - # For now, use sequential generation (TODO: implement parallel) - # Result will be [batch, num_code_groups, seq_len] - # - all_codes_per_position will collect [batch, num_code_groups, 1] for each position - all_codes_per_position = [] - middle_hidden_states = [] # Collect hidden states for each position - - # Generate residual layers for each position for pos in range(seq_len): - layer0_code = input_ids[:, pos : pos + 1] # [batch, 1] + layer0_code = input_ids[:, pos : pos + 1] + layer0_embed = embed_fn(layer0_code) - # Initial input: [last_talker_hidden, layer0_embed] - layer0_embed = self.embed_input_ids(layer0_code) - self.layer0_embed_buffer[:batch_size].copy_(layer0_embed) - pos_all_layers, current_input = self.code_predictor( - layer0_code, self.layer0_embed_buffer[:batch_size], last_talker_hidden + pos_all_layers, proj_buf = self.code_predictor( + layer0_code, + layer0_embed, + last_talker_hidden, ) - # Stack all layers for this position: [batch, num_code_groups, 1] - all_codes_per_position.append(pos_all_layers) - middle_hidden_states.append(current_input[:, 2:-1, :]) - - # Concatenate across positions: [batch, num_code_groups, seq_len] - result_codes = torch.cat(all_codes_per_position, dim=2) - - # Build summed embeddings for each position (like Transformers) - # This combines layer-0 embed, mid layers hidden states, and last layer embed - all_summed_embeddings = [] - - for pos in range(seq_len): - # Layer 0 embedding - layer0_code = result_codes[:, 0, pos : pos + 1] # [batch, 1] - layer0_embed = self.embed_input_ids(layer0_code) # [batch, 1, hidden_size] - - # mid layers hidden states (from CodePredictor) - mid_residual_hiddens = middle_hidden_states[pos] # [batch, num_code_groups-2, hidden_size] - mid_list = list(mid_residual_hiddens.split(1, dim=1)) - - # last layer embedding - last_layer_code = result_codes[:, -1, pos : pos + 1] # [batch, 1] - last_residual_hidden = self.code_predictor.model.codec_embedding[-1](last_layer_code) - - # Concatenate all layers: [batch, num_code_groups, hidden_size] - pos_codec_hiddens = torch.cat( - [layer0_embed] + mid_list + [last_residual_hidden], - dim=1, - ) - - # Sum across layers: [batch, 1, hidden_size] (like Transformers) - pos_summed = pos_codec_hiddens.sum(dim=1, keepdim=True) - all_summed_embeddings.append(pos_summed) - - # Concatenate across positions: [batch, seq_len, hidden_size] - summed_embeddings = torch.cat(all_summed_embeddings, dim=1).squeeze(1) + result_codes[:, :, pos : pos + 1] = pos_all_layers + # proj_buf layout: [0]=talker_hidden, [1..G]=codec embeds + summed_embeddings[:, pos, :] = proj_buf[:, 1:, :].sum(dim=1) return result_codes, summed_embeddings diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 9e96744584e..ddc6c119024 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1362,10 +1362,17 @@ def _update_intermediate_buffer(self, req_id: str, upd: dict) -> None: req_state = self.requests.get(req_id) if req_state is None: return + # Check if the model declares keys that should stay on GPU + gpu_keys: set[str] = set() + if hasattr(self, "model") and hasattr(self.model, "gpu_resident_buffer_keys"): + gpu_keys = self.model.gpu_resident_buffer_keys existing = self.model_intermediate_buffer.setdefault(req_id, {}) for k, v in upd.items(): if isinstance(v, torch.Tensor): - existing[k] = v.detach().to("cpu").contiguous() + if k in gpu_keys: + existing[k] = v.detach().clone() + else: + existing[k] = v.detach().to("cpu").contiguous() elif isinstance(v, list): existing[k] = [ (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v