From 68c1ed09669181e25c2f8f9159358266fe22360a Mon Sep 17 00:00:00 2001 From: JuanPZuluaga Date: Thu, 23 Apr 2026 19:22:44 +0000 Subject: [PATCH] [Feat] Add CUDA graph support for Qwen3-Omni Code2Wav decoder Signed-off-by: JuanPZuluaga --- .../offline_inference/qwen3_omni/end2end.py | 7 + .../qwen3_tts/test_cuda_graph_decoder.py | 4 +- .../models/common/snake_activation.py | 135 ++++++++++++++++++ .../models/qwen3_omni/qwen3_omni.py | 27 ++++ .../models/qwen3_omni/qwen3_omni_code2wav.py | 102 ++++++++++--- .../modeling_qwen3_tts_tokenizer_v2.py | 127 +--------------- 6 files changed, 256 insertions(+), 146 deletions(-) create mode 100644 vllm_omni/model_executor/models/common/snake_activation.py diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py index 04aa7914db1..740e6ae224c 100644 --- a/examples/offline_inference/qwen3_omni/end2end.py +++ b/examples/offline_inference/qwen3_omni/end2end.py @@ -389,6 +389,13 @@ def main(args): output_wav = os.path.join(output_dir, f"output_{request_id}.wav") # Convert to numpy array and ensure correct format + # In async_chunk mode, audio may arrive as a list of chunks + if isinstance(audio_tensor, list): + import torch + + audio_tensor = torch.cat( + [(t if isinstance(t, torch.Tensor) else torch.tensor(t)).flatten() for t in audio_tensor] + ) audio_numpy = audio_tensor.float().detach().cpu().numpy() # Ensure audio is 1D (flatten if needed) diff --git a/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py b/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py index 86af757809d..77568a665e9 100644 --- a/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py +++ b/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py @@ -314,9 +314,7 @@ def test_compute_capture_sizes(kwargs, expected_in, not_expected): ) def test_snakebeta_triton_vs_eager(batch, channels, seq_len): """Fused Triton SnakeBeta kernel must match eager PyTorch output.""" - from vllm_omni.model_executor.models.qwen3_tts.tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import ( - SnakeBeta, - ) + from vllm_omni.model_executor.models.common.snake_activation import SnakeBeta if not SnakeBeta._init_triton(): pytest.skip("Triton not available") diff --git a/vllm_omni/model_executor/models/common/snake_activation.py b/vllm_omni/model_executor/models/common/snake_activation.py new file mode 100644 index 00000000000..0be93d74a0b --- /dev/null +++ b/vllm_omni/model_executor/models/common/snake_activation.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The Qwen team. +"""Shared SnakeBeta activation for speech decoders (Qwen3-TTS, Qwen3-Omni Code2Wav).""" + +import torch +from torch import nn +from torch.nn import Parameter +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper + by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://huggingface.co/papers/2006.08195 + """ + + _triton_kernel = None # None = untried, False = unavailable, callable = ready + _TRITON_MAX_BLOCK_T = 4096 # upper bound for time-axis tile size + + @staticmethod + def _init_triton(): + """Load and JIT-compile the fused Triton kernel (once).""" + if SnakeBeta._triton_kernel is not None: + return SnakeBeta._triton_kernel is not False + try: + import triton + import triton.language as tl + except ImportError: + SnakeBeta._triton_kernel = False + return False + + @triton.jit + def _kernel( # noqa: N803 + x_ptr, + exp_alpha_ptr, + inv_beta_ptr, + out_ptr, + stride_b, + stride_c, + t_len, + block_t: tl.constexpr, + ): + """Fused SnakeBeta using precomputed exp(α) and 1/(exp(β)+ε).""" + bid = tl.program_id(0) + cid = tl.program_id(1) + t_off = tl.program_id(2) * block_t + tl.arange(0, block_t) + mask = t_off < t_len + + x = tl.load(x_ptr + bid * stride_b + cid * stride_c + t_off, mask=mask, other=0.0) + ea = tl.load(exp_alpha_ptr + cid) + ib = tl.load(inv_beta_ptr + cid) + sin_val = tl.sin(x * ea) + result = x + ib * sin_val * sin_val + + tl.store(out_ptr + bid * stride_b + cid * stride_c + t_off, result, mask=mask) + + SnakeBeta._triton_kernel = _kernel + return True + + def __init__(self, in_features, alpha=1.0): + super().__init__() + self.in_features = in_features + + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + + self.no_div_by_zero = 0.000000001 + + # Precomputed buffers (populated by precompute_exp_cache) + self.register_buffer("_exp_alpha", None, persistent=False) + self.register_buffer("_inv_beta", None, persistent=False) + + def precompute_exp_cache(self): + """Materialize exp(alpha) and 1/(exp(beta)+eps) as frozen buffers.""" + with torch.no_grad(): + self._exp_alpha = torch.exp(self.alpha).contiguous() + self._inv_beta = (1.0 / (torch.exp(self.beta) + self.no_div_by_zero)).contiguous() + + @property + def _cached(self): + return self._exp_alpha is not None + + def forward(self, hidden_states): + """SnakeBeta := x + 1/b * sin^2(x*a)""" + if hidden_states.is_cuda and not torch.is_grad_enabled() and self._init_triton(): + try: + return self._triton_forward(hidden_states) + except Exception: + logger.warning("Triton SnakeBeta failed, falling back to eager", exc_info=True) + SnakeBeta._triton_kernel = False + return self._eager_forward(hidden_states) + + def _eager_forward(self, hidden_states): + if self._cached: + exp_alpha = self._exp_alpha.unsqueeze(0).unsqueeze(-1) + inv_beta = self._inv_beta.unsqueeze(0).unsqueeze(-1) + else: + exp_alpha = torch.exp(self.alpha).unsqueeze(0).unsqueeze(-1) + inv_beta = (1.0 / (torch.exp(self.beta) + self.no_div_by_zero)).unsqueeze(0).unsqueeze(-1) + hidden_states = hidden_states + inv_beta * torch.pow(torch.sin(hidden_states * exp_alpha), 2) + return hidden_states + + def _triton_forward(self, x): + import triton + + if not self._cached: + self.precompute_exp_cache() + + x = x.contiguous() + B, C, T = x.shape + out = torch.empty_like(x) + block_t = min(triton.next_power_of_2(T), self._TRITON_MAX_BLOCK_T) + self._triton_kernel[(B, C, triton.cdiv(T, block_t))]( + x, + self._exp_alpha, + self._inv_beta, + out, + x.stride(0), + x.stride(1), + t_len=T, + block_t=block_t, + ) + return out diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 28b969ff7cd..cf57b504b54 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -1257,6 +1257,18 @@ def sample( # ==================== Weight Loading ==================== + def _get_codec_frame_config(self) -> tuple[int, int]: + """Extract codec_chunk_frames and codec_left_context_frames from stage connector config.""" + model_cfg = getattr(self.vllm_config, "model_config", None) + connector_cfg = getattr(model_cfg, "stage_connector_config", None) + if isinstance(connector_cfg, dict): + extra = connector_cfg.get("extra", {}) + else: + extra = getattr(connector_cfg, "extra", None) or {} + chunk_frames = int(extra.get("codec_chunk_frames", 0) or 0) + left_frames = int(extra.get("codec_left_context_frames", 0) or 0) + return chunk_frames, left_frames + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights for all components of the omni model.""" loaded_weights = set() @@ -1293,6 +1305,21 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: code2wav_loaded = add_prefix_to_loaded_weights(code2wav_loaded, "code2wav") loaded_weights.update(code2wav_loaded) + # Precompute SnakeBeta caches and enable CUDA graph for Code2Wav decoder + try: + self.code2wav.precompute_snake_caches() + if hasattr(self.code2wav, "enable_cudagraph"): + chunk_frames, left_frames = self._get_codec_frame_config() + self.code2wav.enable_cudagraph( + codec_chunk_frames=chunk_frames, + codec_left_context_frames=left_frames, + ) + except Exception: + logger.warning( + "Failed to enable CUDA Graph for Code2Wav; falling back to eager.", + exc_info=True, + ) + # Log summary logger.info( "Loaded %d weights for Qwen3OmniMoe (stage=%s)", diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py index 41eeec5e7f4..8bcba4340dd 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py @@ -19,7 +19,6 @@ Qwen3OmniMoeCode2WavDecoderBlock, Qwen3OmniMoeCode2WavTransformerModel, Qwen3OmniMoeConvNeXtBlock, - SnakeBeta, ) from vllm.config import VllmConfig # type: ignore from vllm.logger import init_logger # type: ignore @@ -28,6 +27,8 @@ WeightsMapper, ) +from vllm_omni.model_executor.models.common.snake_activation import SnakeBeta + logger = init_logger(__name__) @@ -119,6 +120,61 @@ def __init__( ] self.decoder = nn.ModuleList(decoder) + # CUDA Graph support — reuses CUDAGraphDecoderWrapper from Qwen3-TTS + self._cudagraph_enabled = False + self._cudagraph_wrapper = None + + def precompute_snake_caches(self): + """Precompute exp(alpha) and 1/(exp(beta)+eps) for all SnakeBeta modules.""" + count = 0 + for module in self.modules(): + if isinstance(module, SnakeBeta): + module.precompute_exp_cache() + count += 1 + if count > 0: + logger.info("Precomputed exp caches for %d SnakeBeta activations", count) + + def enable_cudagraph( + self, + device: torch.device | None = None, + codec_chunk_frames: int = 0, + codec_left_context_frames: int = 0, + ): + """Enable CUDA graph acceleration (same pattern as Qwen3-TTS Code2Wav).""" + from vllm_omni.model_executor.models.qwen3_tts.cuda_graph_decoder_wrapper import ( + CUDAGraphDecoderWrapper, + ) + + if device is None: + device = next(self.parameters()).device + if device.type != "cuda": + logger.warning("Cannot enable CUDA Graph: not on CUDA device (got %s)", device) + return + + wrapper = CUDAGraphDecoderWrapper( + decoder=self, + num_quantizers=self.config.num_quantizers, + enabled=True, + ) + try: + wrapper.warmup( + device, + dtype=torch.long, + codec_chunk_frames=codec_chunk_frames, + codec_left_context_frames=codec_left_context_frames, + ) + except Exception: + self._cudagraph_wrapper = None + self._cudagraph_enabled = False + raise + self._cudagraph_wrapper = wrapper + self._cudagraph_enabled = True + logger.info( + "CUDA Graph enabled for Code2Wav: num_quantizers=%d, sizes=%s", + self.config.num_quantizers, + self._cudagraph_wrapper.capture_sizes, + ) + def forward(self, codes: torch.Tensor) -> torch.Tensor: """ Convert num_quantizers-layer RVQ codes to audio waveform. @@ -168,6 +224,8 @@ def chunked_decode( Decode long sequences in chunks to avoid OOM. Uses overlapping chunks with left context to avoid boundary artifacts. + When CUDA graphs are enabled, delegates chunk-level decoding to the + CUDAGraphDecoderWrapper for reduced kernel launch overhead. Args: codes: [batch, num_quantizers, seq_len] - num_quantizers-layer RVQ codes @@ -180,35 +238,40 @@ def chunked_decode( codes. For ``batch_size == 1``, this is a list containing a single tensor with shape ``[1, waveform_len]``. """ - wavs = [] - start_index = 0 + # Use CUDA graph wrapper for chunk-level decode when available + if self._cudagraph_enabled and self._cudagraph_wrapper is not None: + batch_wav = self._cudagraph_wrapper.chunked_decode_with_cudagraph(codes, chunk_size, left_context_size) + else: + wavs = [] + start_index = 0 + + while start_index < codes.shape[-1]: + end_index = min(start_index + chunk_size, codes.shape[-1]) + context_size = left_context_size if start_index >= left_context_size else start_index - while start_index < codes.shape[-1]: - end_index = min(start_index + chunk_size, codes.shape[-1]) - context_size = left_context_size if start_index >= left_context_size else start_index + # Extract chunk with left context + codes_chunk = codes[..., start_index - context_size : end_index] - # Extract chunk with left context - codes_chunk = codes[..., start_index - context_size : end_index] + # Decode chunk + wav_chunk = self(codes_chunk) - # Decode chunk - wav_chunk = self(codes_chunk) + # Remove context from output (context_size * total_upsample samples) + wavs.append(wav_chunk[..., context_size * self.total_upsample :]) - # Remove context from output (context_size * total_upsample samples) - wavs.append(wav_chunk[..., context_size * self.total_upsample :]) + start_index = end_index - start_index = end_index + batch_wav = torch.cat(wavs, dim=-1) if seq_token_counts is not None: code_seq_lens = [seq_len // self.config.num_quantizers for seq_len in seq_token_counts] else: # Fallback: assume all batch elements share the same sequence length. code_seq_lens = [codes.shape[-1]] * codes.shape[0] - batch_wav = torch.cat(wavs, dim=-1) - wavs = [] + result = [] for idx, code_seq_len in enumerate(code_seq_lens): wav_chunk = batch_wav[idx, :, : code_seq_len * self.total_upsample] - wavs.append(wav_chunk) - return wavs + result.append(wav_chunk) + return result def chunked_decode_streaming( self, @@ -241,7 +304,10 @@ def chunked_decode_streaming( left_context_size = [0] * codes.shape[0] # Decode chunk wavs = [] - batch_wav = self(codes) + if self._cudagraph_enabled and self._cudagraph_wrapper is not None: + batch_wav = self._cudagraph_wrapper.decode(codes) + else: + batch_wav = self(codes) if seq_token_counts is not None: code_seq_lens = [n // self.config.num_quantizers for n in seq_token_counts] else: diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py index dfa48df298a..d96505b8db1 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py @@ -21,7 +21,6 @@ import numpy as np import torch from torch import nn -from torch.nn import Parameter from torch.nn import functional as F from transformers import MimiConfig, MimiModel from transformers.activations import ACT2FN @@ -40,6 +39,8 @@ from transformers.utils import ModelOutput, auto_docstring, logging from transformers.utils.deprecation import deprecate_kwarg +from vllm_omni.model_executor.models.common.snake_activation import SnakeBeta + from .configuration_qwen3_tts_tokenizer_v2 import ( Qwen3TTSTokenizerV2Config, Qwen3TTSTokenizerV2DecoderConfig, @@ -599,130 +600,6 @@ def forward( ) -class SnakeBeta(nn.Module): - """ - A modified Snake function which uses separate parameters for the magnitude of the periodic components - Shape: - - Input: (B, C, T) - - Output: (B, C, T), same shape as the input - Parameters: - - alpha - trainable parameter that controls frequency - - beta - trainable parameter that controls magnitude - References: - - This activation function is a modified version based on this paper - by Liu Ziyin, Tilman Hartwig, Masahito Ueda: - https://huggingface.co/papers/2006.08195 - """ - - _triton_kernel = None # None = untried, False = unavailable, callable = ready - _TRITON_MAX_BLOCK_T = 4096 # upper bound for time-axis tile size - - @staticmethod - def _init_triton(): - """Load and JIT-compile the fused Triton kernel (once).""" - if SnakeBeta._triton_kernel is not None: - return SnakeBeta._triton_kernel is not False - try: - import triton - import triton.language as tl - except ImportError: - SnakeBeta._triton_kernel = False - return False - - @triton.jit - def _kernel( # noqa: N803 - x_ptr, - exp_alpha_ptr, - inv_beta_ptr, - out_ptr, - stride_b, - stride_c, - t_len, - block_t: tl.constexpr, - ): - """Fused SnakeBeta using precomputed exp(α) and 1/(exp(β)+ε).""" - bid = tl.program_id(0) - cid = tl.program_id(1) - t_off = tl.program_id(2) * block_t + tl.arange(0, block_t) - mask = t_off < t_len - - x = tl.load(x_ptr + bid * stride_b + cid * stride_c + t_off, mask=mask, other=0.0) - ea = tl.load(exp_alpha_ptr + cid) - ib = tl.load(inv_beta_ptr + cid) - sin_val = tl.sin(x * ea) - result = x + ib * sin_val * sin_val - - tl.store(out_ptr + bid * stride_b + cid * stride_c + t_off, result, mask=mask) - - SnakeBeta._triton_kernel = _kernel - return True - - def __init__(self, in_features, alpha=1.0): - super().__init__() - self.in_features = in_features - - self.alpha = Parameter(torch.zeros(in_features) * alpha) - self.beta = Parameter(torch.zeros(in_features) * alpha) - - self.no_div_by_zero = 0.000000001 - - # Precomputed buffers (populated by precompute_exp_cache) - self.register_buffer("_exp_alpha", None, persistent=False) - self.register_buffer("_inv_beta", None, persistent=False) - - def precompute_exp_cache(self): - """Materialize exp(alpha) and 1/(exp(beta)+eps) as frozen buffers.""" - with torch.no_grad(): - self._exp_alpha = torch.exp(self.alpha).contiguous() - self._inv_beta = (1.0 / (torch.exp(self.beta) + self.no_div_by_zero)).contiguous() - - @property - def _cached(self): - return self._exp_alpha is not None - - def forward(self, hidden_states): - """SnakeBeta := x + 1/b * sin^2(x*a)""" - if hidden_states.is_cuda and not torch.is_grad_enabled() and self._init_triton(): - try: - return self._triton_forward(hidden_states) - except Exception: - logger.warning("Triton SnakeBeta failed, falling back to eager", exc_info=True) - SnakeBeta._triton_kernel = False - return self._eager_forward(hidden_states) - - def _eager_forward(self, hidden_states): - if self._cached: - exp_alpha = self._exp_alpha.unsqueeze(0).unsqueeze(-1) - inv_beta = self._inv_beta.unsqueeze(0).unsqueeze(-1) - else: - exp_alpha = torch.exp(self.alpha).unsqueeze(0).unsqueeze(-1) - inv_beta = (1.0 / (torch.exp(self.beta) + self.no_div_by_zero)).unsqueeze(0).unsqueeze(-1) - hidden_states = hidden_states + inv_beta * torch.pow(torch.sin(hidden_states * exp_alpha), 2) - return hidden_states - - def _triton_forward(self, x): - import triton - - if not self._cached: - self.precompute_exp_cache() - - x = x.contiguous() - B, C, T = x.shape - out = torch.empty_like(x) - block_t = min(triton.next_power_of_2(T), self._TRITON_MAX_BLOCK_T) - self._triton_kernel[(B, C, triton.cdiv(T, block_t))]( - x, - self._exp_alpha, - self._inv_beta, - out, - x.stride(0), - x.stride(1), - t_len=T, - block_t=block_t, - ) - return out - - class Qwen3TTSTokenizerV2DecoderDecoderResidualUnit(nn.Module): def __init__(self, dim: int = 16, dilation: int = 1): super().__init__()