Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
60ffb77
optimize cuda graph capture script for code2wav
Mar 10, 2026
4b6a949
add triton kernel for SnakeBeta Code2Wav
Mar 10, 2026
1b9094e
add tests for cuda graph and triton snakebeta
Mar 10, 2026
2fc84a4
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
Mar 10, 2026
1eefebf
add reviewers comments
Mar 11, 2026
c2329a9
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
Mar 11, 2026
4ef54d1
t_len constexpr
Mar 11, 2026
adc00c4
solve except:pass
Mar 11, 2026
596af6b
simplify logic in sizes captured by cuda graph
Mar 11, 2026
d7dd80f
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
Mar 11, 2026
621df73
capture more graph
Mar 11, 2026
fc0af3a
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
Mar 11, 2026
b7d51e6
fix pre commit in async_omni
Mar 11, 2026
856dd74
merge main
Mar 12, 2026
fa5899d
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
Mar 12, 2026
7d719f1
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
Mar 12, 2026
c45ea0e
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
Mar 13, 2026
6d8de1d
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
Mar 13, 2026
60cc2c5
Merge branch 'main' into feat/code2wav-batch-cuda-graph
JuanPZuluaga Mar 15, 2026
6913060
Merge branch 'main' into feat/code2wav-batch-cuda-graph
JuanPZuluaga Mar 16, 2026
b4d415e
Merge branch 'main' into feat/code2wav-batch-cuda-graph
JuanPZuluaga Mar 16, 2026
a4c06aa
Merge branch 'main' into feat/code2wav-batch-cuda-graph
hsliuustc0106 Mar 18, 2026
37042bf
Merge branch 'main' into feat/code2wav-batch-cuda-graph
JuanPZuluaga Mar 18, 2026
0ab4594
Merge branch 'main' into feat/code2wav-batch-cuda-graph
JuanPZuluaga Mar 18, 2026
0bce467
Merge branch 'main' into feat/code2wav-batch-cuda-graph
linyueqian Mar 18, 2026
d7556ea
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
Mar 19, 2026
6156b64
Merge branch 'feat/code2wav-batch-cuda-graph' of https://github.com/J…
Mar 19, 2026
d165fde
cache exp call and increase block size cap
Mar 19, 2026
ad2af31
remove 2 times called precompute cache
Mar 19, 2026
5c56826
update docstring
Mar 19, 2026
cb95545
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
Mar 19, 2026
b586968
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
Mar 19, 2026
cd18316
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
Mar 19, 2026
63b665d
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,63 @@ def test_deterministic_across_calls(decoder, wrapper):
out1 = wrapper.decode(codes)
out2 = wrapper.decode(codes)
torch.testing.assert_close(out1, out2, atol=0, rtol=0)


# ──────────────────────────────────────────────────────────────────
# 6. compute_capture_sizes
# ──────────────────────────────────────────────────────────────────


@pytest.mark.parametrize(
"kwargs,expected_in,not_expected",
[
({}, [2, 4, 8, 16, 32, 64, 128, 256, 325], [512]),
(
{"codec_chunk_frames": 33, "codec_left_context_frames": 25},
[2, 4, 8, 16, 32, 33, 58, 64, 128, 256, 325],
[512],
),
(
{"codec_chunk_frames": 25, "codec_left_context_frames": 25},
[2, 4, 8, 16, 25, 32, 50, 64, 128, 256, 325],
[512],
),
],
ids=["default", "streaming_c33", "streaming_c25"],
)
def test_compute_capture_sizes(kwargs, expected_in, not_expected):
"""compute_capture_sizes produces expected sizes capped by max useful size."""
sizes = CUDAGraphDecoderWrapper.compute_capture_sizes(**kwargs)
for val in expected_in:
assert val in sizes, f"{val} not in {sizes}"
for val in not_expected:
assert val not in sizes, f"{val} should not be in {sizes}"


# ──────────────────────────────────────────────────────────────────
# 7. SnakeBeta Triton kernel vs eager equivalence
# ──────────────────────────────────────────────────────────────────


@pytest.mark.parametrize(
"batch,channels,seq_len",
[(2, 64, 1000), (1, 32, 1), (1, 32, 7), (1, 32, 128), (1, 32, 1024), (1, 32, 4096)],
)
def test_snakebeta_triton_vs_eager(batch, channels, seq_len):
"""Fused Triton SnakeBeta kernel must match eager PyTorch output."""
from vllm_omni.model_executor.models.qwen3_tts.tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import (
SnakeBeta,
)

if not SnakeBeta._init_triton():
pytest.skip("Triton not available")

torch.manual_seed(42)
snake = SnakeBeta(in_features=channels).to(DEVICE).eval()
x = torch.randn(batch, channels, seq_len, device=DEVICE)

with torch.no_grad():
eager_out = snake._eager_forward(x)
triton_out = snake._triton_forward(x)

torch.testing.assert_close(triton_out, eager_out, atol=1e-5, rtol=1e-5)
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class CUDAGraphDecoderWrapper:
output = wrapper.decode(codes) # Automatically uses CUDA graph if possible
"""

DEFAULT_CAPTURE_SIZES = [2, 4, 8, 16, 25, 32, 50, 100, 150, 200, 250, 300]

def __init__(
self,
decoder: torch.nn.Module,
Expand All @@ -39,7 +37,8 @@ def __init__(
enabled: bool = True,
):
self.decoder = decoder
self.capture_sizes = capture_sizes or self.DEFAULT_CAPTURE_SIZES
self._explicit_sizes = capture_sizes is not None
self.capture_sizes = sorted(capture_sizes) if capture_sizes else []
self.num_quantizers = num_quantizers
self.enabled = enabled

Expand All @@ -50,66 +49,82 @@ def __init__(
self._warmed_up = False
self._device = None

@staticmethod
def compute_capture_sizes(
codec_chunk_frames: int = 0,
codec_left_context_frames: int = 0,
decode_chunk_size: int = 300,
decode_left_context: int = 25,
) -> list[int]:
"""Compute capture sizes from chunking config for high graph hit rate."""
sizes: set[int] = set()

# Streaming exact hits
if codec_chunk_frames > 0:
sizes.add(codec_chunk_frames)
if codec_left_context_frames > 0:
sizes.add(codec_chunk_frames + codec_left_context_frames)

# Non-streaming chunked decode: full chunk + last-chunk buckets
non_stream_max = decode_chunk_size + decode_left_context
sizes.add(non_stream_max)

# Power-of-2 buckets covering both streaming IC sizes and non-streaming last-chunk sizes
for p2 in [2, 4, 8, 16, 32, 64, 128, 256]:
if p2 <= non_stream_max:
sizes.add(p2)

return sorted(sizes)

def _get_padded_size(self, actual_size: int) -> int | None:
for size in self.capture_sizes:
if actual_size <= size:
return size
return None

def warmup(self, device: torch.device, dtype: torch.dtype = torch.long):
if device.type != "cuda":
logger.info("CUDA Graph warmup skipped: device %s is not CUDA", device)
return

if not self.enabled:
logger.info("CUDA Graph is disabled, skipping warmup")
return

if self._warmed_up:
logger.warning("CUDA Graph already warmed up, skipping")
def warmup(
self,
device: torch.device,
dtype: torch.dtype = torch.long,
codec_chunk_frames: int = 0,
codec_left_context_frames: int = 0,
):
if device.type != "cuda" or not self.enabled or self._warmed_up:
return

self._device = device
self.decoder.eval()

if not self._explicit_sizes:
self.capture_sizes = self.compute_capture_sizes(
codec_chunk_frames=codec_chunk_frames,
codec_left_context_frames=codec_left_context_frames,
)

logger.info("Starting CUDA Graph warmup for %d sizes: %s", len(self.capture_sizes), self.capture_sizes)

# Warmup runs to ensure CUDA memory is allocated
for size in self.capture_sizes:
dummy_codes = torch.zeros(
1,
self.num_quantizers,
size,
dtype=dtype,
device=device,
)
dummy = torch.zeros(1, self.num_quantizers, size, dtype=dtype, device=device)
with torch.no_grad():
_ = self.decoder(dummy_codes)
_ = self.decoder(dummy)

torch.cuda.synchronize(device)

for size in self.capture_sizes:
try:
self._capture_graph_for_size(size, device, dtype)
self._capture(size, device, dtype)
logger.info(" Captured CUDA Graph for size=%d", size)
except Exception:
logger.warning(" Failed to capture CUDA Graph for size=%d", size, exc_info=True)
logger.warning(" Failed to capture graph for size=%d", size, exc_info=True)

self._warmed_up = True
logger.info("CUDA Graph warmup complete. Captured %d graphs.", len(self.graphs))

def _capture_graph_for_size(self, size: int, device: torch.device, dtype: torch.dtype):
static_input = torch.zeros(
1,
self.num_quantizers,
size,
dtype=dtype,
device=device,
)
logger.info("CUDA Graph warmup complete: %d/%d captured", len(self.graphs), len(self.capture_sizes))

def _capture(self, size: int, device: torch.device, dtype: torch.dtype):
static_input = torch.zeros(1, self.num_quantizers, size, dtype=dtype, device=device)
with torch.no_grad():
_ = self.decoder(static_input)

torch.cuda.synchronize(device)

graph = CUDAGraph()
Expand All @@ -122,10 +137,7 @@ def _capture_graph_for_size(self, size: int, device: torch.device, dtype: torch.
self.static_outputs[size] = static_output

def decode(self, codes: torch.Tensor) -> torch.Tensor:
if not self.enabled or not self._warmed_up:
return self.decoder(codes)

if codes.shape[0] != 1:
if not self.enabled or not self._warmed_up or codes.shape[0] != 1:
return self.decoder(codes)

actual_size = codes.shape[-1]
Expand All @@ -136,14 +148,10 @@ def decode(self, codes: torch.Tensor) -> torch.Tensor:

self.static_inputs[padded_size].zero_()
self.static_inputs[padded_size][:, :, :actual_size] = codes

self.graphs[padded_size].replay()

output = self.static_outputs[padded_size]
total_upsample = self.decoder.total_upsample
actual_output_len = actual_size * total_upsample

return output[..., :actual_output_len].clone()
actual_out_len = actual_size * self.decoder.total_upsample
return self.static_outputs[padded_size][..., :actual_out_len].clone()

def chunked_decode_with_cudagraph(
self,
Expand Down
30 changes: 15 additions & 15 deletions vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,17 @@ def _ensure_speech_tokenizer_loaded(self) -> None:
self._output_sample_rate = out_sr
self._total_upsample = int(decoder.total_upsample)

# Precompute SnakeBeta exp caches (benefits both Triton and eager paths)
if hasattr(decoder, "precompute_snake_caches"):
decoder.precompute_snake_caches()

if hasattr(decoder, "enable_cudagraph"):
device = self._module_device(decoder)
if device.type == "cuda":
try:
capture_sizes = None
chunk_frames = 0
left_frames = 0

model_cfg = getattr(self.vllm_config, "model_config", None)
connector_cfg = getattr(model_cfg, "stage_connector_config", None)
extra_cfg = (
Expand All @@ -122,12 +128,12 @@ def _ensure_speech_tokenizer_loaded(self) -> None:
if isinstance(extra_cfg, dict):
chunk_frames = int(extra_cfg.get("codec_chunk_frames") or 0)
left_frames = int(extra_cfg.get("codec_left_context_frames") or 0)
if chunk_frames > 0 and left_frames >= 0:
from .cuda_graph_decoder_wrapper import CUDAGraphDecoderWrapper

steady_window = left_frames + chunk_frames
capture_sizes = sorted({*CUDAGraphDecoderWrapper.DEFAULT_CAPTURE_SIZES, steady_window})
decoder.enable_cudagraph(capture_sizes=capture_sizes, device=device)
decoder.enable_cudagraph(
device=device,
codec_chunk_frames=chunk_frames,
codec_left_context_frames=left_frames,
)
logger.info("Code2Wav decoder CUDA Graph enabled")
except Exception:
logger.warning("Failed to enable CUDA Graph for Code2Wav decoder", exc_info=True)
Expand Down Expand Up @@ -265,18 +271,12 @@ def forward(
pass

# Decode directly via decoder.chunked_decode(), staying entirely on GPU.
# For single request: no padding needed, fast path.
# For multiple requests: decode each individually to avoid padding overhead.
# Each request decoded individually with CUDA graph replay at bs=1.
wav_tensors: list[torch.Tensor] = []
if len(valid_codes_qf) == 1:
codes_bqf = valid_codes_qf[0].unsqueeze(0) # [1, Q, F]
for codes_qf in valid_codes_qf:
codes_bqf = codes_qf.unsqueeze(0) # [1, Q, F]
wav = decoder.chunked_decode(codes_bqf) # [1, 1, wav_len]
wav_tensors.append(wav.squeeze(0).squeeze(0)) # [wav_len]
else:
for codes_qf in valid_codes_qf:
codes_bqf = codes_qf.unsqueeze(0) # [1, Q, F]
wav = decoder.chunked_decode(codes_bqf)
wav_tensors.append(wav.squeeze(0).squeeze(0))

audios: list[torch.Tensor] = [empty] * num_req
srs = [sr_tensor] * num_req
Expand Down
Loading
Loading