diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py index 8bbb643ebec..1c5b80daed0 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py @@ -304,44 +304,82 @@ def __init__( self._num_codebooks = config.num_codebooks self._fast_dim = config.hidden_size - # Pre-allocated buffers (lazily initialised on first forward). + # Pre-allocated buffers (lazily initialised in _ensure_buffers). self._embed_buf: torch.Tensor | None = None - self._pos_ids: torch.Tensor | None = None + + # torch.compile state (lazily initialized in _setup_compile). + # CUDA graph capture is handled by the outer CUDAGraphWrapper + # in OmniGPUModelRunner, not here. self._compiled_model_fwd: object | None = None self._compile_attempted = False - self._compile_failed = False + self._bucket_sizes: list[int] = [] + self._bucket_pos_ids: dict[int, torch.Tensor] = {} - def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None: + def _ensure_buffers(self, device: torch.device, dtype: torch.dtype) -> None: max_seq = self._num_codebooks + 1 # hidden_state + num_codebooks codes + # Use max of max_num_seqs and max_cudagraph_capture_size so + # the buffer is large enough for CUDAGraphWrapper's padded batches. + max_bsz = max( + self._vllm_config.scheduler_config.max_num_seqs, + self._vllm_config.compilation_config.max_cudagraph_capture_size, + 1, + ) if ( self._embed_buf is not None - and self._embed_buf.shape[0] >= bsz + and self._embed_buf.shape[0] >= max_bsz and self._embed_buf.device == device and self._embed_buf.dtype == dtype ): return - self._embed_buf = torch.zeros(bsz, max_seq, self._fast_dim, dtype=dtype, device=device) - self._pos_ids = torch.arange(max_seq, dtype=torch.long, device=device) + self._embed_buf = torch.zeros(max_bsz, max_seq, self._fast_dim, dtype=dtype, device=device) + + def _padded_bsz(self, bsz: int) -> int: + for bucket in self._bucket_sizes: + if bsz <= bucket: + return bucket + return bsz def _setup_compile(self) -> None: + """Lazily set up compiled forward and position_ids buffers. + + No inner CUDA graph capture — the outer CUDAGraphWrapper in + OmniGPUModelRunner captures the entire talker_mtp call (including + this forward) as one graph. We just need torch.compile for + kernel fusion and fixed-shape position_ids for determinism. + """ if self._compile_attempted: return self._compile_attempted = True + try: self._compiled_model_fwd = torch.compile( self.model.forward, - # Keep the helper compiler separate from vLLM's outer - # cudagraph-managed Stage-0 execution. - mode="default", - dynamic=True, - fullgraph=False, + dynamic=False, + options={"epilogue_fusion": False}, ) except Exception as exc: - self._compile_failed = True - logger.warning("Failed to enable torch.compile for Fish Speech Fast AR: %s", exc) + logger.warning("Fish Speech Fast AR: torch.compile failed: %s", exc) self._compiled_model_fwd = self.model.forward - else: - logger.info("Enabled torch.compile for Fish Speech Fast AR forward (mode=default)") + return + + # Build batch-size buckets and pre-allocate position_ids. + max_bsz = max(self._vllm_config.scheduler_config.max_num_seqs, 1) + bucket_sizes = [1 << i for i in range(max_bsz.bit_length()) if (1 << i) <= max_bsz] + if max_bsz not in bucket_sizes: + bucket_sizes.append(max_bsz) + self._bucket_sizes = sorted(bucket_sizes) + + max_seq = self._num_codebooks + 1 + device = next(self.model.parameters()).device + embed_buf = self._embed_buf + + for bsz in self._bucket_sizes: + pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(bsz, -1) + self._bucket_pos_ids[bsz] = pos_ids + # Warmup compiled fn to trigger Inductor compilation. + for _ in range(3): + self._compiled_model_fwd(embed_buf[:bsz, :max_seq, :], pos_ids) + logger.info("Fish Speech Fast AR: compile warmup done for buckets %s", self._bucket_sizes) @torch.inference_mode() def warmup_compile( @@ -350,9 +388,10 @@ def warmup_compile( dtype: torch.dtype, batch_sizes: tuple[int, ...] = (1,), ) -> None: + self._ensure_buffers(device, dtype) self._setup_compile() - if self._compiled_model_fwd is self.model.forward or self._compile_failed: - return + # Run a full forward per warmup batch size so the outer + # CUDAGraphWrapper sees stable shapes during capture. for batch_size in batch_sizes: hidden = torch.zeros((batch_size, self.slow_ar_config.hidden_size), device=device, dtype=dtype) semantic = torch.full( @@ -364,22 +403,6 @@ def warmup_compile( self(hidden, semantic, do_sample=False) torch.cuda.synchronize(device) - @torch.inference_mode() - def _run_model(self, step_input: torch.Tensor, step_pos_ids: torch.Tensor, bsz: int) -> torch.Tensor: - # Default-on compile only pays off for single-request decode. For - # batched decode, eager preserves loaded throughput and avoids the - # regression seen with batch>1 compiled execution. - model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward - try: - return model_fwd(step_input, step_pos_ids) - except Exception as exc: - if model_fwd is self.model.forward or self._compile_failed: - raise - self._compile_failed = True - self._compiled_model_fwd = self.model.forward - logger.warning("Fish Speech Fast AR torch.compile fallback to eager after runtime failure: %s", exc) - return self.model.forward(step_input, step_pos_ids) - @torch.inference_mode() def forward( self, @@ -393,6 +416,11 @@ def forward( ) -> torch.Tensor: """Predict residual codebook codes 0..num_codebooks-1 autoregressively. + Each step replays a CUDA graph (or compiled forward) over the + full-length embedding buffer [padded_bsz, max_seq, H], then + indexes the relevant position for logits. Sampling happens + outside the graph. + Args: slow_ar_hidden: [B, hidden_size] last hidden state from Slow AR. semantic_token_id: [B] or [B, 1] sampled semantic token IDs (in vocab space). @@ -409,18 +437,20 @@ def forward( semantic_begin = self.slow_ar_config.semantic_begin_id semantic_end = self.slow_ar_config.semantic_end_id codebook_size = semantic_end - semantic_begin + 1 # 4096 - # Convert vocab-space semantic token to codebook index. - # Clamp to valid range: im_end or other non-semantic tokens map to 0 (pad). semantic_code = (semantic_token_id.reshape(bsz) - semantic_begin).clamp(min=0, max=codebook_size - 1) all_codes = torch.empty(bsz, num_cb, dtype=torch.long, device=device) all_codes[:, 0] = semantic_code - self._ensure_buffers(bsz, device, dtype) + self._ensure_buffers(device, dtype) self._setup_compile() embed_buf = self._embed_buf - pos_ids = self._pos_ids + max_seq = num_cb + 1 + + # Pad batch to a CUDA graph bucket size. + padded_bsz = self._padded_bsz(bsz) + embed_buf[:padded_bsz].zero_() # Position 0: projected Slow AR hidden state. projected = self.fast_project_in(slow_ar_hidden.reshape(bsz, -1)) @@ -432,23 +462,20 @@ def forward( use_sampling = do_sample and temperature > 0 inv_temperature = 1.0 / max(temperature, 1e-6) if use_sampling else 0.0 - - # Residual codebook size (1024) vs semantic codebook size (4096). - # The fast_output head has codebook_size (4096) outputs, but residual - # codebooks only have 1024 entries. Truncate logits for steps > 0. residual_codebook_size = 1024 + # Resolve compiled forward and position_ids for this bucket. + model_fwd = self._compiled_model_fwd or self.model.forward + pos_ids = self._bucket_pos_ids.get(padded_bsz) + if pos_ids is None: + pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(padded_bsz, -1) + for step in range(1, num_cb): - seq_len = step + 1 - step_input = embed_buf[:bsz, :seq_len, :] - # Use a dense 2D position tensor for every batch size; stride-0 - # views from expand() were fragile under compiled execution. - step_pos_ids = pos_ids[:seq_len].unsqueeze(0).repeat(bsz, 1) + # Full-buffer forward (fixed shape — captured by outer CUDAGraphWrapper). + hidden_out = model_fwd(embed_buf[:padded_bsz, :max_seq, :], pos_ids) - hidden_out = self._run_model(step_input, step_pos_ids, bsz) - logits = self.fast_output(self.fast_norm(hidden_out[:, -1, :])) + logits = self.fast_output(self.fast_norm(hidden_out[:bsz, step, :])) - # Residual codebooks (step >= 1) only have 1024 entries. if step >= 1: logits = logits[:, :residual_codebook_size] diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py index 4ad2a1fa63b..a47b9c24db6 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py @@ -237,6 +237,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix="fast_ar", ) + # Expose .talker so OmniGPUModelRunner wraps talker_mtp in + # CUDAGraphWrapper, capturing the entire Fast AR codebook + # decode loop in one graph replay. + self.talker = self.fast_ar + # Constant logit mask: allow only semantic tokens + im_end. vocab = int(self.text_config.vocab_size) semantic_mask = torch.zeros((vocab,), dtype=torch.bool) @@ -622,19 +627,16 @@ def talker_mtp( # This ensures the Slow AR sees codes from FastAR(hidden_{t-1}). inputs_embeds_out = input_embeds.reshape(bsz, -1).clone() + # Branchless codebook embedding (CUDA-graph-safe: no data-dependent + # control flow). Compute for all positions, mask via torch.where. semantic_mask = (input_ids[:, 0] >= self._semantic_begin_id) & (input_ids[:, 0] <= self._semantic_end_id) - if semantic_mask.any(): - semantic_codes = audio_codes[semantic_mask].clamp(min=0) - offsets = ( - torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size - ).unsqueeze(0) - codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16) - - # Normalize by sqrt(num_codebooks + 1) as in the reference model - # (scale_codebook_embeddings=True for fish_qwen3_omni). - inputs_embeds_out[semantic_mask] = (inputs_embeds_out[semantic_mask] + codebook_sum) / math.sqrt( - self._num_codebooks + 1 - ) + clamped_codes = audio_codes.clamp(min=0) + offsets = ( + torch.arange(self._num_codebooks, device=dev, dtype=clamped_codes.dtype) * self._codebook_size + ).unsqueeze(0) + codebook_sum = self.codebook_embeddings(clamped_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16) + normalized = (inputs_embeds_out + codebook_sum) / math.sqrt(self._num_codebooks + 1) + inputs_embeds_out = torch.where(semantic_mask.unsqueeze(-1), normalized, inputs_embeds_out) return inputs_embeds_out, audio_codes.to(dtype=torch.long)