Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/offline_inference/qwen3_omni/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,13 @@ def main(args):
output_wav = os.path.join(output_dir, f"output_{request_id}.wav")

# Convert to numpy array and ensure correct format
# In async_chunk mode, audio may arrive as a list of chunks
if isinstance(audio_tensor, list):
import torch

audio_tensor = torch.cat(
[(t if isinstance(t, torch.Tensor) else torch.tensor(t)).flatten() for t in audio_tensor]
)
audio_numpy = audio_tensor.float().detach().cpu().numpy()

# Ensure audio is 1D (flatten if needed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,7 @@ def test_compute_capture_sizes(kwargs, expected_in, not_expected):
)
def test_snakebeta_triton_vs_eager(batch, channels, seq_len):
"""Fused Triton SnakeBeta kernel must match eager PyTorch output."""
from vllm_omni.model_executor.models.qwen3_tts.tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import (
SnakeBeta,
)
from vllm_omni.model_executor.models.common.snake_activation import SnakeBeta

if not SnakeBeta._init_triton():
pytest.skip("Triton not available")
Expand Down
135 changes: 135 additions & 0 deletions vllm_omni/model_executor/models/common/snake_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Qwen team.
"""Shared SnakeBeta activation for speech decoders (Qwen3-TTS, Qwen3-Omni Code2Wav)."""

import torch
from torch import nn
from torch.nn import Parameter
from vllm.logger import init_logger

logger = init_logger(__name__)


class SnakeBeta(nn.Module):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper
by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://huggingface.co/papers/2006.08195
"""

_triton_kernel = None # None = untried, False = unavailable, callable = ready
_TRITON_MAX_BLOCK_T = 4096 # upper bound for time-axis tile size

@staticmethod
def _init_triton():
"""Load and JIT-compile the fused Triton kernel (once)."""
if SnakeBeta._triton_kernel is not None:
return SnakeBeta._triton_kernel is not False
try:
import triton
import triton.language as tl
except ImportError:
SnakeBeta._triton_kernel = False
return False

@triton.jit
def _kernel( # noqa: N803
x_ptr,
exp_alpha_ptr,
inv_beta_ptr,
out_ptr,
stride_b,
stride_c,
t_len,
block_t: tl.constexpr,
):
"""Fused SnakeBeta using precomputed exp(α) and 1/(exp(β)+ε)."""
bid = tl.program_id(0)
cid = tl.program_id(1)
t_off = tl.program_id(2) * block_t + tl.arange(0, block_t)
mask = t_off < t_len

x = tl.load(x_ptr + bid * stride_b + cid * stride_c + t_off, mask=mask, other=0.0)
ea = tl.load(exp_alpha_ptr + cid)
ib = tl.load(inv_beta_ptr + cid)
sin_val = tl.sin(x * ea)
result = x + ib * sin_val * sin_val

tl.store(out_ptr + bid * stride_b + cid * stride_c + t_off, result, mask=mask)

SnakeBeta._triton_kernel = _kernel
return True

def __init__(self, in_features, alpha=1.0):
super().__init__()
self.in_features = in_features

self.alpha = Parameter(torch.zeros(in_features) * alpha)
self.beta = Parameter(torch.zeros(in_features) * alpha)

self.no_div_by_zero = 0.000000001

# Precomputed buffers (populated by precompute_exp_cache)
self.register_buffer("_exp_alpha", None, persistent=False)
self.register_buffer("_inv_beta", None, persistent=False)

def precompute_exp_cache(self):
"""Materialize exp(alpha) and 1/(exp(beta)+eps) as frozen buffers."""
with torch.no_grad():
self._exp_alpha = torch.exp(self.alpha).contiguous()
self._inv_beta = (1.0 / (torch.exp(self.beta) + self.no_div_by_zero)).contiguous()

@property
def _cached(self):
return self._exp_alpha is not None

def forward(self, hidden_states):
"""SnakeBeta := x + 1/b * sin^2(x*a)"""
if hidden_states.is_cuda and not torch.is_grad_enabled() and self._init_triton():
try:
return self._triton_forward(hidden_states)
except Exception:
logger.warning("Triton SnakeBeta failed, falling back to eager", exc_info=True)
SnakeBeta._triton_kernel = False
return self._eager_forward(hidden_states)

def _eager_forward(self, hidden_states):
if self._cached:
exp_alpha = self._exp_alpha.unsqueeze(0).unsqueeze(-1)
inv_beta = self._inv_beta.unsqueeze(0).unsqueeze(-1)
else:
exp_alpha = torch.exp(self.alpha).unsqueeze(0).unsqueeze(-1)
inv_beta = (1.0 / (torch.exp(self.beta) + self.no_div_by_zero)).unsqueeze(0).unsqueeze(-1)
hidden_states = hidden_states + inv_beta * torch.pow(torch.sin(hidden_states * exp_alpha), 2)
return hidden_states

def _triton_forward(self, x):
import triton

if not self._cached:
self.precompute_exp_cache()

x = x.contiguous()
B, C, T = x.shape
out = torch.empty_like(x)
block_t = min(triton.next_power_of_2(T), self._TRITON_MAX_BLOCK_T)
self._triton_kernel[(B, C, triton.cdiv(T, block_t))](
x,
self._exp_alpha,
self._inv_beta,
out,
x.stride(0),
x.stride(1),
t_len=T,
block_t=block_t,
)
return out
27 changes: 27 additions & 0 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,6 +1257,18 @@ def sample(

# ==================== Weight Loading ====================

def _get_codec_frame_config(self) -> tuple[int, int]:
"""Extract codec_chunk_frames and codec_left_context_frames from stage connector config."""
model_cfg = getattr(self.vllm_config, "model_config", None)
connector_cfg = getattr(model_cfg, "stage_connector_config", None)
if isinstance(connector_cfg, dict):
extra = connector_cfg.get("extra", {})
else:
extra = getattr(connector_cfg, "extra", None) or {}
chunk_frames = int(extra.get("codec_chunk_frames", 0) or 0)
left_frames = int(extra.get("codec_left_context_frames", 0) or 0)
return chunk_frames, left_frames

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights for all components of the omni model."""
loaded_weights = set()
Expand Down Expand Up @@ -1293,6 +1305,21 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
code2wav_loaded = add_prefix_to_loaded_weights(code2wav_loaded, "code2wav")
loaded_weights.update(code2wav_loaded)

# Precompute SnakeBeta caches and enable CUDA graph for Code2Wav decoder
try:
self.code2wav.precompute_snake_caches()
if hasattr(self.code2wav, "enable_cudagraph"):
chunk_frames, left_frames = self._get_codec_frame_config()
self.code2wav.enable_cudagraph(
codec_chunk_frames=chunk_frames,
codec_left_context_frames=left_frames,
)
except Exception:
logger.warning(
"Failed to enable CUDA Graph for Code2Wav; falling back to eager.",
exc_info=True,
)

# Log summary
logger.info(
"Loaded %d weights for Qwen3OmniMoe (stage=%s)",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bare except Exception here means a bug in the config extraction (e.g. a typo in the key name) will silently disable graphs with no easy way to diagnose. Since this runs once at load time, just let it propagate — the caller can decide whether CUDA graph failure is fatal.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed now.

Expand Down
102 changes: 84 additions & 18 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
Qwen3OmniMoeCode2WavDecoderBlock,
Qwen3OmniMoeCode2WavTransformerModel,
Qwen3OmniMoeConvNeXtBlock,
SnakeBeta,
)
from vllm.config import VllmConfig # type: ignore
from vllm.logger import init_logger # type: ignore
Expand All @@ -28,6 +27,8 @@
WeightsMapper,
)

from vllm_omni.model_executor.models.common.snake_activation import SnakeBeta

logger = init_logger(__name__)


Expand Down Expand Up @@ -119,6 +120,61 @@ def __init__(
]
self.decoder = nn.ModuleList(decoder)

# CUDA Graph support — reuses CUDAGraphDecoderWrapper from Qwen3-TTS
self._cudagraph_enabled = False
self._cudagraph_wrapper = None

def precompute_snake_caches(self):
"""Precompute exp(alpha) and 1/(exp(beta)+eps) for all SnakeBeta modules."""
count = 0
for module in self.modules():
if isinstance(module, SnakeBeta):
module.precompute_exp_cache()
count += 1
if count > 0:
logger.info("Precomputed exp caches for %d SnakeBeta activations", count)

def enable_cudagraph(
self,
device: torch.device | None = None,
codec_chunk_frames: int = 0,
codec_left_context_frames: int = 0,
):
"""Enable CUDA graph acceleration (same pattern as Qwen3-TTS Code2Wav)."""
from vllm_omni.model_executor.models.qwen3_tts.cuda_graph_decoder_wrapper import (
CUDAGraphDecoderWrapper,
)

if device is None:
device = next(self.parameters()).device
if device.type != "cuda":
logger.warning("Cannot enable CUDA Graph: not on CUDA device (got %s)", device)
return

wrapper = CUDAGraphDecoderWrapper(
decoder=self,
num_quantizers=self.config.num_quantizers,
enabled=True,
)
try:
wrapper.warmup(
device,
dtype=torch.long,
codec_chunk_frames=codec_chunk_frames,
codec_left_context_frames=codec_left_context_frames,
)
except Exception:
self._cudagraph_wrapper = None
self._cudagraph_enabled = False
raise
self._cudagraph_wrapper = wrapper
self._cudagraph_enabled = True
logger.info(
"CUDA Graph enabled for Code2Wav: num_quantizers=%d, sizes=%s",
self.config.num_quantizers,
self._cudagraph_wrapper.capture_sizes,
)

def forward(self, codes: torch.Tensor) -> torch.Tensor:
"""
Convert num_quantizers-layer RVQ codes to audio waveform.
Expand Down Expand Up @@ -168,6 +224,8 @@ def chunked_decode(
Decode long sequences in chunks to avoid OOM.

Uses overlapping chunks with left context to avoid boundary artifacts.
When CUDA graphs are enabled, delegates chunk-level decoding to the
CUDAGraphDecoderWrapper for reduced kernel launch overhead.

Args:
codes: [batch, num_quantizers, seq_len] - num_quantizers-layer RVQ codes
Expand All @@ -180,35 +238,40 @@ def chunked_decode(
codes. For ``batch_size == 1``, this is a list containing a
single tensor with shape ``[1, waveform_len]``.
"""
wavs = []
start_index = 0
# Use CUDA graph wrapper for chunk-level decode when available
if self._cudagraph_enabled and self._cudagraph_wrapper is not None:
batch_wav = self._cudagraph_wrapper.chunked_decode_with_cudagraph(codes, chunk_size, left_context_size)
else:
wavs = []
start_index = 0

while start_index < codes.shape[-1]:
end_index = min(start_index + chunk_size, codes.shape[-1])
context_size = left_context_size if start_index >= left_context_size else start_index

while start_index < codes.shape[-1]:
end_index = min(start_index + chunk_size, codes.shape[-1])
context_size = left_context_size if start_index >= left_context_size else start_index
# Extract chunk with left context
codes_chunk = codes[..., start_index - context_size : end_index]

# Extract chunk with left context
codes_chunk = codes[..., start_index - context_size : end_index]
# Decode chunk
wav_chunk = self(codes_chunk)

# Decode chunk
wav_chunk = self(codes_chunk)
# Remove context from output (context_size * total_upsample samples)
wavs.append(wav_chunk[..., context_size * self.total_upsample :])

# Remove context from output (context_size * total_upsample samples)
wavs.append(wav_chunk[..., context_size * self.total_upsample :])
start_index = end_index

start_index = end_index
batch_wav = torch.cat(wavs, dim=-1)

if seq_token_counts is not None:
code_seq_lens = [seq_len // self.config.num_quantizers for seq_len in seq_token_counts]
else:
# Fallback: assume all batch elements share the same sequence length.
code_seq_lens = [codes.shape[-1]] * codes.shape[0]
batch_wav = torch.cat(wavs, dim=-1)
wavs = []
result = []
for idx, code_seq_len in enumerate(code_seq_lens):
wav_chunk = batch_wav[idx, :, : code_seq_len * self.total_upsample]
wavs.append(wav_chunk)
return wavs
result.append(wav_chunk)
return result

def chunked_decode_streaming(
self,
Expand Down Expand Up @@ -241,7 +304,10 @@ def chunked_decode_streaming(
left_context_size = [0] * codes.shape[0]
# Decode chunk
wavs = []
batch_wav = self(codes)
if self._cudagraph_enabled and self._cudagraph_wrapper is not None:
batch_wav = self._cudagraph_wrapper.decode(codes)
else:
batch_wav = self(codes)
if seq_token_counts is not None:
code_seq_lens = [n // self.config.num_quantizers for n in seq_token_counts]
else:
Expand Down
Loading
Loading