From 60ffb77128c9ea918ee21eff7cdc0a20ec31582c Mon Sep 17 00:00:00 2001 From: pablo Date: Tue, 10 Mar 2026 19:29:46 +0000 Subject: [PATCH 01/12] optimize cuda graph capture script for code2wav Signed-off-by: pablo --- .../qwen3_tts/cuda_graph_decoder_wrapper.py | 103 ++++++++++-------- 1 file changed, 59 insertions(+), 44 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py index de278edf7a9..0b7ba65fe6d 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py +++ b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py @@ -29,7 +29,7 @@ class CUDAGraphDecoderWrapper: output = wrapper.decode(codes) # Automatically uses CUDA graph if possible """ - DEFAULT_CAPTURE_SIZES = [25, 50, 100, 150, 200, 250, 300] + DEFAULT_CAPTURE_SIZES = [2, 4, 8, 16, 25, 32, 50, 100, 150, 200, 250, 300] def __init__( self, @@ -39,7 +39,8 @@ def __init__( enabled: bool = True, ): self.decoder = decoder - self.capture_sizes = capture_sizes or self.DEFAULT_CAPTURE_SIZES + self._explicit_sizes = capture_sizes is not None + self.capture_sizes = sorted(capture_sizes or self.DEFAULT_CAPTURE_SIZES) self.num_quantizers = num_quantizers self.enabled = enabled @@ -50,66 +51,87 @@ def __init__( self._warmed_up = False self._device = None + @staticmethod + def compute_capture_sizes( + codec_chunk_frames: int = 0, + codec_left_context_frames: int = 0, + decode_chunk_size: int = 300, + decode_left_context: int = 25, + ) -> list[int]: + """Compute capture sizes from chunking config for high graph hit rate.""" + sizes: set[int] = set() + + # Streaming sizes + if codec_chunk_frames > 0: + sizes.add(codec_chunk_frames) + if codec_left_context_frames > 0: + sizes.add(codec_chunk_frames + codec_left_context_frames) + + # Non-streaming: first chunk + steady-state + sizes.add(decode_chunk_size) + sizes.add(decode_chunk_size + decode_left_context) + + # Buckets for variable-length last chunks + step = max(decode_chunk_size // 8, 10) + for i in range(step, decode_chunk_size + 1, step): + sizes.add(i + decode_left_context) + + # Power-of-2 small sizes for dynamic initial chunk sizes + for s in [2, 4, 8, 16, 32, 64]: + sizes.add(s) + + return sorted(sizes) + def _get_padded_size(self, actual_size: int) -> int | None: for size in self.capture_sizes: if actual_size <= size: return size return None - def warmup(self, device: torch.device, dtype: torch.dtype = torch.long): - if device.type != "cuda": - logger.info("CUDA Graph warmup skipped: device %s is not CUDA", device) - return - - if not self.enabled: - logger.info("CUDA Graph is disabled, skipping warmup") - return - - if self._warmed_up: - logger.warning("CUDA Graph already warmed up, skipping") + def warmup( + self, + device: torch.device, + dtype: torch.dtype = torch.long, + codec_chunk_frames: int = 0, + codec_left_context_frames: int = 0, + ): + if device.type != "cuda" or not self.enabled or self._warmed_up: return self._device = device self.decoder.eval() logger.info("Starting CUDA Graph warmup for %d sizes: %s", len(self.capture_sizes), self.capture_sizes) + if not self._explicit_sizes: + self.capture_sizes = self.compute_capture_sizes( + codec_chunk_frames=codec_chunk_frames, + codec_left_context_frames=codec_left_context_frames, + ) + + logger.info("CUDA Graph warmup: %d sizes %s", len(self.capture_sizes), self.capture_sizes) # Warmup runs to ensure CUDA memory is allocated for size in self.capture_sizes: - dummy_codes = torch.zeros( - 1, - self.num_quantizers, - size, - dtype=dtype, - device=device, - ) + dummy = torch.zeros(1, self.num_quantizers, size, dtype=dtype, device=device) with torch.no_grad(): - _ = self.decoder(dummy_codes) + _ = self.decoder(dummy) torch.cuda.synchronize(device) for size in self.capture_sizes: try: - self._capture_graph_for_size(size, device, dtype) + self._capture(size, device, dtype) logger.info(" Captured CUDA Graph for size=%d", size) except Exception: - logger.warning(" Failed to capture CUDA Graph for size=%d", size, exc_info=True) + logger.warning(" Failed to capture graph for size=%d", size, exc_info=True) self._warmed_up = True - logger.info("CUDA Graph warmup complete. Captured %d graphs.", len(self.graphs)) - - def _capture_graph_for_size(self, size: int, device: torch.device, dtype: torch.dtype): - static_input = torch.zeros( - 1, - self.num_quantizers, - size, - dtype=dtype, - device=device, - ) + logger.info("CUDA Graph warmup complete: %d/%d captured", len(self.graphs), len(self.capture_sizes)) + def _capture(self, size: int, device: torch.device, dtype: torch.dtype): + static_input = torch.zeros(1, self.num_quantizers, size, dtype=dtype, device=device) with torch.no_grad(): _ = self.decoder(static_input) - torch.cuda.synchronize(device) graph = CUDAGraph() @@ -122,10 +144,7 @@ def _capture_graph_for_size(self, size: int, device: torch.device, dtype: torch. self.static_outputs[size] = static_output def decode(self, codes: torch.Tensor) -> torch.Tensor: - if not self.enabled or not self._warmed_up: - return self.decoder(codes) - - if codes.shape[0] != 1: + if not self.enabled or not self._warmed_up or codes.shape[0] != 1: return self.decoder(codes) actual_size = codes.shape[-1] @@ -136,14 +155,10 @@ def decode(self, codes: torch.Tensor) -> torch.Tensor: self.static_inputs[padded_size].zero_() self.static_inputs[padded_size][:, :, :actual_size] = codes - self.graphs[padded_size].replay() - output = self.static_outputs[padded_size] - total_upsample = self.decoder.total_upsample - actual_output_len = actual_size * total_upsample - - return output[..., :actual_output_len].clone() + actual_out_len = actual_size * self.decoder.total_upsample + return self.static_outputs[padded_size][..., :actual_out_len].clone() def chunked_decode_with_cudagraph( self, From 4b6a9491e8ae730ee45d32ee5b9cb18b148567a1 Mon Sep 17 00:00:00 2001 From: pablo Date: Tue, 10 Mar 2026 19:31:31 +0000 Subject: [PATCH 02/12] add triton kernel for SnakeBeta Code2Wav Signed-off-by: pablo --- .../models/qwen3_tts/qwen3_tts_code2wav.py | 26 +++--- .../modeling_qwen3_tts_tokenizer_v2.py | 82 ++++++++++++++++++- 2 files changed, 89 insertions(+), 19 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index eee751e1683..a08525ff92f 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -111,7 +111,9 @@ def _ensure_speech_tokenizer_loaded(self) -> None: device = self._module_device(decoder) if device.type == "cuda": try: - capture_sizes = None + chunk_frames = 0 + left_frames = 0 + model_cfg = getattr(self.vllm_config, "model_config", None) connector_cfg = getattr(model_cfg, "stage_connector_config", None) extra_cfg = ( @@ -122,12 +124,12 @@ def _ensure_speech_tokenizer_loaded(self) -> None: if isinstance(extra_cfg, dict): chunk_frames = int(extra_cfg.get("codec_chunk_frames") or 0) left_frames = int(extra_cfg.get("codec_left_context_frames") or 0) - if chunk_frames > 0 and left_frames >= 0: - from .cuda_graph_decoder_wrapper import CUDAGraphDecoderWrapper - steady_window = left_frames + chunk_frames - capture_sizes = sorted({*CUDAGraphDecoderWrapper.DEFAULT_CAPTURE_SIZES, steady_window}) - decoder.enable_cudagraph(capture_sizes=capture_sizes, device=device) + decoder.enable_cudagraph( + device=device, + codec_chunk_frames=chunk_frames, + codec_left_context_frames=left_frames, + ) logger.info("Code2Wav decoder CUDA Graph enabled") except Exception: logger.warning("Failed to enable CUDA Graph for Code2Wav decoder", exc_info=True) @@ -265,18 +267,12 @@ def forward( pass # Decode directly via decoder.chunked_decode(), staying entirely on GPU. - # For single request: no padding needed, fast path. - # For multiple requests: decode each individually to avoid padding overhead. + # Each request decoded individually with CUDA graph replay at bs=1. wav_tensors: list[torch.Tensor] = [] - if len(valid_codes_qf) == 1: - codes_bqf = valid_codes_qf[0].unsqueeze(0) # [1, Q, F] + for codes_qf in valid_codes_qf: + codes_bqf = codes_qf.unsqueeze(0) # [1, Q, F] wav = decoder.chunked_decode(codes_bqf) # [1, 1, wav_len] wav_tensors.append(wav.squeeze(0).squeeze(0)) # [wav_len] - else: - for codes_qf in valid_codes_qf: - codes_bqf = codes_qf.unsqueeze(0) # [1, Q, F] - wav = decoder.chunked_decode(codes_bqf) - wav_tensors.append(wav.squeeze(0).squeeze(0)) audios: list[torch.Tensor] = [empty] * num_req srs = [sr_tensor] * num_req diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py index 5719d1024f2..4566d07517f 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py @@ -614,6 +614,48 @@ class SnakeBeta(nn.Module): https://huggingface.co/papers/2006.08195 """ + _triton_kernel = None + + @staticmethod + def _init_triton(): + """Load and JIT-compile the fused Triton kernel (once).""" + if SnakeBeta._triton_kernel is not None: + return True + try: + import triton + import triton.language as tl + except ImportError: + return False + + @triton.jit + def _kernel( # noqa: N803 + x_ptr, + alpha_ptr, + beta_ptr, + out_ptr, + stride_b, + stride_c, + t_len: tl.constexpr, + eps: tl.constexpr, + block_t: tl.constexpr, + ): + """Fused SnakeBeta: x + (1/exp(β)) · sin²(x · exp(α)).""" + bid = tl.program_id(0) + cid = tl.program_id(1) + t_off = tl.program_id(2) * block_t + tl.arange(0, block_t) + mask = t_off < t_len + + x = tl.load(x_ptr + bid * stride_b + cid * stride_c + t_off, mask=mask) + alpha = tl.exp(tl.load(alpha_ptr + cid)) + beta = tl.exp(tl.load(beta_ptr + cid)) + sin_val = tl.sin(x * alpha) + result = x + (1.0 / (beta + eps)) * sin_val * sin_val + + tl.store(out_ptr + bid * stride_b + cid * stride_c + t_off, result, mask=mask) + + SnakeBeta._triton_kernel = _kernel + return True + def __init__(self, in_features, alpha=1.0): super().__init__() self.in_features = in_features @@ -630,6 +672,11 @@ def forward(self, hidden_states): Applies the function to the input elementwise. SnakeBeta ∶= x + 1/b * sin^2 (xa) """ + if hidden_states.is_cuda and self._init_triton(): + return self._triton_forward(hidden_states) + return self._eager_forward(hidden_states) + + def _eager_forward(self, hidden_states): alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] beta = self.beta.unsqueeze(0).unsqueeze(-1) alpha = torch.exp(alpha) @@ -637,9 +684,27 @@ def forward(self, hidden_states): hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( torch.sin(hidden_states * alpha), 2 ) - return hidden_states + def _triton_forward(self, x): + import triton + + B, C, T = x.shape + out = torch.empty_like(x) + block_t = min(triton.next_power_of_2(T), 1024) + self._triton_kernel[(B, C, triton.cdiv(T, block_t))]( + x, + self.alpha, + self.beta, + out, + x.stride(0), + x.stride(1), + t_len=T, + eps=self.no_div_by_zero, + block_t=block_t, + ) + return out + class Qwen3TTSTokenizerV2DecoderDecoderResidualUnit(nn.Module): def __init__(self, dim: int = 16, dilation: int = 1): @@ -880,6 +945,8 @@ def enable_cudagraph( self, capture_sizes: list[int] | None = None, device: torch.device | None = None, + codec_chunk_frames: int = 0, + codec_left_context_frames: int = 0, ): from ..cuda_graph_decoder_wrapper import CUDAGraphDecoderWrapper @@ -894,10 +961,17 @@ def enable_cudagraph( num_quantizers=self.config.num_quantizers, enabled=True, ) - self._cudagraph_wrapper.warmup(device, dtype=torch.long) + self._cudagraph_wrapper.warmup( + device, + dtype=torch.long, + codec_chunk_frames=codec_chunk_frames, + codec_left_context_frames=codec_left_context_frames, + ) self._cudagraph_enabled = True - sizes = self._cudagraph_wrapper.capture_sizes - logger.info("CUDA Graph enabled for decoder with sizes: %s", sizes) + logger.info( + "CUDA Graph enabled for decoder: seq_lens=%s", + self._cudagraph_wrapper.capture_sizes, + ) def disable_cudagraph(self): self._cudagraph_enabled = False From 1b9094e2f5498bdd28c4f6b95c58b824bc251daa Mon Sep 17 00:00:00 2001 From: pablo Date: Tue, 10 Mar 2026 19:33:17 +0000 Subject: [PATCH 03/12] add tests for cuda graph and triton snakebeta Signed-off-by: pablo --- .../qwen3_tts/test_cuda_graph_decoder.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py b/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py index 3739687606d..57de426ce6e 100644 --- a/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py +++ b/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py @@ -270,3 +270,53 @@ def test_deterministic_across_calls(decoder, wrapper): out1 = wrapper.decode(codes) out2 = wrapper.decode(codes) torch.testing.assert_close(out1, out2, atol=0, rtol=0) + + +# ────────────────────────────────────────────────────────────────── +# 6. compute_capture_sizes +# ────────────────────────────────────────────────────────────────── + + +@pytest.mark.parametrize( + "kwargs,expected_in", + [ + ({}, [300, 325, 2, 32]), + ({"codec_chunk_frames": 33, "codec_left_context_frames": 25}, [33, 58, 300, 325]), + ({"decode_chunk_size": 150, "decode_left_context": 10}, [150, 160]), + ], + ids=["default", "streaming", "custom_decode"], +) +def test_compute_capture_sizes(kwargs, expected_in): + """compute_capture_sizes produces expected sizes for various configs.""" + sizes = CUDAGraphDecoderWrapper.compute_capture_sizes(**kwargs) + for val in expected_in: + assert val in sizes, f"{val} not in {sizes}" + + +# ────────────────────────────────────────────────────────────────── +# 7. SnakeBeta Triton kernel vs eager equivalence +# ────────────────────────────────────────────────────────────────── + + +@pytest.mark.parametrize( + "batch,channels,seq_len", + [(2, 64, 1000), (1, 32, 1), (1, 32, 7), (1, 32, 128), (1, 32, 1024), (1, 32, 4096)], +) +def test_snakebeta_triton_vs_eager(batch, channels, seq_len): + """Fused Triton SnakeBeta kernel must match eager PyTorch output.""" + from vllm_omni.model_executor.models.qwen3_tts.tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import ( + SnakeBeta, + ) + + if not SnakeBeta._init_triton(): + pytest.skip("Triton not available") + + torch.manual_seed(42) + snake = SnakeBeta(in_features=channels).to(DEVICE).eval() + x = torch.randn(batch, channels, seq_len, device=DEVICE) + + with torch.no_grad(): + eager_out = snake._eager_forward(x) + triton_out = snake._triton_forward(x) + + torch.testing.assert_close(triton_out, eager_out, atol=1e-5, rtol=1e-5) From 1eefebf36c06dbc3a5e7ab7228d7120fb1d4132d Mon Sep 17 00:00:00 2001 From: JuanPZuluaga Date: Wed, 11 Mar 2026 06:46:41 +0000 Subject: [PATCH 04/12] add reviewers comments Signed-off-by: JuanPZuluaga --- .../qwen3_tts/cuda_graph_decoder_wrapper.py | 3 +-- .../modeling_qwen3_tts_tokenizer_v2.py | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py index 0b7ba65fe6d..4e0e20ad1b1 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py +++ b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py @@ -101,14 +101,13 @@ def warmup( self._device = device self.decoder.eval() - logger.info("Starting CUDA Graph warmup for %d sizes: %s", len(self.capture_sizes), self.capture_sizes) if not self._explicit_sizes: self.capture_sizes = self.compute_capture_sizes( codec_chunk_frames=codec_chunk_frames, codec_left_context_frames=codec_left_context_frames, ) - logger.info("CUDA Graph warmup: %d sizes %s", len(self.capture_sizes), self.capture_sizes) + logger.info("Starting CUDA Graph warmup for %d sizes: %s", len(self.capture_sizes), self.capture_sizes) # Warmup runs to ensure CUDA memory is allocated for size in self.capture_sizes: diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py index 4566d07517f..41fe74c1c6c 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py @@ -620,11 +620,12 @@ class SnakeBeta(nn.Module): def _init_triton(): """Load and JIT-compile the fused Triton kernel (once).""" if SnakeBeta._triton_kernel is not None: - return True + return SnakeBeta._triton_kernel is not False try: import triton import triton.language as tl except ImportError: + SnakeBeta._triton_kernel = False return False @triton.jit @@ -640,14 +641,18 @@ def _kernel( # noqa: N803 block_t: tl.constexpr, ): """Fused SnakeBeta: x + (1/exp(β)) · sin²(x · exp(α)).""" + # Grid: (batch, channel, time_block) bid = tl.program_id(0) cid = tl.program_id(1) t_off = tl.program_id(2) * block_t + tl.arange(0, block_t) - mask = t_off < t_len + mask = t_off < t_len # guard out-of-bounds time steps + # Load input tile for this (batch, channel) slice x = tl.load(x_ptr + bid * stride_b + cid * stride_c + t_off, mask=mask) + # Per-channel learned parameters (log-space → exp) alpha = tl.exp(tl.load(alpha_ptr + cid)) beta = tl.exp(tl.load(beta_ptr + cid)) + # SnakeBeta activation: x + (1/β) · sin²(x·α) sin_val = tl.sin(x * alpha) result = x + (1.0 / (beta + eps)) * sin_val * sin_val @@ -672,8 +677,11 @@ def forward(self, hidden_states): Applies the function to the input elementwise. SnakeBeta ∶= x + 1/b * sin^2 (xa) """ - if hidden_states.is_cuda and self._init_triton(): - return self._triton_forward(hidden_states) + if hidden_states.is_cuda and not torch.is_grad_enabled() and self._init_triton(): + try: + return self._triton_forward(hidden_states) + except Exception: + pass return self._eager_forward(hidden_states) def _eager_forward(self, hidden_states): @@ -689,6 +697,7 @@ def _eager_forward(self, hidden_states): def _triton_forward(self, x): import triton + x = x.contiguous() B, C, T = x.shape out = torch.empty_like(x) block_t = min(triton.next_power_of_2(T), 1024) From 4ef54d1a6378bf699459bbcbe9109b7692cc36ec Mon Sep 17 00:00:00 2001 From: JuanPZuluaga Date: Wed, 11 Mar 2026 06:55:31 +0000 Subject: [PATCH 05/12] t_len constexpr Signed-off-by: JuanPZuluaga --- .../qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py index 41fe74c1c6c..1139372ce41 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py @@ -636,7 +636,7 @@ def _kernel( # noqa: N803 out_ptr, stride_b, stride_c, - t_len: tl.constexpr, + t_len, eps: tl.constexpr, block_t: tl.constexpr, ): From adc00c4955eee4d5f0d9d52612464421878fb67d Mon Sep 17 00:00:00 2001 From: JuanPZuluaga Date: Wed, 11 Mar 2026 07:45:22 +0000 Subject: [PATCH 06/12] solve except:pass Signed-off-by: JuanPZuluaga --- .../tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py index 1139372ce41..b37da676ba4 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py @@ -614,7 +614,7 @@ class SnakeBeta(nn.Module): https://huggingface.co/papers/2006.08195 """ - _triton_kernel = None + _triton_kernel = None # None = untried, False = unavailable, callable = ready @staticmethod def _init_triton(): @@ -648,7 +648,7 @@ def _kernel( # noqa: N803 mask = t_off < t_len # guard out-of-bounds time steps # Load input tile for this (batch, channel) slice - x = tl.load(x_ptr + bid * stride_b + cid * stride_c + t_off, mask=mask) + x = tl.load(x_ptr + bid * stride_b + cid * stride_c + t_off, mask=mask, other=0.0) # Per-channel learned parameters (log-space → exp) alpha = tl.exp(tl.load(alpha_ptr + cid)) beta = tl.exp(tl.load(beta_ptr + cid)) @@ -681,7 +681,8 @@ def forward(self, hidden_states): try: return self._triton_forward(hidden_states) except Exception: - pass + logger.warning("Triton SnakeBeta failed, falling back to eager", exc_info=True) + SnakeBeta._triton_kernel = False return self._eager_forward(hidden_states) def _eager_forward(self, hidden_states): From 596af6b132f7002ce13f87fee3f4bf5b695e5668 Mon Sep 17 00:00:00 2001 From: JuanPZuluaga Date: Wed, 11 Mar 2026 08:12:39 +0000 Subject: [PATCH 07/12] simplify logic in sizes captured by cuda graph Signed-off-by: JuanPZuluaga --- .../qwen3_tts/test_cuda_graph_decoder.py | 8 +++---- .../qwen3_tts/cuda_graph_decoder_wrapper.py | 24 ++++--------------- 2 files changed, 9 insertions(+), 23 deletions(-) diff --git a/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py b/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py index 57de426ce6e..0c7d96375cf 100644 --- a/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py +++ b/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py @@ -280,11 +280,11 @@ def test_deterministic_across_calls(decoder, wrapper): @pytest.mark.parametrize( "kwargs,expected_in", [ - ({}, [300, 325, 2, 32]), - ({"codec_chunk_frames": 33, "codec_left_context_frames": 25}, [33, 58, 300, 325]), - ({"decode_chunk_size": 150, "decode_left_context": 10}, [150, 160]), + ({}, [2, 4, 8, 16, 32, 64, 128, 156, 212, 256, 384, 512]), + ({"codec_chunk_frames": 33, "codec_left_context_frames": 25}, [33, 58, 64, 256, 512]), + ({"codec_chunk_frames": 17, "codec_left_context_frames": 25}, [17, 42, 64, 256]), ], - ids=["default", "streaming", "custom_decode"], + ids=["default", "streaming_c33", "streaming_c17"], ) def test_compute_capture_sizes(kwargs, expected_in): """compute_capture_sizes produces expected sizes for various configs.""" diff --git a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py index 4e0e20ad1b1..2971ee047d6 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py +++ b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py @@ -29,7 +29,8 @@ class CUDAGraphDecoderWrapper: output = wrapper.decode(codes) # Automatically uses CUDA graph if possible """ - DEFAULT_CAPTURE_SIZES = [2, 4, 8, 16, 25, 32, 50, 100, 150, 200, 250, 300] + # Power-of-2 backbone for initial phase, low TTFC + mid-range fills to limit padding waste + _BASE_CAPTURE_SIZES = [2, 4, 8, 16, 32, 64, 128, 156, 212, 256, 384, 512] def __init__( self, @@ -40,7 +41,7 @@ def __init__( ): self.decoder = decoder self._explicit_sizes = capture_sizes is not None - self.capture_sizes = sorted(capture_sizes or self.DEFAULT_CAPTURE_SIZES) + self.capture_sizes = sorted(capture_sizes) if capture_sizes else [] self.num_quantizers = num_quantizers self.enabled = enabled @@ -55,31 +56,16 @@ def __init__( def compute_capture_sizes( codec_chunk_frames: int = 0, codec_left_context_frames: int = 0, - decode_chunk_size: int = 300, - decode_left_context: int = 25, ) -> list[int]: """Compute capture sizes from chunking config for high graph hit rate.""" - sizes: set[int] = set() + sizes: set[int] = set(CUDAGraphDecoderWrapper._BASE_CAPTURE_SIZES) - # Streaming sizes + # Exact hits for streaming steady-state (avoids padding overhead) if codec_chunk_frames > 0: sizes.add(codec_chunk_frames) if codec_left_context_frames > 0: sizes.add(codec_chunk_frames + codec_left_context_frames) - # Non-streaming: first chunk + steady-state - sizes.add(decode_chunk_size) - sizes.add(decode_chunk_size + decode_left_context) - - # Buckets for variable-length last chunks - step = max(decode_chunk_size // 8, 10) - for i in range(step, decode_chunk_size + 1, step): - sizes.add(i + decode_left_context) - - # Power-of-2 small sizes for dynamic initial chunk sizes - for s in [2, 4, 8, 16, 32, 64]: - sizes.add(s) - return sorted(sizes) def _get_padded_size(self, actual_size: int) -> int | None: From 621df733573f68d5eeaf8ea6d9cb845cfb04d669 Mon Sep 17 00:00:00 2001 From: JuanPZuluaga Date: Wed, 11 Mar 2026 22:02:00 +0000 Subject: [PATCH 08/12] capture more graph Signed-off-by: JuanPZuluaga --- .../qwen3_tts/test_cuda_graph_decoder.py | 24 +++++++++++++------ .../qwen3_tts/cuda_graph_decoder_wrapper.py | 18 ++++++++++---- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py b/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py index 0c7d96375cf..86af757809d 100644 --- a/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py +++ b/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py @@ -278,19 +278,29 @@ def test_deterministic_across_calls(decoder, wrapper): @pytest.mark.parametrize( - "kwargs,expected_in", + "kwargs,expected_in,not_expected", [ - ({}, [2, 4, 8, 16, 32, 64, 128, 156, 212, 256, 384, 512]), - ({"codec_chunk_frames": 33, "codec_left_context_frames": 25}, [33, 58, 64, 256, 512]), - ({"codec_chunk_frames": 17, "codec_left_context_frames": 25}, [17, 42, 64, 256]), + ({}, [2, 4, 8, 16, 32, 64, 128, 256, 325], [512]), + ( + {"codec_chunk_frames": 33, "codec_left_context_frames": 25}, + [2, 4, 8, 16, 32, 33, 58, 64, 128, 256, 325], + [512], + ), + ( + {"codec_chunk_frames": 25, "codec_left_context_frames": 25}, + [2, 4, 8, 16, 25, 32, 50, 64, 128, 256, 325], + [512], + ), ], - ids=["default", "streaming_c33", "streaming_c17"], + ids=["default", "streaming_c33", "streaming_c25"], ) -def test_compute_capture_sizes(kwargs, expected_in): - """compute_capture_sizes produces expected sizes for various configs.""" +def test_compute_capture_sizes(kwargs, expected_in, not_expected): + """compute_capture_sizes produces expected sizes capped by max useful size.""" sizes = CUDAGraphDecoderWrapper.compute_capture_sizes(**kwargs) for val in expected_in: assert val in sizes, f"{val} not in {sizes}" + for val in not_expected: + assert val not in sizes, f"{val} should not be in {sizes}" # ────────────────────────────────────────────────────────────────── diff --git a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py index 2971ee047d6..96f8c799c13 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py +++ b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py @@ -29,9 +29,6 @@ class CUDAGraphDecoderWrapper: output = wrapper.decode(codes) # Automatically uses CUDA graph if possible """ - # Power-of-2 backbone for initial phase, low TTFC + mid-range fills to limit padding waste - _BASE_CAPTURE_SIZES = [2, 4, 8, 16, 32, 64, 128, 156, 212, 256, 384, 512] - def __init__( self, decoder: torch.nn.Module, @@ -56,16 +53,27 @@ def __init__( def compute_capture_sizes( codec_chunk_frames: int = 0, codec_left_context_frames: int = 0, + decode_chunk_size: int = 300, + decode_left_context: int = 25, ) -> list[int]: """Compute capture sizes from chunking config for high graph hit rate.""" - sizes: set[int] = set(CUDAGraphDecoderWrapper._BASE_CAPTURE_SIZES) + sizes: set[int] = set() - # Exact hits for streaming steady-state (avoids padding overhead) + # Streaming exact hits if codec_chunk_frames > 0: sizes.add(codec_chunk_frames) if codec_left_context_frames > 0: sizes.add(codec_chunk_frames + codec_left_context_frames) + # Non-streaming chunked decode: full chunk + last-chunk buckets + non_stream_max = decode_chunk_size + decode_left_context + sizes.add(non_stream_max) + + # Power-of-2 buckets covering both streaming IC sizes and non-streaming last-chunk sizes + for p2 in [2, 4, 8, 16, 32, 64, 128, 256]: + if p2 <= non_stream_max: + sizes.add(p2) + return sorted(sizes) def _get_padded_size(self, actual_size: int) -> int | None: From b7d51e6f51331bc0261cbcf29ac470386639e062 Mon Sep 17 00:00:00 2001 From: JuanPZuluaga Date: Wed, 11 Mar 2026 22:34:42 +0000 Subject: [PATCH 09/12] fix pre commit in async_omni Signed-off-by: JuanPZuluaga --- vllm_omni/entrypoints/async_omni.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 0761ae6ef34..3c3cb4cd65e 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -140,13 +140,6 @@ def __init__(self, model: str, **kwargs: dict[str, Any]) -> None: getattr(self, "_inline_engine", None), ) - async def get_supported_tasks(self) -> set[str]: - """Return supported tasks based on the configured stage output modalities.""" - tasks: set[str] = set() - if "text" in self.output_modalities: - tasks.add("generate") - return tasks - def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[str, Any]: """Create default diffusion stage configuration.""" # TODO: here is different from the Omni class. We should merge the two in the future. From d165fdeb4486a42f3417f80b7e177feb13dce2fb Mon Sep 17 00:00:00 2001 From: JuanPZuluaga Date: Thu, 19 Mar 2026 05:43:35 +0000 Subject: [PATCH 10/12] cache exp call and increase block size cap Signed-off-by: JuanPZuluaga --- .../models/qwen3_tts/qwen3_tts_code2wav.py | 4 + .../modeling_qwen3_tts_tokenizer_v2.py | 83 ++++++++++++------- 2 files changed, 57 insertions(+), 30 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index 40dc511127f..6be039df105 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -107,6 +107,10 @@ def _ensure_speech_tokenizer_loaded(self) -> None: self._output_sample_rate = out_sr self._total_upsample = int(decoder.total_upsample) + # Precompute SnakeBeta exp caches (benefits both Triton and eager paths) + if hasattr(decoder, "precompute_snake_caches"): + decoder.precompute_snake_caches() + if hasattr(decoder, "enable_cudagraph"): device = self._module_device(decoder) if device.type == "cuda": diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py index b37da676ba4..beb4fd6fbbf 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py @@ -615,6 +615,7 @@ class SnakeBeta(nn.Module): """ _triton_kernel = None # None = untried, False = unavailable, callable = ready + _TRITON_MAX_BLOCK_T = 4096 # upper bound for time-axis tile size @staticmethod def _init_triton(): @@ -631,30 +632,25 @@ def _init_triton(): @triton.jit def _kernel( # noqa: N803 x_ptr, - alpha_ptr, - beta_ptr, + exp_alpha_ptr, + inv_beta_ptr, out_ptr, stride_b, stride_c, t_len, - eps: tl.constexpr, block_t: tl.constexpr, ): - """Fused SnakeBeta: x + (1/exp(β)) · sin²(x · exp(α)).""" - # Grid: (batch, channel, time_block) + """Fused SnakeBeta using precomputed exp(α) and 1/(exp(β)+ε).""" bid = tl.program_id(0) cid = tl.program_id(1) t_off = tl.program_id(2) * block_t + tl.arange(0, block_t) - mask = t_off < t_len # guard out-of-bounds time steps + mask = t_off < t_len - # Load input tile for this (batch, channel) slice x = tl.load(x_ptr + bid * stride_b + cid * stride_c + t_off, mask=mask, other=0.0) - # Per-channel learned parameters (log-space → exp) - alpha = tl.exp(tl.load(alpha_ptr + cid)) - beta = tl.exp(tl.load(beta_ptr + cid)) - # SnakeBeta activation: x + (1/β) · sin²(x·α) - sin_val = tl.sin(x * alpha) - result = x + (1.0 / (beta + eps)) * sin_val * sin_val + ea = tl.load(exp_alpha_ptr + cid) + ib = tl.load(inv_beta_ptr + cid) + sin_val = tl.sin(x * ea) + result = x + ib * sin_val * sin_val tl.store(out_ptr + bid * stride_b + cid * stride_c + t_off, result, mask=mask) @@ -665,18 +661,31 @@ def __init__(self, in_features, alpha=1.0): super().__init__() self.in_features = in_features - # initialize alpha self.alpha = Parameter(torch.zeros(in_features) * alpha) self.beta = Parameter(torch.zeros(in_features) * alpha) self.no_div_by_zero = 0.000000001 - def forward(self, hidden_states): - """ - Forward pass of the function. - Applies the function to the input elementwise. - SnakeBeta ∶= x + 1/b * sin^2 (xa) + # Precomputed buffers (populated by precompute_exp_cache) + self.register_buffer("_exp_alpha", None, persistent=False) + self.register_buffer("_inv_beta", None, persistent=False) + + def precompute_exp_cache(self): + """Materialize exp(alpha) and 1/(exp(beta)+eps) as frozen buffers. + + Call once after weight loading to eliminate transcendental ops at + inference time. Saves 2 exp() calls per element across every forward. """ + with torch.no_grad(): + self._exp_alpha = torch.exp(self.alpha).contiguous() + self._inv_beta = (1.0 / (torch.exp(self.beta) + self.no_div_by_zero)).contiguous() + + @property + def _cached(self): + return self._exp_alpha is not None + + def forward(self, hidden_states): + """SnakeBeta := x + 1/b * sin^2(x*a)""" if hidden_states.is_cuda and not torch.is_grad_enabled() and self._init_triton(): try: return self._triton_forward(hidden_states) @@ -686,31 +695,33 @@ def forward(self, hidden_states): return self._eager_forward(hidden_states) def _eager_forward(self, hidden_states): - alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] - beta = self.beta.unsqueeze(0).unsqueeze(-1) - alpha = torch.exp(alpha) - beta = torch.exp(beta) - hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( - torch.sin(hidden_states * alpha), 2 - ) + if self._cached: + exp_alpha = self._exp_alpha.unsqueeze(0).unsqueeze(-1) + inv_beta = self._inv_beta.unsqueeze(0).unsqueeze(-1) + else: + exp_alpha = torch.exp(self.alpha).unsqueeze(0).unsqueeze(-1) + inv_beta = (1.0 / (torch.exp(self.beta) + self.no_div_by_zero)).unsqueeze(0).unsqueeze(-1) + hidden_states = hidden_states + inv_beta * torch.pow(torch.sin(hidden_states * exp_alpha), 2) return hidden_states def _triton_forward(self, x): import triton + if not self._cached: + self.precompute_exp_cache() + x = x.contiguous() B, C, T = x.shape out = torch.empty_like(x) - block_t = min(triton.next_power_of_2(T), 1024) + block_t = min(triton.next_power_of_2(T), self._TRITON_MAX_BLOCK_T) self._triton_kernel[(B, C, triton.cdiv(T, block_t))]( x, - self.alpha, - self.beta, + self._exp_alpha, + self._inv_beta, out, x.stride(0), x.stride(1), t_len=T, - eps=self.no_div_by_zero, block_t=block_t, ) return out @@ -951,6 +962,16 @@ def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig): self._cudagraph_enabled = False self._cudagraph_wrapper = None + def precompute_snake_caches(self): + """Precompute exp(alpha) and 1/(exp(beta)+eps) for all SnakeBeta modules.""" + count = 0 + for module in self.modules(): + if isinstance(module, SnakeBeta): + module.precompute_exp_cache() + count += 1 + if count > 0: + logger.info("Precomputed exp caches for %d SnakeBeta activations", count) + def enable_cudagraph( self, capture_sizes: list[int] | None = None, @@ -965,6 +986,8 @@ def enable_cudagraph( if device.type != "cuda": logger.warning("Cannot enable CUDA Graph: decoder is not on a CUDA device (got %s)", device) return + + self.precompute_snake_caches() self._cudagraph_wrapper = CUDAGraphDecoderWrapper( decoder=self, capture_sizes=capture_sizes, From ad2af31dd67102917be21b8a130ae35a1572bb35 Mon Sep 17 00:00:00 2001 From: JuanPZuluaga Date: Thu, 19 Mar 2026 06:55:05 +0000 Subject: [PATCH 11/12] remove 2 times called precompute cache Signed-off-by: JuanPZuluaga --- .../qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py index beb4fd6fbbf..5e024f9c4d9 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py @@ -987,7 +987,6 @@ def enable_cudagraph( logger.warning("Cannot enable CUDA Graph: decoder is not on a CUDA device (got %s)", device) return - self.precompute_snake_caches() self._cudagraph_wrapper = CUDAGraphDecoderWrapper( decoder=self, capture_sizes=capture_sizes, From 5c568261b9b1ad45459e098e25ccfcfb78af345a Mon Sep 17 00:00:00 2001 From: JuanPZuluaga Date: Thu, 19 Mar 2026 06:57:50 +0000 Subject: [PATCH 12/12] update docstring Signed-off-by: JuanPZuluaga --- .../tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py index 5e024f9c4d9..dfa48df298a 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py @@ -671,11 +671,7 @@ def __init__(self, in_features, alpha=1.0): self.register_buffer("_inv_beta", None, persistent=False) def precompute_exp_cache(self): - """Materialize exp(alpha) and 1/(exp(beta)+eps) as frozen buffers. - - Call once after weight loading to eliminate transcendental ops at - inference time. Saves 2 exp() calls per element across every forward. - """ + """Materialize exp(alpha) and 1/(exp(beta)+eps) as frozen buffers.""" with torch.no_grad(): self._exp_alpha = torch.exp(self.alpha).contiguous() self._inv_beta = (1.0 / (torch.exp(self.beta) + self.no_div_by_zero)).contiguous()