diff --git a/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py b/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py index b0ce10a8d5e..8798cb3ca9a 100644 --- a/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py +++ b/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py @@ -21,7 +21,7 @@ from pytest_mock import MockerFixture # Direct file import to avoid vllm_omni.__init__ patch dependencies. -_BASE = os.path.join( +_MODELS = os.path.join( os.path.dirname(__file__), os.pardir, os.pardir, @@ -30,14 +30,16 @@ "vllm_omni", "model_executor", "models", - "qwen3_tts", ) +_BASE = os.path.join(_MODELS, "qwen3_tts") +_COMMON = os.path.join(_MODELS, "common") def _load_module(name: str, filename: str): path = os.path.abspath(os.path.join(_BASE, filename)) spec = importlib.util.spec_from_file_location(name, path) mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod # register before exec (needed for dataclasses etc.) spec.loader.exec_module(mod) return mod @@ -59,8 +61,17 @@ def _build_mock_modules(mocker: MockerFixture) -> dict[str, object]: weight_utils_mock = mocker.MagicMock() weight_utils_mock.default_weight_loader = lambda p, w: None - pkg = types.ModuleType("vllm_omni.model_executor.models.qwen3_tts") - pkg.__path__ = [os.path.abspath(_BASE)] + tts_pkg = types.ModuleType("vllm_omni.model_executor.models.qwen3_tts") + tts_pkg.__path__ = [os.path.abspath(_BASE)] + + common_pkg = types.ModuleType("vllm_omni.model_executor.models.common") + common_pkg.__path__ = [os.path.abspath(_COMMON)] + + models_pkg = types.ModuleType("vllm_omni.model_executor.models") + models_pkg.__path__ = [os.path.abspath(_MODELS)] + + vllm_parallel_mock = mocker.MagicMock() + vllm_parallel_mock.VocabParallelEmbedding = torch.nn.Embedding return { "vllm_omni": mocker.MagicMock(), @@ -69,9 +80,11 @@ def _build_mock_modules(mocker: MockerFixture) -> dict[str, object]: "vllm.config": mocker.MagicMock(), "vllm.config.vllm": vllm_config_mod, "vllm.model_executor.model_loader.weight_utils": weight_utils_mock, + "vllm.model_executor.layers.vocab_parallel_embedding": vllm_parallel_mock, "vllm_omni.model_executor": types.ModuleType("vllm_omni.model_executor"), - "vllm_omni.model_executor.models": types.ModuleType("vllm_omni.model_executor.models"), - "vllm_omni.model_executor.models.qwen3_tts": pkg, + "vllm_omni.model_executor.models": models_pkg, + "vllm_omni.model_executor.models.common": common_pkg, + "vllm_omni.model_executor.models.qwen3_tts": tts_pkg, } @@ -88,6 +101,15 @@ def _load_target_classes(mocker: MockerFixture): ) sys.modules["vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts"] = config_mod + # Load the shared common module (thin wrappers import from it) + common_cp_path = os.path.abspath(os.path.join(_COMMON, "qwen3_code_predictor.py")) + common_spec = importlib.util.spec_from_file_location( + "vllm_omni.model_executor.models.common.qwen3_code_predictor", common_cp_path + ) + common_cp_mod = importlib.util.module_from_spec(common_spec) + sys.modules["vllm_omni.model_executor.models.common.qwen3_code_predictor"] = common_cp_mod + common_spec.loader.exec_module(common_cp_mod) + cp_mod = _load_module( "vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code_predictor_vllm", "qwen3_tts_code_predictor_vllm.py", @@ -104,6 +126,7 @@ def loaded_target_classes(mocker: MockerFixture): config_mod.Qwen3TTSTalkerConfig, cp_mod.Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM, cp_mod.Qwen3TTSTalkerCodePredictorModelVLLM, + cp_mod.CodePredictorWrapperConfig, ) @@ -114,6 +137,7 @@ def _make_tiny_config(loaded_target_classes) -> tuple: qwen3_tts_talker_config, _, _, + _, ) = loaded_target_classes cp_config = qwen3_tts_talker_code_predictor_config( vocab_size=64, @@ -145,7 +169,7 @@ class TestCodePredictorDtypeAlignment: def test_ensure_buffers_uses_given_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None: """_ensure_buffers should create proj_buf with the given dtype.""" - _, _, code_predictor_wrapper, _ = loaded_target_classes + _, _, code_predictor_wrapper, _, _ = loaded_target_classes cp_config, talker_config = _make_tiny_config(loaded_target_classes) vllm_config = _make_vllm_config(mocker) @@ -156,17 +180,17 @@ def test_ensure_buffers_uses_given_dtype(self, mocker: MockerFixture, loaded_tar ) # Create buffer in float16 - predictor._ensure_buffers(torch.device("cpu"), torch.float16) + predictor._ensure_buffers(torch.device("cpu"), torch.float16, 4) assert predictor._proj_buf is not None assert predictor._proj_buf.dtype == torch.float16 # Re-create buffer in float32 (different dtype triggers re-allocation) - predictor._ensure_buffers(torch.device("cpu"), torch.float32) + predictor._ensure_buffers(torch.device("cpu"), torch.float32, 4) assert predictor._proj_buf.dtype == torch.float32 def test_warmup_aligns_buffer_to_model_params(self, mocker: MockerFixture, loaded_target_classes) -> None: """_warmup_buckets should align proj_buf dtype to model parameters.""" - _, _, code_predictor_wrapper, _ = loaded_target_classes + _, _, code_predictor_wrapper, _, _ = loaded_target_classes cp_config, talker_config = _make_tiny_config(loaded_target_classes) vllm_config = _make_vllm_config(mocker, max_num_seqs=2) @@ -180,7 +204,7 @@ def test_warmup_aligns_buffer_to_model_params(self, mocker: MockerFixture, loade predictor = predictor.to(torch.float16) # Pre-create proj_buf with WRONG dtype (float32) — simulating the bug - predictor._ensure_buffers(torch.device("cpu"), torch.float32) + predictor._ensure_buffers(torch.device("cpu"), torch.float32, 2) assert predictor._proj_buf.dtype == torch.float32 # Simulate _setup_compile having cached model dtype and compiled forward @@ -194,7 +218,7 @@ def test_warmup_aligns_buffer_to_model_params(self, mocker: MockerFixture, loade def test_setup_compile_caches_model_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None: """_setup_compile should cache model parameter dtype.""" - _, _, code_predictor_wrapper, _ = loaded_target_classes + _, _, code_predictor_wrapper, _, _ = loaded_target_classes cp_config, talker_config = _make_tiny_config(loaded_target_classes) vllm_config = _make_vllm_config(mocker, max_num_seqs=2) @@ -211,7 +235,7 @@ def test_setup_compile_caches_model_dtype(self, mocker: MockerFixture, loaded_ta def test_forward_with_mismatched_input_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None: """forward() should not crash when inputs are float32 but model is float16.""" - _, _, code_predictor_wrapper, _ = loaded_target_classes + _, _, code_predictor_wrapper, _, _ = loaded_target_classes cp_config, talker_config = _make_tiny_config(loaded_target_classes) vllm_config = _make_vllm_config(mocker, max_num_seqs=2) @@ -250,9 +274,9 @@ class TestCodePredictorModelDtype: def test_model_forward_float16(self, loaded_target_classes) -> None: """Inner model forward should work in float16.""" - _, _, _, code_predictor_model = loaded_target_classes + _, _, _, code_predictor_model, _ = loaded_target_classes cp_config, _ = _make_tiny_config(loaded_target_classes) - model = code_predictor_model(cp_config, talker_hidden_size=32).to(torch.float16) + model = code_predictor_model(cp_config, embedding_dim=32).to(torch.float16) bsz, seq_len = 1, 4 inputs = torch.randn(bsz, seq_len, 32, dtype=torch.float16) @@ -264,9 +288,9 @@ def test_model_forward_float16(self, loaded_target_classes) -> None: def test_model_forward_float32(self, loaded_target_classes) -> None: """Inner model forward should work in float32.""" - _, _, _, code_predictor_model = loaded_target_classes + _, _, _, code_predictor_model, _ = loaded_target_classes cp_config, _ = _make_tiny_config(loaded_target_classes) - model = code_predictor_model(cp_config, talker_hidden_size=32).to(torch.float32) + model = code_predictor_model(cp_config, embedding_dim=32).to(torch.float32) bsz, seq_len = 1, 4 inputs = torch.randn(bsz, seq_len, 32, dtype=torch.float32) @@ -275,3 +299,37 @@ def test_model_forward_float32(self, loaded_target_classes) -> None: output = model(inputs, pos_ids) assert output.dtype == torch.float32 assert output.shape == (bsz, seq_len, 32) + + +class TestCodePredictorWrapperConfig: + """Test wrapper configuration for different models.""" + + def test_omni_config(self, loaded_target_classes) -> None: + """Qwen3-Omni uses correct wrapper config.""" + _, _, _, _, code_predictor_wrapper_config = loaded_target_classes + config = code_predictor_wrapper_config( + use_cuda_graphs=False, + use_parallel_embedding=True, + use_projection=False, + return_proj_buf=True, + sampling_mode="stored", + ) + assert config.use_cuda_graphs is False + assert config.use_parallel_embedding is True + assert config.return_proj_buf is True + assert config.sampling_mode == "stored" + + def test_tts_config(self, loaded_target_classes) -> None: + """Qwen3-TTS uses correct wrapper config.""" + _, _, _, _, code_predictor_wrapper_config = loaded_target_classes + config = code_predictor_wrapper_config( + use_cuda_graphs=True, + use_parallel_embedding=False, + use_projection=True, + return_proj_buf=False, + sampling_mode="per_call", + ) + assert config.use_cuda_graphs is True + assert config.use_parallel_embedding is False + assert config.return_proj_buf is False + assert config.sampling_mode == "per_call" diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py index bf40aa77cd5..3a7fe4bad77 100644 --- a/vllm_omni/engine/stage_init_utils.py +++ b/vllm_omni/engine/stage_init_utils.py @@ -192,8 +192,9 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata: default_sampling_params: OmniSamplingParams = SPClass(**default_sp) custom_process_input_func: Callable | None = None - if hasattr(stage_config, "custom_process_input_func"): - mod_path, fn_name = stage_config.custom_process_input_func.rsplit(".", 1) + _cpif_path = getattr(stage_config, "custom_process_input_func", None) + if _cpif_path: + mod_path, fn_name = _cpif_path.rsplit(".", 1) custom_process_input_func = getattr(importlib.import_module(mod_path), fn_name) prompt_expand_func: Callable | None = None diff --git a/vllm_omni/model_executor/models/common/__init__.py b/vllm_omni/model_executor/models/common/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_omni/model_executor/models/common/qwen3_code_predictor.py b/vllm_omni/model_executor/models/common/qwen3_code_predictor.py new file mode 100644 index 00000000000..3a904442fa8 --- /dev/null +++ b/vllm_omni/model_executor/models/common/qwen3_code_predictor.py @@ -0,0 +1,654 @@ +"""Qwen3 Code Predictor -- optimized re-prefill, no KV cache. + +Shared by Qwen3-Omni and Qwen3-TTS talker models. + +* SDPA attention (F.scaled_dot_product_attention) with native GQA support +* HF-compatible numerics (float32 RMSNorm, float32 RoPE, separate linear layers) +* Per-call embedding buffer to avoid cross-request aliasing +* Pre-allocated position_ids (read-only, safe to persist) +* torch.compile (epilogue_fusion=False) on inner transformer by default +* Optional manual CUDA graph capture per batch-size bucket +* Inline sampling (top-k + top-p) -- no custom op overhead +""" + +from __future__ import annotations + +import dataclasses +from collections.abc import Iterable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.platforms import current_omni_platform + +logger = init_logger(__name__) + + +# =================================================================== +# HF-numerics-compatible layers for code predictor +# =================================================================== +# +# These use plain PyTorch ops (nn.Linear, manual RMSNorm in float32, +# rotate_half RoPE) to produce outputs numerically identical to the +# HuggingFace reference. vLLM's fused kernels (RMSNorm, QKVParallel, +# get_rope) introduce small precision differences that compound across +# the autoregressive steps of the code predictor, causing severe +# audio quality degradation. +# +# See: https://github.com/vllm-project/vllm-omni/issues/2274 + + +class _RMSNorm(nn.Module): + """RMSNorm matching HuggingFace's implementation exactly. + + Computes variance in float32 to avoid bfloat16 precision loss. + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class _RotaryEmbedding(nn.Module): + """RoPE matching HuggingFace's implementation exactly. + + Forces float32 computation for cos/sin, matching HF's torch.autocast(enabled=False). + """ + + def __init__(self, config) -> None: + super().__init__() + head_dim = getattr( + config, + "head_dim", + config.hidden_size // config.num_attention_heads, + ) + rope_theta = getattr(config, "rope_theta", 10000.0) + inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # position_ids: [batch, seq_len] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 (matching HF) + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# =================================================================== +# Attention +# =================================================================== + + +class CodePredictorAttention(nn.Module): + """Multi-head self-attention for code predictor. + + Uses ``F.scaled_dot_product_attention`` with HF-compatible RoPE and RMSNorm. + No KV cache -- the code predictor always re-prefills the full (short) + sequence each AR step. + + Input : [B, seq_len, hidden_size] + Output: [B, seq_len, hidden_size] + """ + + def __init__(self, config, *, prefix: str = "") -> None: + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr( + config, + "head_dim", + config.hidden_size // config.num_attention_heads, + ) + self.hidden_size = config.hidden_size + self.scaling = self.head_dim**-0.5 + self._use_gqa = self.num_kv_heads != self.num_heads + + # Separate q/k/v projections matching HF (no fused packing) + bias = getattr(config, "attention_bias", False) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.q_norm = _RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = _RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + bsz, seq_len, _ = hidden_states.shape + hidden_shape_q = (bsz, seq_len, self.num_heads, self.head_dim) + hidden_shape_kv = (bsz, seq_len, self.num_kv_heads, self.head_dim) + + q = self.q_norm(self.q_proj(hidden_states).view(hidden_shape_q)).transpose(1, 2) + k = self.k_norm(self.k_proj(hidden_states).view(hidden_shape_kv)).transpose(1, 2) + v = self.v_proj(hidden_states).view(hidden_shape_kv).transpose(1, 2) + + cos, sin = position_embeddings + # cos/sin are [batch, seq_len, head_dim], need unsqueeze at dim=1 for heads + cos = cos.unsqueeze(1) # [batch, 1, seq_len, head_dim] + sin = sin.unsqueeze(1) + q = (q * cos) + (_rotate_half(q) * sin) + k = (k * cos) + (_rotate_half(k) * sin) + + attn_out = F.scaled_dot_product_attention( + q, + k, + v, + scale=self.scaling, + is_causal=True, + enable_gqa=self._use_gqa, + ) + + attn_out = attn_out.transpose(1, 2).reshape(bsz, seq_len, -1) + return self.o_proj(attn_out) + + +# =================================================================== +# MLP +# =================================================================== + + +class CodePredictorMLP(nn.Module): + """SiLU-gated MLP for code predictor, matching HF's implementation.""" + + def __init__(self, config, *, prefix: str = "") -> None: + super().__init__() + self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) + + +# =================================================================== +# Decoder Layer +# =================================================================== + + +class CodePredictorDecoderLayer(nn.Module): + """Transformer decoder layer (SDPA, no KV cache).""" + + def __init__(self, config, *, prefix: str = "") -> None: + super().__init__() + self.self_attn = CodePredictorAttention(config, prefix=f"{prefix}.self_attn") + self.mlp = CodePredictorMLP(config, prefix=f"{prefix}.mlp") + self.input_layernorm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, position_embeddings) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +# =================================================================== +# Base Transformer Model (re-prefill, no KV cache) +# =================================================================== + + +class CodePredictorBaseModel(nn.Module): + """Inner transformer for code predictor. + + Signature: ``forward(inputs_embeds, position_ids) -> hidden_states`` + """ + + def __init__( + self, + config, + *, + embedding_dim: int | None = None, + use_parallel_embedding: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + emb_dim = int(embedding_dim) if embedding_dim is not None else int(config.hidden_size) + if use_parallel_embedding: + self.codec_embedding = nn.ModuleList( + [VocabParallelEmbedding(config.vocab_size, emb_dim) for _ in range(config.num_code_groups - 1)] + ) + else: + self.codec_embedding = nn.ModuleList( + [nn.Embedding(config.vocab_size, emb_dim) for _ in range(config.num_code_groups - 1)] + ) + + self.layers = nn.ModuleList( + [ + CodePredictorDecoderLayer(config, prefix=f"{prefix}.layers.{idx}") + for idx in range(config.num_hidden_layers) + ] + ) + self.norm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = _RotaryEmbedding(config) + + def get_input_embeddings(self) -> nn.ModuleList: + return self.codec_embedding + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for layer in self.layers: + hidden_states = layer(hidden_states, position_embeddings) + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + param = params_dict.get(name) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +# =================================================================== +# Wrapper Configuration +# =================================================================== + + +@dataclasses.dataclass +class CodePredictorWrapperConfig: + """Controls behavioral differences between model-specific code predictors.""" + + use_cuda_graphs: bool = False + use_parallel_embedding: bool = False + use_projection: bool = False + return_proj_buf: bool = False + sampling_mode: str = "stored" + + +# =================================================================== +# Code Predictor Wrapper (optimized re-prefill, persistent buffers) +# =================================================================== + + +class CodePredictorWrapper(nn.Module): + """Optimized code predictor -- re-prefill approach, no KV cache. + + Each AR step forwards the full growing sequence (len 2 -> num_code_groups+1) + through the transformer. The extra O(T^2) FLOPs are negligible for + short sequences, and this avoids all KV-cache management overhead. + + Optimizations: + 1. Per-call embedding buffer -- avoids cross-request aliasing. + 2. Pre-allocated position_ids -- no torch.arange per step. + 3. Cached module references -- bypass ModuleList indexing. + 4. torch.compile on inner transformer. + 5. Inline sampling (top-k + top-p) -- no custom op overhead. + 6. Optional manual CUDA graph capture per batch-size bucket. + """ + + def __init__( + self, + *, + vllm_config: VllmConfig, + cp_config, + wrapper_config: CodePredictorWrapperConfig, + talker_hidden_size: int | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self._vllm_config = vllm_config + self.config = cp_config + self._wrapper_config = wrapper_config + self.prefix = prefix + + self._num_groups = int(cp_config.num_code_groups) + self._cp_hidden = int(cp_config.hidden_size) + + # For Omni backward compat (accessed by the talker) + self.num_code_groups = self._num_groups + + # Determine embedding dimension + _talker_hidden = int(talker_hidden_size) if talker_hidden_size is not None else self._cp_hidden + + self.model = CodePredictorBaseModel( + cp_config, + embedding_dim=_talker_hidden, + use_parallel_embedding=wrapper_config.use_parallel_embedding, + prefix=f"{prefix}.model" if prefix else "model", + ) + + self.lm_head = nn.ModuleList( + [nn.Linear(cp_config.hidden_size, cp_config.vocab_size, bias=False) for _ in range(self._num_groups - 1)] + ) + + # Projection: Identity when hidden sizes match or not needed + if wrapper_config.use_projection and _talker_hidden != self._cp_hidden: + self.small_to_mtp_projection = nn.Linear(_talker_hidden, self._cp_hidden, bias=True) + else: + self.small_to_mtp_projection = nn.Identity() + + # Sampling defaults for "stored" mode + self._top_k: int = 50 + self._top_p: float = 0.8 + + # Lazily initialised state + self._proj_buf: torch.Tensor | None = None + self._model_dtype: torch.dtype | None = None + self._compiled_model_fwd = None + self._bucket_sizes: list[int] = [] + self._bucket_pos_ids: dict[int, torch.Tensor] = {} + self._lm_heads_list: list[nn.Module] | None = None + self._codec_embeds_list: list[nn.Module] | None = None + self._cuda_graphs: dict[int, tuple[torch.cuda.CUDAGraph, torch.Tensor]] = {} + + def get_input_embeddings(self) -> nn.ModuleList: + return self.model.get_input_embeddings() + + def set_sampling_params(self, top_k: int = 50, top_p: float = 0.8) -> None: + """Configure sampling parameters to maintain consistency with previous implementation.""" + self._top_k = top_k + self._top_p = top_p + logger.debug("Sampling parameters updated: top_k=%d, top_p=%.2f", top_k, top_p) + + # ------------------------------------------------------------------ + # Lazy-init helpers + # ------------------------------------------------------------------ + + def _ensure_buffers(self, device: torch.device, dtype: torch.dtype, bsz: int) -> None: + """Ensure the projection buffer can hold at least *bsz* rows.""" + max_seq = self._num_groups + 1 + if ( + self._proj_buf is not None + and self._proj_buf.device == device + and self._proj_buf.dtype == dtype + and self._proj_buf.shape[0] >= bsz + ): + return + self._proj_buf = torch.zeros(bsz, max_seq, self._cp_hidden, dtype=dtype, device=device) + + def _setup_compile(self) -> None: + """Lazily set up torch.compile with optional CUDA graph capture.""" + if self._compiled_model_fwd is not None: + return + + # Cache model parameter dtype so forward() doesn't need to query it + # on every call. Also ensures warmup buffers match model precision + # even when upstream modules produce a different dtype (#2385). + self._model_dtype = next(self.model.parameters()).dtype + self._lm_heads_list = list(self.lm_head) + self._codec_embeds_list = list(self.model.codec_embedding) + + if not current_omni_platform.supports_torch_inductor(): + logger.warning_once("code_predictor: torch.compile disabled") + self._compiled_model_fwd = self.model.forward + return + + # torch.compile fuses RMSNorm/RoPE in ways that lose float32 + # precision, compounding across AR steps. Use epilogue_fusion=False + # to disable the problematic fusions while still getting kernel + # fusion benefits for the linear layers and SDPA. + self._compiled_model_fwd = torch.compile( + self.model.forward, + dynamic=False, + options={"epilogue_fusion": False}, + ) + self._warmup_buckets() + + if self._wrapper_config.use_cuda_graphs: + self._capture_cuda_graphs() + logger.info("code_predictor: torch.compile (no epilogue fusion) + CUDA graphs") + else: + logger.info("code_predictor: torch.compile (dynamic=False, no epilogue fusion)") + + def _padded_bsz(self, bsz: int) -> int: + """Round batch size up to nearest power-of-2 bucket.""" + for bucket in self._bucket_sizes: + if bsz <= bucket: + return bucket + return bsz + + def _warmup_buckets(self) -> None: + """Warmup power-of-2 batch-size buckets to front-load Inductor compilation.""" + max_bsz = self._vllm_config.scheduler_config.max_num_seqs + bucket_sizes = [1 << i for i in range(max_bsz.bit_length()) if (1 << i) <= max_bsz] + if max_bsz not in bucket_sizes: + bucket_sizes.append(max_bsz) + self._bucket_sizes = sorted(bucket_sizes) + + max_seq = self._num_groups + 1 + device = next(self.model.parameters()).device + + # Ensure proj_buf matches model parameter dtype to avoid dtype + # mismatch during warmup compilation (see #2385). + self._ensure_buffers(device, self._model_dtype, max(self._bucket_sizes)) + proj_buf = self._proj_buf + + for bsz in self._bucket_sizes: + pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(bsz, -1).contiguous() + self._bucket_pos_ids[bsz] = pos_ids + for _ in range(3): + self._compiled_model_fwd(proj_buf[:bsz, :max_seq, :], pos_ids) + logger.info("code_predictor: warmup done for buckets %s", self._bucket_sizes) + + def _capture_cuda_graphs(self) -> None: + """Capture a CUDA graph per bucket using vLLM's global graph pool.""" + from vllm.platforms import current_platform + + pool = current_platform.get_global_graph_pool() + max_seq = self._num_groups + 1 + proj_buf = self._proj_buf + + for bsz in self._bucket_sizes: + static_input = proj_buf[:bsz, :max_seq, :] + pos_ids = self._bucket_pos_ids[bsz] + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, pool=pool): + static_output = self._compiled_model_fwd(static_input, pos_ids) + + self._cuda_graphs[bsz] = (g, static_output) + + logger.info("code_predictor: captured CUDA graphs for buckets %s", self._bucket_sizes) + + # ------------------------------------------------------------------ + # Forward -- re-prefill + inline sampling + # ------------------------------------------------------------------ + + @torch.inference_mode() + def forward( + self, + layer0_code: torch.Tensor, + layer0_embed: torch.Tensor, + last_talker_hidden: torch.Tensor, + do_sample: bool = True, + temperature: float = 0.9, + top_k: int = 50, + top_p: float = 1.0, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Predict residual codebooks 1..G-1 autoregressively via re-prefill.""" + bsz = int(layer0_code.shape[0]) + num_groups = self._num_groups + device = layer0_code.device + + # _setup_compile caches _model_dtype on first call; use it for buffers + # so they always match model weight precision (#2385). + self._setup_compile() + dtype = self._model_dtype + + padded_bsz = self._padded_bsz(bsz) + self._ensure_buffers(device, dtype, padded_bsz) + + proj_buf = self._proj_buf + max_seq = num_groups + 1 + projection = self.small_to_mtp_projection + model_fwd = self._compiled_model_fwd + lm_heads = self._lm_heads_list + codec_embeds = self._codec_embeds_list + + # Zero the padded region of the buffer + proj_buf[:padded_bsz].zero_() + + # Fill buffer positions 0 (talker hidden) & 1 (layer0 embed) + proj_buf[:bsz, 0, :] = projection(last_talker_hidden.reshape(bsz, 1, -1).to(dtype)).reshape(bsz, -1) + proj_buf[:bsz, 1, :] = projection(layer0_embed.reshape(bsz, 1, -1).to(dtype)).reshape(bsz, -1) + + # Get pre-computed pos_ids for this bucket + full_pos_ids = self._bucket_pos_ids.get(padded_bsz) + if full_pos_ids is None: + full_pos_ids = ( + torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(padded_bsz, -1).contiguous() + ) + + # Use captured CUDA graph if available, otherwise call compiled fn. + cuda_graph_entry = self._cuda_graphs.get(padded_bsz) + + # Prepare sampling parameters + stored_mode = self._wrapper_config.sampling_mode == "stored" + if stored_mode: + s_top_k = self._top_k + s_top_p = self._top_p + else: + use_sampling = do_sample and temperature > 0 + inv_temperature = 1.0 / max(temperature, 1e-6) if use_sampling else 0.0 + if use_sampling and top_p != 1.0: + raise NotImplementedError( + "top_p sampling is not implemented for the vLLM-native code predictor; please set top_p=1.0." + ) + + # Output codes -- shape depends on return mode + if self._wrapper_config.return_proj_buf: + all_codes = torch.empty(bsz, num_groups, 1, dtype=torch.int64, device=device) + all_codes[:, 0] = layer0_code.reshape(bsz, -1)[:, :1] + else: + all_codes = torch.empty(bsz, num_groups, dtype=torch.long, device=device) + all_codes[:, 0] = layer0_code.reshape(bsz) + + # Autoregressive loop: predict layers 1..G-1 + for step in range(1, num_groups): + # Run transformer (CUDA graph replay or compiled forward) + if cuda_graph_entry is not None: + cuda_graph_entry[0].replay() + hidden_out = cuda_graph_entry[1] + else: + hidden_out = model_fwd(proj_buf[:padded_bsz, :max_seq, :], full_pos_ids) + + logits = lm_heads[step - 1](hidden_out[:bsz, step, :]) + + # Sample next code + if stored_mode: + # "stored" mode: top-k -> top-p -> softmax -> multinomial + if s_top_k > 0: + topk_vals, _ = logits.topk(s_top_k, dim=-1) + logits = logits.masked_fill(logits < topk_vals[:, -1:], float("-inf")) + if s_top_p < 1.0: + sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) + sorted_probs = F.softmax(sorted_logits, dim=-1) + cumulative_probs = sorted_probs.cumsum(dim=-1) + remove_mask = (cumulative_probs - sorted_probs) >= s_top_p + sorted_logits[remove_mask] = float("-inf") + logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) + probs = F.softmax(logits, dim=-1) + code = torch.multinomial(probs, num_samples=1) + else: + # "per_call" mode: temperature-scaled + top-k + if use_sampling: + scaled = logits * inv_temperature + if top_k > 0: + topk_vals, _ = scaled.topk(top_k, dim=-1) + scaled = scaled.masked_fill(scaled < topk_vals[:, -1:], float("-inf")) + probs = F.softmax(scaled, dim=-1) + code = torch.multinomial(probs, num_samples=1) + else: + code = logits.argmax(dim=-1, keepdim=True) + + # Store code + if self._wrapper_config.return_proj_buf: + all_codes[:, step] = code + else: + all_codes[:, step] = code.reshape(bsz) + + # Embed predicted code -> project -> next buffer position + if step < num_groups - 1 or self._wrapper_config.return_proj_buf: + new_embed = codec_embeds[step - 1](code) + proj_buf[:bsz, step + 1, :] = projection(new_embed.reshape(bsz, 1, -1)).reshape(bsz, -1) + + if self._wrapper_config.return_proj_buf: + return all_codes, proj_buf[:bsz].clone() + return all_codes + + # ------------------------------------------------------------------ + # Weight loading + # ------------------------------------------------------------------ + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights directly (no fused projection remapping needed).""" + loaded: set[str] = set() + model_weights: list[tuple[str, torch.Tensor]] = [] + other_weights: list[tuple[str, torch.Tensor]] = [] + + for name, w in weights: + if "rotary_emb.inv_freq" in name: + continue + if name.startswith("model."): + model_weights.append((name[len("model.") :], w)) + else: + other_weights.append((name, w)) + + loaded_model = self.model.load_weights(model_weights) + loaded |= {f"model.{n}" for n in loaded_model} + + params = dict(self.named_parameters(remove_duplicate=False)) + for name, w in other_weights: + param = params.get(name) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, w) + loaded.add(name) + + return loaded diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py index 2ceaafdb670..819e22e181e 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py @@ -1,510 +1,28 @@ -"""Qwen3-Omni Code Predictor -- optimized re-prefill, no KV cache. +"""Qwen3-Omni Code Predictor -- thin wrapper over CodePredictorWrapper.""" -* SDPA attention (F.scaled_dot_product_attention) with native GQA support -* HF-compatible numerics (float32 RMSNorm, float32 RoPE, separate linear layers) -* Per-call embedding buffer to avoid cross-request aliasing -* Pre-allocated position_ids (read-only, safe to persist) -* torch.compile (epilogue_fusion=False) on inner transformer by default -* Inline sampling (top-k + top-p) -- no custom op overhead -""" - -import torch -import torch.nn as nn -import torch.nn.functional as F from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.model_loader.weight_utils import default_weight_loader - -from vllm_omni.platforms import current_omni_platform - -logger = init_logger(__name__) - - -# =================================================================== -# HF-numerics-compatible layers for code predictor -# =================================================================== -# -# These use plain PyTorch ops (nn.Linear, manual RMSNorm in float32, -# rotate_half RoPE) to produce outputs numerically identical to the -# HuggingFace reference. vLLM's fused kernels (RMSNorm, QKVParallel, -# get_rope) introduce small precision differences that compound across -# the autoregressive steps of the code predictor, causing severe -# audio quality degradation. -# -# See: https://github.com/vllm-project/vllm-omni/issues/2274 - - -class _RMSNorm(nn.Module): - """RMSNorm matching HuggingFace's implementation exactly. - - Computes variance in float32 to avoid bfloat16 precision loss. - """ - - def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -def _rotate_half(x: torch.Tensor) -> torch.Tensor: - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -class _RotaryEmbedding(nn.Module): - """RoPE matching HuggingFace's implementation exactly. - - Forces float32 computation for cos/sin, matching HF's torch.autocast(enabled=False). - """ - - def __init__(self, config) -> None: - super().__init__() - head_dim = getattr( - config, - "head_dim", - config.hidden_size // config.num_attention_heads, - ) - rope_theta = getattr(config, "rope_theta", 10000.0) - inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - # position_ids: [batch, seq_len] - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - - # Force float32 (matching HF) - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class Qwen3OmniCodePredictorAttention(nn.Module): - """Multi-head self-attention for code predictor. - - Uses ``F.scaled_dot_product_attention`` with HF-compatible RoPE and RMSNorm. - No KV cache -- the code predictor always re-prefills the full (short) - sequence each AR step. - - Input : [B, seq_len, hidden_size] - Output: [B, seq_len, hidden_size] - """ - - def __init__( - self, - config, - prefix: str = "", - ): - super().__init__() - cp_cfg = config.code_predictor_config - self.num_heads = cp_cfg.num_attention_heads - self.num_kv_heads = cp_cfg.num_key_value_heads - self.head_dim = getattr( - cp_cfg, - "head_dim", - cp_cfg.hidden_size // cp_cfg.num_attention_heads, - ) - self.hidden_size = cp_cfg.hidden_size - self.scaling = self.head_dim**-0.5 - self._use_gqa = self.num_kv_heads != self.num_heads - - # Separate q/k/v projections matching HF (no fused packing) - self.q_proj = nn.Linear( - self.hidden_size, - self.num_heads * self.head_dim, - bias=False, - ) - self.k_proj = nn.Linear( - self.hidden_size, - self.num_kv_heads * self.head_dim, - bias=False, - ) - self.v_proj = nn.Linear( - self.hidden_size, - self.num_kv_heads * self.head_dim, - bias=False, - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, - self.hidden_size, - bias=False, - ) - self.q_norm = _RMSNorm(self.head_dim, eps=cp_cfg.rms_norm_eps) - self.k_norm = _RMSNorm(self.head_dim, eps=cp_cfg.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - ) -> torch.Tensor: - bsz, seq_len, _ = hidden_states.shape - hidden_shape_q = (bsz, seq_len, self.num_heads, self.head_dim) - hidden_shape_kv = (bsz, seq_len, self.num_kv_heads, self.head_dim) - - q = self.q_norm(self.q_proj(hidden_states).view(hidden_shape_q)).transpose(1, 2) - k = self.k_norm(self.k_proj(hidden_states).view(hidden_shape_kv)).transpose(1, 2) - v = self.v_proj(hidden_states).view(hidden_shape_kv).transpose(1, 2) - - cos, sin = position_embeddings - # cos/sin are [batch, seq_len, head_dim], need unsqueeze at dim=1 for heads - cos = cos.unsqueeze(1) # [batch, 1, seq_len, head_dim] - sin = sin.unsqueeze(1) - q = (q * cos) + (_rotate_half(q) * sin) - k = (k * cos) + (_rotate_half(k) * sin) - - attn_out = F.scaled_dot_product_attention( - q, - k, - v, - scale=self.scaling, - is_causal=True, - enable_gqa=self._use_gqa, - ) - - attn_out = attn_out.transpose(1, 2).reshape(bsz, seq_len, -1) - output = self.o_proj(attn_out) - return output - - -# =================================================================== -# MLP -# =================================================================== - - -class Qwen3OmniCodePredictorMLP(nn.Module): - """SiLU-gated MLP for code predictor, matching HF's implementation.""" - - def __init__( - self, - config, - prefix: str = "", - ): - super().__init__() - hidden_size = config.code_predictor_config.hidden_size - intermediate_size = config.code_predictor_config.intermediate_size - - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) - - -# =================================================================== -# Decoder Layer -# =================================================================== - - -class Qwen3OmniCodePredictorDecoderLayer(nn.Module): - """Transformer decoder layer (SDPA, no KV cache).""" - - def __init__( - self, - config, - prefix: str = "", - ) -> None: - super().__init__() - self.self_attn = Qwen3OmniCodePredictorAttention( - config, - prefix=f"{prefix}.self_attn", - ) - self.mlp = Qwen3OmniCodePredictorMLP( - config, - prefix=f"{prefix}.mlp", - ) - cp_cfg = config.code_predictor_config - self.input_layernorm = _RMSNorm(cp_cfg.hidden_size, eps=cp_cfg.rms_norm_eps) - self.post_attention_layernorm = _RMSNorm(cp_cfg.hidden_size, eps=cp_cfg.rms_norm_eps) - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn(hidden_states, position_embeddings) - hidden_states = residual + hidden_states +from vllm_omni.model_executor.models.common.qwen3_code_predictor import ( + CodePredictorWrapper, + CodePredictorWrapperConfig, +) - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states +class Qwen3OmniMoeTalkerCodePredictor(CodePredictorWrapper): + """Qwen3-Omni code predictor (no CUDA graphs, VocabParallelEmbedding).""" -# =================================================================== -# Base Transformer Model (re-prefill, no KV cache) -# =================================================================== - - -class Qwen3OmniCodePredictorBaseModel(nn.Module): - """Inner transformer for code predictor. - - Signature: ``forward(inputs_embeds, position_ids) -> hidden_states`` - -- plain Tensor in, plain Tensor out (no namedtuple). - """ - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config.code_predictor_config - self.config = config - - self.codec_embedding = nn.ModuleList( - [VocabParallelEmbedding(config.vocab_size, config.hidden_size) for _ in range(config.num_code_groups - 1)] - ) - - self.layers = nn.ModuleList( - [ - Qwen3OmniCodePredictorDecoderLayer( - vllm_config.model_config.hf_config, - prefix=f"{prefix}.layers.{idx}", - ) - for idx in range(config.num_hidden_layers) - ] - ) - self.norm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = _RotaryEmbedding(config) - - def forward( - self, - inputs_embeds: torch.Tensor, - position_ids: torch.Tensor, - ) -> torch.Tensor: - hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids) - for layer in self.layers: - hidden_states = layer(hidden_states, position_embeddings) - hidden_states = self.norm(hidden_states) - return hidden_states - - -# =================================================================== -# Code Predictor Wrapper (optimized re-prefill, persistent buffers) -# =================================================================== - - -class Qwen3OmniMoeTalkerCodePredictor(nn.Module): - """Optimized code predictor -- re-prefill approach, no KV cache. - - Each AR step forwards the full growing sequence (len 2 -> num_code_groups+1) - through the transformer. The extra O(T^2) FLOPs are negligible for - short sequences, and this avoids all KV-cache management overhead. - - Optimizations: - 1. Per-call embedding buffer -- avoids cross-request aliasing. - 2. Pre-allocated position_ids -- no torch.arange per step. - 3. Cached module references -- bypass ModuleList indexing. - 4. torch.compile on inner transformer. - 5. Inline sampling (top-k + top-p) -- no custom op overhead. - """ - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - self.config = config - self.quant_config = vllm_config.quant_config - self.prefix = prefix - - self.num_code_groups = config.code_predictor_config.num_code_groups - self._hidden_size = config.code_predictor_config.hidden_size - - self.model = Qwen3OmniCodePredictorBaseModel( + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + cp_config = vllm_config.model_config.hf_config.code_predictor_config + super().__init__( vllm_config=vllm_config, + cp_config=cp_config, + wrapper_config=CodePredictorWrapperConfig( + use_cuda_graphs=False, + use_parallel_embedding=True, + use_projection=False, + return_proj_buf=True, + sampling_mode="stored", + ), + talker_hidden_size=cp_config.hidden_size, prefix=prefix, ) - - # One lm_head per residual layer (layers 1 .. G-1) - self.lm_head = nn.ModuleList( - [ - nn.Linear( - config.code_predictor_config.hidden_size, - config.code_predictor_config.vocab_size, - bias=False, - ) - for _ in range(self.num_code_groups - 1) - ] - ) - - self.set_sampling_params() - - # Lazily initialised position ids (read-only, safe to persist) - self._pos_ids: torch.Tensor | None = None - - # Cached plain-list refs (set once) - self._lm_heads: list | None = None - self._codec_embeds: list | None = None - - # Model forward (optionally compiled) - self._model_fwd: object | None = None - - def set_sampling_params(self, top_k: int = 50, top_p: float = 0.8): - """Configure sampling parameters to maintain consistency with previous implementation.""" - self._top_k = top_k - self._top_p = top_p - logger.debug(f"Sampling parameters updated: top_k={top_k}, top_p={top_p}s") - - # ------------------------------------------------------------------ - # Lazy-init helpers - # ------------------------------------------------------------------ - - def _ensure_pos_ids(self, device: torch.device) -> None: - if self._pos_ids is not None and self._pos_ids.device == device: - return - max_seq = self.num_code_groups + 1 - # [1, max_seq] for HF-style RoPE (will be expanded to [bsz, seq_len] at use) - self._pos_ids = torch.arange(max_seq, dtype=torch.long, device=device).unsqueeze(0) - - def _ensure_cached_refs(self) -> None: - if self._lm_heads is not None: - return - self._lm_heads = list(self.lm_head) - self._codec_embeds = list(self.model.codec_embedding) - - def _ensure_model_fwd(self) -> None: - if self._model_fwd is not None: - return - if current_omni_platform.supports_torch_inductor(): - # torch.compile fuses RMSNorm/RoPE in ways that lose float32 - # precision, compounding across AR steps. Use epilogue_fusion=False - # to disable the problematic fusions while still getting kernel - # fusion benefits for the linear layers and SDPA. - self._model_fwd = torch.compile( - self.model.forward, - dynamic=True, - options={ - "epilogue_fusion": False, - }, - ) - logger.info("code_predictor: torch.compile enabled (no epilogue fusion)") - else: - self._model_fwd = self.model.forward - logger.info("code_predictor: using eager mode (no torch.compile)") - - # ------------------------------------------------------------------ - # Forward -- re-prefill + inline sampling - # ------------------------------------------------------------------ - - @torch.inference_mode() - def forward( - self, - layer0_code: torch.Tensor, - layer0_embed: torch.Tensor, - last_talker_hidden: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Predict residual codebooks 1..G-1 autoregressively via re-prefill. - - Args: - layer0_code: [bsz, 1] int64 - layer0_embed: [bsz, 1, hidden_size] - last_talker_hidden: [bsz, 1, hidden_size] - - Returns: - all_codes: [bsz, num_code_groups, 1] - proj_buf: [bsz, num_code_groups + 1, hidden_size] - pos 0 = last_talker_hidden (NOT a codec embed) - pos 1 = layer0_embed - pos 2.. = `codec_embedding[i](predicted_code_i)` - """ - bsz = int(layer0_code.shape[0]) - device = layer0_code.device - dtype = last_talker_hidden.dtype - num_groups = self.num_code_groups - - # Lazy init (read-only caches only) - self._ensure_pos_ids(device) - self._ensure_model_fwd() - self._ensure_cached_refs() - - # Allocate proj_buf locally each call to avoid cross-call aliasing - max_seq = num_groups + 1 - proj_buf = torch.zeros(bsz, max_seq, self._hidden_size, dtype=dtype, device=device) - pos_ids = self._pos_ids - model_fwd = self._model_fwd - lm_heads = self._lm_heads - codec_embeds = self._codec_embeds - - # Output codes - all_codes = torch.empty(bsz, num_groups, 1, dtype=torch.int64, device=device) - all_codes[:, 0] = layer0_code - - # Fill buffer positions 0 & 1 - proj_buf[:bsz, 0:1, :] = last_talker_hidden - proj_buf[:bsz, 1:2, :] = layer0_embed - - # Autoregressive loop: predict layers 1..G-1 - for step in range(1, num_groups): - seq_len = step + 1 - projected = proj_buf[:bsz, :seq_len, :] - # position_ids: [batch, seq_len] for HF-style RoPE - step_pos_ids = pos_ids[:, :seq_len].expand(bsz, -1) - - hidden_out = model_fwd(projected, step_pos_ids) - - # Inline sampling: top-k -> top-p -> softmax -> multinomial - logits = lm_heads[step - 1](hidden_out[:, -1, :]) # [bsz, vocab] - if self._top_k > 0: - topk_vals, _ = logits.topk(self._top_k, dim=-1) - logits = logits.masked_fill(logits < topk_vals[:, -1:], float("-inf")) - if self._top_p < 1.0: - sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) - cumulative_probs = F.softmax(sorted_logits, dim=-1).cumsum(dim=-1) - # Remove tokens with cumulative probability above top_p - remove_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= self._top_p - sorted_logits[remove_mask] = float("-inf") - logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) - probs = F.softmax(logits, dim=-1) - code = torch.multinomial(probs, num_samples=1) # [bsz, 1] - - all_codes[:, step] = code - - # Embed predicted code -> next buffer position - new_embed = codec_embeds[step - 1](code) # [batch, 1, hidden_size] - proj_buf[:bsz, step + 1 : step + 2, :] = new_embed - - return all_codes, proj_buf[:bsz] - - # ------------------------------------------------------------------ - # Weight loading - # ------------------------------------------------------------------ - - def load_weights(self, weights: list[tuple[str, torch.Tensor]]) -> set[str]: - """Load weights directly (no fused projection remapping needed). - - Since we use separate nn.Linear for q/k/v/o and gate/up/down, - weight names match the HF checkpoint directly. - """ - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - for name, loaded_weight in weights: - # Skip rotary embeddings - if "rotary_emb.inv_freq" in name: - continue - - param = params_dict.get(name) - if param is None: - continue - - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - - return loaded_params diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py index 1e84eaebaa5..8d2f0686ae0 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py @@ -1,318 +1,27 @@ +"""Qwen3-TTS Code Predictor -- thin wrapper over CodePredictorWrapper.""" + from __future__ import annotations from collections.abc import Iterable import torch -import torch.nn as nn -import torch.nn.functional as F from vllm.config import VllmConfig from vllm.config.vllm import set_current_vllm_config -from vllm.logger import init_logger -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, -) -from vllm_omni.platforms import current_omni_platform +from vllm_omni.model_executor.models.common.qwen3_code_predictor import ( + CodePredictorBaseModel, + CodePredictorWrapper, + CodePredictorWrapperConfig, +) from .configuration_qwen3_tts import Qwen3TTSTalkerCodePredictorConfig, Qwen3TTSTalkerConfig -logger = init_logger(__name__) - - -# =================================================================== -# HF-numerics-compatible layers for code predictor -# =================================================================== -# -# These use plain PyTorch ops (nn.Linear, manual RMSNorm in float32, -# rotate_half RoPE) to produce outputs numerically identical to the -# HuggingFace reference. vLLM's fused kernels (RMSNorm, QKVParallel, -# get_rope) introduce small precision differences that compound across -# the 15 autoregressive steps of the code predictor, causing severe -# audio quality degradation (UTMOS ~4.26 → ~2.66). -# -# See: https://github.com/vllm-project/vllm-omni/issues/2274 - - -class _RMSNorm(nn.Module): - """RMSNorm matching HuggingFace's Qwen3TTSRMSNorm exactly. - - Computes variance in float32 to avoid bfloat16 precision loss. - """ - - def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -def _rotate_half(x: torch.Tensor) -> torch.Tensor: - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -class _RotaryEmbedding(nn.Module): - """RoPE matching HuggingFace's Qwen3TTSRotaryEmbedding exactly. - - Forces float32 computation for cos/sin, matching HF's torch.autocast(enabled=False). - """ - - def __init__(self, config: Qwen3TTSTalkerCodePredictorConfig) -> None: - super().__init__() - head_dim = getattr( - config, - "head_dim", - config.hidden_size // config.num_attention_heads, - ) - # Standard default RoPE - rope_theta = getattr(config, "rope_theta", 10000.0) - inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - # position_ids: [batch, seq_len] - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - - # Force float32 (matching HF) - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class _CodePredictorAttention(nn.Module): - """Standalone multi-head attention for code predictor. - - Uses F.scaled_dot_product_attention with HF-compatible RoPE and RMSNorm. - Input: [B, seq_len, hidden_size], output: [B, seq_len, hidden_size]. - """ - - def __init__( - self, - config: Qwen3TTSTalkerCodePredictorConfig, - *, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.head_dim = getattr( - config, - "head_dim", - config.hidden_size // config.num_attention_heads, - ) - self.scaling = self.head_dim**-0.5 - self._use_gqa = self.num_kv_heads != self.num_heads - - # Separate q/k/v projections matching HF (no fused packing) - self.q_proj = nn.Linear( - self.hidden_size, - self.num_heads * self.head_dim, - bias=getattr(config, "attention_bias", False), - ) - self.k_proj = nn.Linear( - self.hidden_size, - self.num_kv_heads * self.head_dim, - bias=getattr(config, "attention_bias", False), - ) - self.v_proj = nn.Linear( - self.hidden_size, - self.num_kv_heads * self.head_dim, - bias=getattr(config, "attention_bias", False), - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, - self.hidden_size, - bias=False, - ) - self.q_norm = _RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = _RMSNorm(self.head_dim, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - ) -> torch.Tensor: - bsz, seq_len, _ = hidden_states.shape - hidden_shape_q = (bsz, seq_len, self.num_heads, self.head_dim) - hidden_shape_kv = (bsz, seq_len, self.num_kv_heads, self.head_dim) - - q = self.q_norm(self.q_proj(hidden_states).view(hidden_shape_q)).transpose(1, 2) - k = self.k_norm(self.k_proj(hidden_states).view(hidden_shape_kv)).transpose(1, 2) - v = self.v_proj(hidden_states).view(hidden_shape_kv).transpose(1, 2) - - cos, sin = position_embeddings - # cos/sin are [batch, seq_len, head_dim], need unsqueeze at dim=1 for heads - cos = cos.unsqueeze(1) # [batch, 1, seq_len, head_dim] - sin = sin.unsqueeze(1) - q = (q * cos) + (_rotate_half(q) * sin) - k = (k * cos) + (_rotate_half(k) * sin) - - attn_out = F.scaled_dot_product_attention( - q, - k, - v, - scale=self.scaling, - is_causal=True, - enable_gqa=self._use_gqa, - ) - - attn_out = attn_out.transpose(1, 2).reshape(bsz, seq_len, -1) - output = self.o_proj(attn_out) - return output - - -class _CodePredictorMLP(nn.Module): - """SiLU-gated MLP for code predictor, matching HF's Qwen3TTSTalkerTextMLP.""" - - def __init__( - self, - config: Qwen3TTSTalkerCodePredictorConfig, - *, - prefix: str = "", - ) -> None: - super().__init__() - self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) - self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) - self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class _CodePredictorDecoderLayer(nn.Module): - """Transformer decoder layer for code predictor (SDPA, no KV cache).""" - - def __init__( - self, - config: Qwen3TTSTalkerCodePredictorConfig, - *, - prefix: str = "", - ) -> None: - super().__init__() - self.self_attn = _CodePredictorAttention(config, prefix=f"{prefix}.self_attn") - self.mlp = _CodePredictorMLP(config, prefix=f"{prefix}.mlp") - self.input_layernorm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn(hidden_states, position_embeddings) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - -# =================================================================== -# Code Predictor Transformer Model -# =================================================================== - - -class Qwen3TTSTalkerCodePredictorModelVLLM(nn.Module): - """Transformer model for the code predictor (re-prefill, no KV cache).""" - - def __init__( - self, - config: Qwen3TTSTalkerCodePredictorConfig, - *, - talker_hidden_size: int | None = None, - prefix: str = "", - ) -> None: - super().__init__() - self.config = config - - self.layers = nn.ModuleList( - [_CodePredictorDecoderLayer(config, prefix=f"{prefix}.layers.{i}") for i in range(config.num_hidden_layers)] - ) - self.norm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = _RotaryEmbedding(config) - - # Codec embeddings: one per residual group. Stored in talker hidden dim - # (some checkpoints use talker_hidden_size != code_predictor hidden_size). - emb_dim = int(talker_hidden_size) if talker_hidden_size is not None else int(config.hidden_size) - self.codec_embedding = nn.ModuleList( - [nn.Embedding(config.vocab_size, emb_dim) for _ in range(config.num_code_groups - 1)] - ) - - def get_input_embeddings(self) -> nn.ModuleList: - return self.codec_embedding - - def forward( - self, - inputs_embeds: torch.Tensor, - position_ids: torch.Tensor, - ) -> torch.Tensor: - hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids) - for layer in self.layers: - hidden_states = layer(hidden_states, position_embeddings) - hidden_states = self.norm(hidden_states) - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - param = params_dict.get(name) - if param is None: - continue - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -# =================================================================== -# Code Predictor Wrapper (optimized re-prefill + torch.compile) -# =================================================================== - - -class Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM(nn.Module): - """vLLM-native code_predictor for the AR talker (residual codebooks). +# Backward-compat alias used by tests +Qwen3TTSTalkerCodePredictorModelVLLM = CodePredictorBaseModel - Re-prefill approach: each AR step forwards the full growing sequence - through the 5-layer transformer. No KV cache needed. This trades - ~O(T^2) extra attention FLOPs (negligible for T=16, 5 layers) for - zero KV cache management overhead and a simpler execution model. - Uses HF-compatible layers (plain nn.Linear, float32 RMSNorm, rotate_half - RoPE) to ensure numerical fidelity with the reference implementation. - Precision matters here because small errors compound across 15 AR steps. - - Optimizations preserved: - 1. torch.compile on model forward -- fuses small kernel launches. - 2. Pre-allocated embedding buffer [B, max_seq, H] -- no torch.cat per step. - 3. Projection caching -- each token projected once and cached. - 4. Pre-allocated position_ids -- no torch.arange per step. - 5. Inline sampling -- no custom op / forward_context overhead. - 6. Cached module references -- bypass nn.Module.__call__ overhead. - 7. CUDA graphs per batch-size bucket. - """ +class Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM(CodePredictorWrapper): + """Qwen3-TTS code predictor (CUDA graphs, per-call sampling, projection).""" def __init__( self, @@ -322,250 +31,24 @@ def __init__( talker_config: Qwen3TTSTalkerConfig, prefix: str = "code_predictor", ) -> None: - super().__init__() - self._vllm_config = vllm_config - self.config = config - self.talker_config = talker_config - - self.model = Qwen3TTSTalkerCodePredictorModelVLLM( - config, + super().__init__( + vllm_config=vllm_config, + cp_config=config, + wrapper_config=CodePredictorWrapperConfig( + use_cuda_graphs=True, + use_parallel_embedding=False, + use_projection=(config.hidden_size != talker_config.hidden_size), + return_proj_buf=False, + sampling_mode="per_call", + ), talker_hidden_size=int(talker_config.hidden_size), - prefix=f"{prefix}.model", + prefix=prefix, ) - - self.lm_head = nn.ModuleList( - [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_code_groups - 1)] - ) - - if config.hidden_size != talker_config.hidden_size: - self.small_to_mtp_projection = nn.Linear(talker_config.hidden_size, config.hidden_size, bias=True) - else: - self.small_to_mtp_projection = nn.Identity() - - self._num_groups = int(config.num_code_groups) - self._talker_hidden = int(talker_config.hidden_size) - self._cp_hidden = int(config.hidden_size) - - # Pre-allocated buffers (lazily initialized on first forward). - self._proj_buf: torch.Tensor | None = None - self._model_dtype: torch.dtype | None = None - - # torch.compile + warmup state (lazily initialized in _setup_compile). - self._compiled_model_fwd = None - self._bucket_sizes: list[int] = [] - self._bucket_pos_ids: dict[int, torch.Tensor] = {} - self._lm_heads_list: list[nn.Module] | None = None - self._codec_embeds_list: list[nn.Module] | None = None - self._cuda_graphs: dict[int, tuple[torch.cuda.CUDAGraph, torch.Tensor]] = {} - - def get_input_embeddings(self) -> nn.ModuleList: - return self.model.get_input_embeddings() + # Store talker_config for backward compat (accessed by some callers) + self.talker_config = talker_config + self._vllm_config = vllm_config def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights with vllm config context (required for VocabParallelEmbedding).""" with set_current_vllm_config(self._vllm_config): - loaded: set[str] = set() - model_weights: list[tuple[str, torch.Tensor]] = [] - other_weights: list[tuple[str, torch.Tensor]] = [] - for name, w in weights: - if name.startswith("model."): - model_weights.append((name[len("model.") :], w)) - else: - other_weights.append((name, w)) - - loaded_model = self.model.load_weights(model_weights) - loaded |= {f"model.{n}" for n in loaded_model} - - params = dict(self.named_parameters(remove_duplicate=False)) - for name, w in other_weights: - if name not in params: - continue - default_weight_loader(params[name], w) - loaded.add(name) - - return loaded - - # ------------------------------------------------------------------ - # Pre-allocated buffer management - # ------------------------------------------------------------------ - - def _ensure_buffers(self, device: torch.device, dtype: torch.dtype) -> None: - max_seq = self._num_groups + 1 - if self._proj_buf is not None and self._proj_buf.device == device and self._proj_buf.dtype == dtype: - return - max_bsz = self._vllm_config.scheduler_config.max_num_seqs - self._proj_buf = torch.zeros( - max_bsz, - max_seq, - self._cp_hidden, - dtype=dtype, - device=device, - ) - - def _setup_compile(self) -> None: - """Lazily set up torch.compile with manual CUDA graph capture.""" - if self._compiled_model_fwd is not None: - return - # Cache model parameter dtype so forward() doesn't need to query it - # on every call. Also ensures warmup buffers match model precision - # even when upstream modules produce a different dtype (#2385). - self._model_dtype = next(self.model.parameters()).dtype - self._lm_heads_list = list(self.lm_head) - self._codec_embeds_list = list(self.model.codec_embedding) - if not current_omni_platform.supports_torch_inductor(): - logger.warning_once("code_predictor: torch.compile disabled") - self._compiled_model_fwd = self.model.forward - return - - # torch.compile fuses RMSNorm/RoPE in ways that lose float32 - # precision, compounding across 15 AR steps. Use torch.compile - # with options that disable the problematic fusions while still - # getting kernel fusion benefits for the linear layers and SDPA. - self._compiled_model_fwd = torch.compile( - self.model.forward, - dynamic=False, - options={ - "epilogue_fusion": False, - }, - ) - self._warmup_buckets() - self._capture_cuda_graphs() - logger.info("code_predictor: torch.compile (no epilogue fusion) + CUDA graphs") - - def _padded_bsz(self, bsz: int) -> int: - for bucket in self._bucket_sizes: - if bsz <= bucket: - return bucket - return bsz - - def _warmup_buckets(self) -> None: - """Warmup power-of-2 batch-size buckets to front-load Inductor compilation.""" - max_bsz = self._vllm_config.scheduler_config.max_num_seqs - bucket_sizes = [1 << i for i in range(max_bsz.bit_length()) if (1 << i) <= max_bsz] - if max_bsz not in bucket_sizes: - bucket_sizes.append(max_bsz) - self._bucket_sizes = sorted(bucket_sizes) - - max_seq = self._num_groups + 1 - device = next(self.model.parameters()).device - - # Ensure proj_buf matches model parameter dtype to avoid dtype - # mismatch during warmup compilation (see #2385). - self._ensure_buffers(device, self._model_dtype) - proj_buf = self._proj_buf - for bsz in self._bucket_sizes: - # position_ids: [batch, seq_len] for HF-style RoPE - pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(bsz, -1) - self._bucket_pos_ids[bsz] = pos_ids - for _ in range(3): - self._compiled_model_fwd(proj_buf[:bsz, :max_seq, :], pos_ids) - logger.info("code_predictor: warmup done for buckets %s", self._bucket_sizes) - - def _capture_cuda_graphs(self) -> None: - """Capture a CUDA graph per bucket using vLLM's global graph pool.""" - from vllm.platforms import current_platform - - pool = current_platform.get_global_graph_pool() - - max_seq = self._num_groups + 1 - proj_buf = self._proj_buf - - for bsz in self._bucket_sizes: - static_input = proj_buf[:bsz, :max_seq, :] - pos_ids = self._bucket_pos_ids[bsz] - - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g, pool=pool): - static_output = self._compiled_model_fwd(static_input, pos_ids) - - self._cuda_graphs[bsz] = (g, static_output) - - logger.info("code_predictor: captured CUDA graphs for buckets %s", self._bucket_sizes) - - # ------------------------------------------------------------------ - # Optimized forward: re-prefill + torch.compile + projection cache - # ------------------------------------------------------------------ - - @torch.inference_mode() - def forward( - self, - layer0_code: torch.Tensor, - layer0_embed: torch.Tensor, - last_talker_hidden: torch.Tensor, - do_sample: bool = True, - temperature: float = 0.9, - top_k: int = 50, - top_p: float = 1.0, - ) -> torch.Tensor: - """Predict residual codebooks 1..Q-1 autoregressively via re-prefill. - - torch.compile fuses the ~60 small kernel launches per step into fewer - fused kernels, reducing kernel launch overhead by ~75%. - - Projection caching: each token is projected once via small_to_mtp_projection - and cached in _proj_buf, avoiding redundant re-projection of past tokens. - """ - bsz = int(layer0_code.shape[0]) - num_groups = self._num_groups - device = layer0_code.device - - all_codes = torch.empty(bsz, num_groups, dtype=torch.long, device=device) - all_codes[:, 0] = layer0_code.reshape(bsz) - - # _setup_compile caches _model_dtype on first call; use it for buffers - # so they always match model weight precision (#2385). - self._setup_compile() - dtype = self._model_dtype - self._ensure_buffers(device, dtype) - - proj_buf = self._proj_buf - max_seq = self._num_groups + 1 - - projection = self.small_to_mtp_projection - model_fwd = self._compiled_model_fwd - lm_heads = self._lm_heads_list - codec_embeds = self._codec_embeds_list - - use_sampling = do_sample and temperature > 0 - inv_temperature = 1.0 / max(temperature, 1e-6) if use_sampling else 0.0 - if use_sampling and top_p != 1.0: - raise NotImplementedError( - "top_p sampling is not implemented for the vLLM-native code predictor; please set top_p=1.0." - ) - - padded_bsz = self._padded_bsz(bsz) - proj_buf[:padded_bsz].zero_() - - proj_buf[:bsz, 0, :] = projection(last_talker_hidden.reshape(bsz, 1, -1).to(dtype)).reshape(bsz, -1) - proj_buf[:bsz, 1, :] = projection(layer0_embed.reshape(bsz, 1, -1).to(dtype)).reshape(bsz, -1) - full_pos_ids = self._bucket_pos_ids.get(padded_bsz) - if full_pos_ids is None: - full_pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(padded_bsz, -1) - - # Use captured CUDA graph if available, otherwise call compiled fn. - cuda_graph_entry = self._cuda_graphs.get(padded_bsz) - - for step in range(1, num_groups): - if cuda_graph_entry is not None: - cuda_graph_entry[0].replay() - hidden_out = cuda_graph_entry[1] - else: - hidden_out = model_fwd(proj_buf[:padded_bsz, :max_seq, :], full_pos_ids) - logits = lm_heads[step - 1](hidden_out[:bsz, step, :]) - - if use_sampling: - scaled = logits * inv_temperature - if top_k > 0: - topk_vals, _ = scaled.topk(top_k, dim=-1) - scaled = scaled.masked_fill(scaled < topk_vals[:, -1:], float("-inf")) - probs = F.softmax(scaled, dim=-1) - next_ids = torch.multinomial(probs, num_samples=1) - else: - next_ids = logits.argmax(dim=-1, keepdim=True) - - all_codes[:, step] = next_ids.reshape(bsz) - - if step < num_groups - 1: - new_embed = codec_embeds[step - 1](next_ids) - proj_buf[:bsz, step + 1, :] = projection(new_embed.reshape(bsz, 1, -1)).reshape(bsz, -1) - - return all_codes + return super().load_weights(weights)