diff --git a/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py b/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py index 3739687606d..86af757809d 100644 --- a/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py +++ b/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py @@ -270,3 +270,63 @@ def test_deterministic_across_calls(decoder, wrapper): out1 = wrapper.decode(codes) out2 = wrapper.decode(codes) torch.testing.assert_close(out1, out2, atol=0, rtol=0) + + +# ────────────────────────────────────────────────────────────────── +# 6. compute_capture_sizes +# ────────────────────────────────────────────────────────────────── + + +@pytest.mark.parametrize( + "kwargs,expected_in,not_expected", + [ + ({}, [2, 4, 8, 16, 32, 64, 128, 256, 325], [512]), + ( + {"codec_chunk_frames": 33, "codec_left_context_frames": 25}, + [2, 4, 8, 16, 32, 33, 58, 64, 128, 256, 325], + [512], + ), + ( + {"codec_chunk_frames": 25, "codec_left_context_frames": 25}, + [2, 4, 8, 16, 25, 32, 50, 64, 128, 256, 325], + [512], + ), + ], + ids=["default", "streaming_c33", "streaming_c25"], +) +def test_compute_capture_sizes(kwargs, expected_in, not_expected): + """compute_capture_sizes produces expected sizes capped by max useful size.""" + sizes = CUDAGraphDecoderWrapper.compute_capture_sizes(**kwargs) + for val in expected_in: + assert val in sizes, f"{val} not in {sizes}" + for val in not_expected: + assert val not in sizes, f"{val} should not be in {sizes}" + + +# ────────────────────────────────────────────────────────────────── +# 7. SnakeBeta Triton kernel vs eager equivalence +# ────────────────────────────────────────────────────────────────── + + +@pytest.mark.parametrize( + "batch,channels,seq_len", + [(2, 64, 1000), (1, 32, 1), (1, 32, 7), (1, 32, 128), (1, 32, 1024), (1, 32, 4096)], +) +def test_snakebeta_triton_vs_eager(batch, channels, seq_len): + """Fused Triton SnakeBeta kernel must match eager PyTorch output.""" + from vllm_omni.model_executor.models.qwen3_tts.tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import ( + SnakeBeta, + ) + + if not SnakeBeta._init_triton(): + pytest.skip("Triton not available") + + torch.manual_seed(42) + snake = SnakeBeta(in_features=channels).to(DEVICE).eval() + x = torch.randn(batch, channels, seq_len, device=DEVICE) + + with torch.no_grad(): + eager_out = snake._eager_forward(x) + triton_out = snake._triton_forward(x) + + torch.testing.assert_close(triton_out, eager_out, atol=1e-5, rtol=1e-5) diff --git a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py index 5ff259193d2..96f8c799c13 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py +++ b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py @@ -29,8 +29,6 @@ class CUDAGraphDecoderWrapper: output = wrapper.decode(codes) # Automatically uses CUDA graph if possible """ - DEFAULT_CAPTURE_SIZES = [2, 4, 8, 16, 25, 32, 50, 100, 150, 200, 250, 300] - def __init__( self, decoder: torch.nn.Module, @@ -39,7 +37,8 @@ def __init__( enabled: bool = True, ): self.decoder = decoder - self.capture_sizes = capture_sizes or self.DEFAULT_CAPTURE_SIZES + self._explicit_sizes = capture_sizes is not None + self.capture_sizes = sorted(capture_sizes) if capture_sizes else [] self.num_quantizers = num_quantizers self.enabled = enabled @@ -50,66 +49,82 @@ def __init__( self._warmed_up = False self._device = None + @staticmethod + def compute_capture_sizes( + codec_chunk_frames: int = 0, + codec_left_context_frames: int = 0, + decode_chunk_size: int = 300, + decode_left_context: int = 25, + ) -> list[int]: + """Compute capture sizes from chunking config for high graph hit rate.""" + sizes: set[int] = set() + + # Streaming exact hits + if codec_chunk_frames > 0: + sizes.add(codec_chunk_frames) + if codec_left_context_frames > 0: + sizes.add(codec_chunk_frames + codec_left_context_frames) + + # Non-streaming chunked decode: full chunk + last-chunk buckets + non_stream_max = decode_chunk_size + decode_left_context + sizes.add(non_stream_max) + + # Power-of-2 buckets covering both streaming IC sizes and non-streaming last-chunk sizes + for p2 in [2, 4, 8, 16, 32, 64, 128, 256]: + if p2 <= non_stream_max: + sizes.add(p2) + + return sorted(sizes) + def _get_padded_size(self, actual_size: int) -> int | None: for size in self.capture_sizes: if actual_size <= size: return size return None - def warmup(self, device: torch.device, dtype: torch.dtype = torch.long): - if device.type != "cuda": - logger.info("CUDA Graph warmup skipped: device %s is not CUDA", device) - return - - if not self.enabled: - logger.info("CUDA Graph is disabled, skipping warmup") - return - - if self._warmed_up: - logger.warning("CUDA Graph already warmed up, skipping") + def warmup( + self, + device: torch.device, + dtype: torch.dtype = torch.long, + codec_chunk_frames: int = 0, + codec_left_context_frames: int = 0, + ): + if device.type != "cuda" or not self.enabled or self._warmed_up: return self._device = device self.decoder.eval() + if not self._explicit_sizes: + self.capture_sizes = self.compute_capture_sizes( + codec_chunk_frames=codec_chunk_frames, + codec_left_context_frames=codec_left_context_frames, + ) + logger.info("Starting CUDA Graph warmup for %d sizes: %s", len(self.capture_sizes), self.capture_sizes) # Warmup runs to ensure CUDA memory is allocated for size in self.capture_sizes: - dummy_codes = torch.zeros( - 1, - self.num_quantizers, - size, - dtype=dtype, - device=device, - ) + dummy = torch.zeros(1, self.num_quantizers, size, dtype=dtype, device=device) with torch.no_grad(): - _ = self.decoder(dummy_codes) + _ = self.decoder(dummy) torch.cuda.synchronize(device) for size in self.capture_sizes: try: - self._capture_graph_for_size(size, device, dtype) + self._capture(size, device, dtype) logger.info(" Captured CUDA Graph for size=%d", size) except Exception: - logger.warning(" Failed to capture CUDA Graph for size=%d", size, exc_info=True) + logger.warning(" Failed to capture graph for size=%d", size, exc_info=True) self._warmed_up = True - logger.info("CUDA Graph warmup complete. Captured %d graphs.", len(self.graphs)) - - def _capture_graph_for_size(self, size: int, device: torch.device, dtype: torch.dtype): - static_input = torch.zeros( - 1, - self.num_quantizers, - size, - dtype=dtype, - device=device, - ) + logger.info("CUDA Graph warmup complete: %d/%d captured", len(self.graphs), len(self.capture_sizes)) + def _capture(self, size: int, device: torch.device, dtype: torch.dtype): + static_input = torch.zeros(1, self.num_quantizers, size, dtype=dtype, device=device) with torch.no_grad(): _ = self.decoder(static_input) - torch.cuda.synchronize(device) graph = CUDAGraph() @@ -122,10 +137,7 @@ def _capture_graph_for_size(self, size: int, device: torch.device, dtype: torch. self.static_outputs[size] = static_output def decode(self, codes: torch.Tensor) -> torch.Tensor: - if not self.enabled or not self._warmed_up: - return self.decoder(codes) - - if codes.shape[0] != 1: + if not self.enabled or not self._warmed_up or codes.shape[0] != 1: return self.decoder(codes) actual_size = codes.shape[-1] @@ -136,14 +148,10 @@ def decode(self, codes: torch.Tensor) -> torch.Tensor: self.static_inputs[padded_size].zero_() self.static_inputs[padded_size][:, :, :actual_size] = codes - self.graphs[padded_size].replay() - output = self.static_outputs[padded_size] - total_upsample = self.decoder.total_upsample - actual_output_len = actual_size * total_upsample - - return output[..., :actual_output_len].clone() + actual_out_len = actual_size * self.decoder.total_upsample + return self.static_outputs[padded_size][..., :actual_out_len].clone() def chunked_decode_with_cudagraph( self, diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index 79a2810bfad..6be039df105 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -107,11 +107,17 @@ def _ensure_speech_tokenizer_loaded(self) -> None: self._output_sample_rate = out_sr self._total_upsample = int(decoder.total_upsample) + # Precompute SnakeBeta exp caches (benefits both Triton and eager paths) + if hasattr(decoder, "precompute_snake_caches"): + decoder.precompute_snake_caches() + if hasattr(decoder, "enable_cudagraph"): device = self._module_device(decoder) if device.type == "cuda": try: - capture_sizes = None + chunk_frames = 0 + left_frames = 0 + model_cfg = getattr(self.vllm_config, "model_config", None) connector_cfg = getattr(model_cfg, "stage_connector_config", None) extra_cfg = ( @@ -122,12 +128,12 @@ def _ensure_speech_tokenizer_loaded(self) -> None: if isinstance(extra_cfg, dict): chunk_frames = int(extra_cfg.get("codec_chunk_frames") or 0) left_frames = int(extra_cfg.get("codec_left_context_frames") or 0) - if chunk_frames > 0 and left_frames >= 0: - from .cuda_graph_decoder_wrapper import CUDAGraphDecoderWrapper - steady_window = left_frames + chunk_frames - capture_sizes = sorted({*CUDAGraphDecoderWrapper.DEFAULT_CAPTURE_SIZES, steady_window}) - decoder.enable_cudagraph(capture_sizes=capture_sizes, device=device) + decoder.enable_cudagraph( + device=device, + codec_chunk_frames=chunk_frames, + codec_left_context_frames=left_frames, + ) logger.info("Code2Wav decoder CUDA Graph enabled") except Exception: logger.warning("Failed to enable CUDA Graph for Code2Wav decoder", exc_info=True) @@ -265,18 +271,12 @@ def forward( pass # Decode directly via decoder.chunked_decode(), staying entirely on GPU. - # For single request: no padding needed, fast path. - # For multiple requests: decode each individually to avoid padding overhead. + # Each request decoded individually with CUDA graph replay at bs=1. wav_tensors: list[torch.Tensor] = [] - if len(valid_codes_qf) == 1: - codes_bqf = valid_codes_qf[0].unsqueeze(0) # [1, Q, F] + for codes_qf in valid_codes_qf: + codes_bqf = codes_qf.unsqueeze(0) # [1, Q, F] wav = decoder.chunked_decode(codes_bqf) # [1, 1, wav_len] wav_tensors.append(wav.squeeze(0).squeeze(0)) # [wav_len] - else: - for codes_qf in valid_codes_qf: - codes_bqf = codes_qf.unsqueeze(0) # [1, Q, F] - wav = decoder.chunked_decode(codes_bqf) - wav_tensors.append(wav.squeeze(0).squeeze(0)) audios: list[torch.Tensor] = [empty] * num_req srs = [sr_tensor] * num_req diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py index 5719d1024f2..dfa48df298a 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py @@ -614,32 +614,114 @@ class SnakeBeta(nn.Module): https://huggingface.co/papers/2006.08195 """ + _triton_kernel = None # None = untried, False = unavailable, callable = ready + _TRITON_MAX_BLOCK_T = 4096 # upper bound for time-axis tile size + + @staticmethod + def _init_triton(): + """Load and JIT-compile the fused Triton kernel (once).""" + if SnakeBeta._triton_kernel is not None: + return SnakeBeta._triton_kernel is not False + try: + import triton + import triton.language as tl + except ImportError: + SnakeBeta._triton_kernel = False + return False + + @triton.jit + def _kernel( # noqa: N803 + x_ptr, + exp_alpha_ptr, + inv_beta_ptr, + out_ptr, + stride_b, + stride_c, + t_len, + block_t: tl.constexpr, + ): + """Fused SnakeBeta using precomputed exp(α) and 1/(exp(β)+ε).""" + bid = tl.program_id(0) + cid = tl.program_id(1) + t_off = tl.program_id(2) * block_t + tl.arange(0, block_t) + mask = t_off < t_len + + x = tl.load(x_ptr + bid * stride_b + cid * stride_c + t_off, mask=mask, other=0.0) + ea = tl.load(exp_alpha_ptr + cid) + ib = tl.load(inv_beta_ptr + cid) + sin_val = tl.sin(x * ea) + result = x + ib * sin_val * sin_val + + tl.store(out_ptr + bid * stride_b + cid * stride_c + t_off, result, mask=mask) + + SnakeBeta._triton_kernel = _kernel + return True + def __init__(self, in_features, alpha=1.0): super().__init__() self.in_features = in_features - # initialize alpha self.alpha = Parameter(torch.zeros(in_features) * alpha) self.beta = Parameter(torch.zeros(in_features) * alpha) self.no_div_by_zero = 0.000000001 - def forward(self, hidden_states): - """ - Forward pass of the function. - Applies the function to the input elementwise. - SnakeBeta ∶= x + 1/b * sin^2 (xa) - """ - alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] - beta = self.beta.unsqueeze(0).unsqueeze(-1) - alpha = torch.exp(alpha) - beta = torch.exp(beta) - hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( - torch.sin(hidden_states * alpha), 2 - ) + # Precomputed buffers (populated by precompute_exp_cache) + self.register_buffer("_exp_alpha", None, persistent=False) + self.register_buffer("_inv_beta", None, persistent=False) + + def precompute_exp_cache(self): + """Materialize exp(alpha) and 1/(exp(beta)+eps) as frozen buffers.""" + with torch.no_grad(): + self._exp_alpha = torch.exp(self.alpha).contiguous() + self._inv_beta = (1.0 / (torch.exp(self.beta) + self.no_div_by_zero)).contiguous() + @property + def _cached(self): + return self._exp_alpha is not None + + def forward(self, hidden_states): + """SnakeBeta := x + 1/b * sin^2(x*a)""" + if hidden_states.is_cuda and not torch.is_grad_enabled() and self._init_triton(): + try: + return self._triton_forward(hidden_states) + except Exception: + logger.warning("Triton SnakeBeta failed, falling back to eager", exc_info=True) + SnakeBeta._triton_kernel = False + return self._eager_forward(hidden_states) + + def _eager_forward(self, hidden_states): + if self._cached: + exp_alpha = self._exp_alpha.unsqueeze(0).unsqueeze(-1) + inv_beta = self._inv_beta.unsqueeze(0).unsqueeze(-1) + else: + exp_alpha = torch.exp(self.alpha).unsqueeze(0).unsqueeze(-1) + inv_beta = (1.0 / (torch.exp(self.beta) + self.no_div_by_zero)).unsqueeze(0).unsqueeze(-1) + hidden_states = hidden_states + inv_beta * torch.pow(torch.sin(hidden_states * exp_alpha), 2) return hidden_states + def _triton_forward(self, x): + import triton + + if not self._cached: + self.precompute_exp_cache() + + x = x.contiguous() + B, C, T = x.shape + out = torch.empty_like(x) + block_t = min(triton.next_power_of_2(T), self._TRITON_MAX_BLOCK_T) + self._triton_kernel[(B, C, triton.cdiv(T, block_t))]( + x, + self._exp_alpha, + self._inv_beta, + out, + x.stride(0), + x.stride(1), + t_len=T, + block_t=block_t, + ) + return out + class Qwen3TTSTokenizerV2DecoderDecoderResidualUnit(nn.Module): def __init__(self, dim: int = 16, dilation: int = 1): @@ -876,10 +958,22 @@ def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig): self._cudagraph_enabled = False self._cudagraph_wrapper = None + def precompute_snake_caches(self): + """Precompute exp(alpha) and 1/(exp(beta)+eps) for all SnakeBeta modules.""" + count = 0 + for module in self.modules(): + if isinstance(module, SnakeBeta): + module.precompute_exp_cache() + count += 1 + if count > 0: + logger.info("Precomputed exp caches for %d SnakeBeta activations", count) + def enable_cudagraph( self, capture_sizes: list[int] | None = None, device: torch.device | None = None, + codec_chunk_frames: int = 0, + codec_left_context_frames: int = 0, ): from ..cuda_graph_decoder_wrapper import CUDAGraphDecoderWrapper @@ -888,16 +982,24 @@ def enable_cudagraph( if device.type != "cuda": logger.warning("Cannot enable CUDA Graph: decoder is not on a CUDA device (got %s)", device) return + self._cudagraph_wrapper = CUDAGraphDecoderWrapper( decoder=self, capture_sizes=capture_sizes, num_quantizers=self.config.num_quantizers, enabled=True, ) - self._cudagraph_wrapper.warmup(device, dtype=torch.long) + self._cudagraph_wrapper.warmup( + device, + dtype=torch.long, + codec_chunk_frames=codec_chunk_frames, + codec_left_context_frames=codec_left_context_frames, + ) self._cudagraph_enabled = True - sizes = self._cudagraph_wrapper.capture_sizes - logger.info("CUDA Graph enabled for decoder with sizes: %s", sizes) + logger.info( + "CUDA Graph enabled for decoder: seq_lens=%s", + self._cudagraph_wrapper.capture_sizes, + ) def disable_cudagraph(self): self._cudagraph_enabled = False