Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 77 additions & 50 deletions vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,44 +304,82 @@ def __init__(
self._num_codebooks = config.num_codebooks
self._fast_dim = config.hidden_size

# Pre-allocated buffers (lazily initialised on first forward).
# Pre-allocated buffers (lazily initialised in _ensure_buffers).
self._embed_buf: torch.Tensor | None = None
self._pos_ids: torch.Tensor | None = None

# torch.compile state (lazily initialized in _setup_compile).
# CUDA graph capture is handled by the outer CUDAGraphWrapper
# in OmniGPUModelRunner, not here.
self._compiled_model_fwd: object | None = None
self._compile_attempted = False
self._compile_failed = False
self._bucket_sizes: list[int] = []
self._bucket_pos_ids: dict[int, torch.Tensor] = {}

def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None:
def _ensure_buffers(self, device: torch.device, dtype: torch.dtype) -> None:
max_seq = self._num_codebooks + 1 # hidden_state + num_codebooks codes
# Use max of max_num_seqs and max_cudagraph_capture_size so
# the buffer is large enough for CUDAGraphWrapper's padded batches.
max_bsz = max(
self._vllm_config.scheduler_config.max_num_seqs,
self._vllm_config.compilation_config.max_cudagraph_capture_size,
1,
)
if (
self._embed_buf is not None
and self._embed_buf.shape[0] >= bsz
and self._embed_buf.shape[0] >= max_bsz
and self._embed_buf.device == device
and self._embed_buf.dtype == dtype
):
return
self._embed_buf = torch.zeros(bsz, max_seq, self._fast_dim, dtype=dtype, device=device)
self._pos_ids = torch.arange(max_seq, dtype=torch.long, device=device)
self._embed_buf = torch.zeros(max_bsz, max_seq, self._fast_dim, dtype=dtype, device=device)

def _padded_bsz(self, bsz: int) -> int:
for bucket in self._bucket_sizes:
if bsz <= bucket:
return bucket
return bsz

def _setup_compile(self) -> None:
"""Lazily set up compiled forward and position_ids buffers.

No inner CUDA graph capture — the outer CUDAGraphWrapper in
OmniGPUModelRunner captures the entire talker_mtp call (including
this forward) as one graph. We just need torch.compile for
kernel fusion and fixed-shape position_ids for determinism.
"""
if self._compile_attempted:
return
self._compile_attempted = True

try:
self._compiled_model_fwd = torch.compile(
self.model.forward,
# Keep the helper compiler separate from vLLM's outer
# cudagraph-managed Stage-0 execution.
mode="default",
dynamic=True,
fullgraph=False,
dynamic=False,
options={"epilogue_fusion": False},
)
except Exception as exc:
self._compile_failed = True
logger.warning("Failed to enable torch.compile for Fish Speech Fast AR: %s", exc)
logger.warning("Fish Speech Fast AR: torch.compile failed: %s", exc)
self._compiled_model_fwd = self.model.forward
else:
logger.info("Enabled torch.compile for Fish Speech Fast AR forward (mode=default)")
return

# Build batch-size buckets and pre-allocate position_ids.
max_bsz = max(self._vllm_config.scheduler_config.max_num_seqs, 1)
bucket_sizes = [1 << i for i in range(max_bsz.bit_length()) if (1 << i) <= max_bsz]
if max_bsz not in bucket_sizes:
bucket_sizes.append(max_bsz)
self._bucket_sizes = sorted(bucket_sizes)

max_seq = self._num_codebooks + 1
device = next(self.model.parameters()).device
embed_buf = self._embed_buf

for bsz in self._bucket_sizes:
pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(bsz, -1)
self._bucket_pos_ids[bsz] = pos_ids
# Warmup compiled fn to trigger Inductor compilation.
for _ in range(3):
self._compiled_model_fwd(embed_buf[:bsz, :max_seq, :], pos_ids)
logger.info("Fish Speech Fast AR: compile warmup done for buckets %s", self._bucket_sizes)

@torch.inference_mode()
def warmup_compile(
Expand All @@ -350,9 +388,10 @@ def warmup_compile(
dtype: torch.dtype,
batch_sizes: tuple[int, ...] = (1,),
) -> None:
self._ensure_buffers(device, dtype)
self._setup_compile()
if self._compiled_model_fwd is self.model.forward or self._compile_failed:
return
# Run a full forward per warmup batch size so the outer
# CUDAGraphWrapper sees stable shapes during capture.
for batch_size in batch_sizes:
hidden = torch.zeros((batch_size, self.slow_ar_config.hidden_size), device=device, dtype=dtype)
semantic = torch.full(
Expand All @@ -364,22 +403,6 @@ def warmup_compile(
self(hidden, semantic, do_sample=False)
torch.cuda.synchronize(device)

@torch.inference_mode()
def _run_model(self, step_input: torch.Tensor, step_pos_ids: torch.Tensor, bsz: int) -> torch.Tensor:
# Default-on compile only pays off for single-request decode. For
# batched decode, eager preserves loaded throughput and avoids the
# regression seen with batch>1 compiled execution.
model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward
try:
return model_fwd(step_input, step_pos_ids)
except Exception as exc:
if model_fwd is self.model.forward or self._compile_failed:
raise
self._compile_failed = True
self._compiled_model_fwd = self.model.forward
logger.warning("Fish Speech Fast AR torch.compile fallback to eager after runtime failure: %s", exc)
return self.model.forward(step_input, step_pos_ids)

@torch.inference_mode()
def forward(
self,
Expand All @@ -393,6 +416,11 @@ def forward(
) -> torch.Tensor:
"""Predict residual codebook codes 0..num_codebooks-1 autoregressively.

Each step replays a CUDA graph (or compiled forward) over the
full-length embedding buffer [padded_bsz, max_seq, H], then
indexes the relevant position for logits. Sampling happens
outside the graph.

Args:
slow_ar_hidden: [B, hidden_size] last hidden state from Slow AR.
semantic_token_id: [B] or [B, 1] sampled semantic token IDs (in vocab space).
Expand All @@ -409,18 +437,20 @@ def forward(
semantic_begin = self.slow_ar_config.semantic_begin_id
semantic_end = self.slow_ar_config.semantic_end_id
codebook_size = semantic_end - semantic_begin + 1 # 4096
# Convert vocab-space semantic token to codebook index.
# Clamp to valid range: im_end or other non-semantic tokens map to 0 (pad).
semantic_code = (semantic_token_id.reshape(bsz) - semantic_begin).clamp(min=0, max=codebook_size - 1)

all_codes = torch.empty(bsz, num_cb, dtype=torch.long, device=device)
all_codes[:, 0] = semantic_code

self._ensure_buffers(bsz, device, dtype)
self._ensure_buffers(device, dtype)
self._setup_compile()

embed_buf = self._embed_buf
pos_ids = self._pos_ids
max_seq = num_cb + 1

# Pad batch to a CUDA graph bucket size.
padded_bsz = self._padded_bsz(bsz)
embed_buf[:padded_bsz].zero_()

# Position 0: projected Slow AR hidden state.
projected = self.fast_project_in(slow_ar_hidden.reshape(bsz, -1))
Expand All @@ -432,23 +462,20 @@ def forward(

use_sampling = do_sample and temperature > 0
inv_temperature = 1.0 / max(temperature, 1e-6) if use_sampling else 0.0

# Residual codebook size (1024) vs semantic codebook size (4096).
# The fast_output head has codebook_size (4096) outputs, but residual
# codebooks only have 1024 entries. Truncate logits for steps > 0.
residual_codebook_size = 1024

# Resolve compiled forward and position_ids for this bucket.
model_fwd = self._compiled_model_fwd or self.model.forward
pos_ids = self._bucket_pos_ids.get(padded_bsz)
if pos_ids is None:
pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(padded_bsz, -1)

for step in range(1, num_cb):
seq_len = step + 1
step_input = embed_buf[:bsz, :seq_len, :]
# Use a dense 2D position tensor for every batch size; stride-0
# views from expand() were fragile under compiled execution.
step_pos_ids = pos_ids[:seq_len].unsqueeze(0).repeat(bsz, 1)
# Full-buffer forward (fixed shape — captured by outer CUDAGraphWrapper).
hidden_out = model_fwd(embed_buf[:padded_bsz, :max_seq, :], pos_ids)

hidden_out = self._run_model(step_input, step_pos_ids, bsz)
logits = self.fast_output(self.fast_norm(hidden_out[:, -1, :]))
logits = self.fast_output(self.fast_norm(hidden_out[:bsz, step, :]))

# Residual codebooks (step >= 1) only have 1024 entries.
if step >= 1:
logits = logits[:, :residual_codebook_size]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
prefix="fast_ar",
)

# Expose .talker so OmniGPUModelRunner wraps talker_mtp in
# CUDAGraphWrapper, capturing the entire Fast AR codebook
# decode loop in one graph replay.
self.talker = self.fast_ar

# Constant logit mask: allow only semantic tokens + im_end.
vocab = int(self.text_config.vocab_size)
semantic_mask = torch.zeros((vocab,), dtype=torch.bool)
Expand Down Expand Up @@ -622,19 +627,16 @@ def talker_mtp(
# This ensures the Slow AR sees codes from FastAR(hidden_{t-1}).
inputs_embeds_out = input_embeds.reshape(bsz, -1).clone()

# Branchless codebook embedding (CUDA-graph-safe: no data-dependent
# control flow). Compute for all positions, mask via torch.where.
semantic_mask = (input_ids[:, 0] >= self._semantic_begin_id) & (input_ids[:, 0] <= self._semantic_end_id)
if semantic_mask.any():
semantic_codes = audio_codes[semantic_mask].clamp(min=0)
offsets = (
torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size
).unsqueeze(0)
codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16)

# Normalize by sqrt(num_codebooks + 1) as in the reference model
# (scale_codebook_embeddings=True for fish_qwen3_omni).
inputs_embeds_out[semantic_mask] = (inputs_embeds_out[semantic_mask] + codebook_sum) / math.sqrt(
self._num_codebooks + 1
)
clamped_codes = audio_codes.clamp(min=0)
offsets = (
torch.arange(self._num_codebooks, device=dev, dtype=clamped_codes.dtype) * self._codebook_size
).unsqueeze(0)
codebook_sum = self.codebook_embeddings(clamped_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16)
normalized = (inputs_embeds_out + codebook_sum) / math.sqrt(self._num_codebooks + 1)
inputs_embeds_out = torch.where(semantic_mask.unsqueeze(-1), normalized, inputs_embeds_out)

return inputs_embeds_out, audio_codes.to(dtype=torch.long)

Expand Down
Loading