diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py index 8bbb643ebec..22a2744ff5d 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py @@ -310,6 +310,7 @@ def __init__( self._compiled_model_fwd: object | None = None self._compile_attempted = False self._compile_failed = False + self._disable_compile_for_graph = False def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None: max_seq = self._num_codebooks + 1 # hidden_state + num_codebooks codes @@ -327,11 +328,20 @@ def _setup_compile(self) -> None: if self._compile_attempted: return self._compile_attempted = True + if self._disable_compile_for_graph: + try: + self._compiled_model_fwd = torch.compile( + self.model.forward, + dynamic=True, + options={"epilogue_fusion": False}, + ) + except Exception as exc: + logger.warning("Fast AR torch.compile (graph mode) failed: %s", exc) + self._compiled_model_fwd = self.model.forward + return try: self._compiled_model_fwd = torch.compile( self.model.forward, - # Keep the helper compiler separate from vLLM's outer - # cudagraph-managed Stage-0 execution. mode="default", dynamic=True, fullgraph=False, @@ -366,10 +376,10 @@ def warmup_compile( @torch.inference_mode() def _run_model(self, step_input: torch.Tensor, step_pos_ids: torch.Tensor, bsz: int) -> torch.Tensor: - # Default-on compile only pays off for single-request decode. For - # batched decode, eager preserves loaded throughput and avoids the - # regression seen with batch>1 compiled execution. - model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward + if self._disable_compile_for_graph: + model_fwd = self._compiled_model_fwd or self.model.forward + else: + model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward try: return model_fwd(step_input, step_pos_ids) except Exception as exc: diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py index 4ad2a1fa63b..1c549421f65 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py @@ -194,6 +194,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.has_postprocess = True self.mtp_hidden_size = int(self.text_config.hidden_size) self.talker_mtp_output_key = "audio_codes" + self.talker_mtp_graph_safe = True self.gpu_resident_buffer_keys: set[str] = {"last_slow_ar_hidden"} # Qwen3 transformer backbone. @@ -236,6 +237,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): slow_ar_config=self.text_config, prefix="fast_ar", ) + if self.talker_mtp_graph_safe: + self.fast_ar._disable_compile_for_graph = True # Constant logit mask: allow only semantic tokens + im_end. vocab = int(self.text_config.vocab_size) @@ -623,18 +626,13 @@ def talker_mtp( inputs_embeds_out = input_embeds.reshape(bsz, -1).clone() semantic_mask = (input_ids[:, 0] >= self._semantic_begin_id) & (input_ids[:, 0] <= self._semantic_end_id) - if semantic_mask.any(): - semantic_codes = audio_codes[semantic_mask].clamp(min=0) - offsets = ( - torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size - ).unsqueeze(0) - codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16) - - # Normalize by sqrt(num_codebooks + 1) as in the reference model - # (scale_codebook_embeddings=True for fish_qwen3_omni). - inputs_embeds_out[semantic_mask] = (inputs_embeds_out[semantic_mask] + codebook_sum) / math.sqrt( - self._num_codebooks + 1 - ) + semantic_codes = audio_codes.clamp(min=0, max=self._codebook_size - 1) + offsets = ( + torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size + ).unsqueeze(0) + codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16) + norm_embeds = (inputs_embeds_out + codebook_sum) / math.sqrt(self._num_codebooks + 1) + inputs_embeds_out = torch.where(semantic_mask.unsqueeze(-1), norm_embeds, inputs_embeds_out) return inputs_embeds_out, audio_codes.to(dtype=torch.long) @@ -745,14 +743,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if truncated: logger.info("Truncated %d RoPE cos_sin_cache buffers to bf16 precision", truncated) - try: - self.fast_ar.warmup_compile( - device=self.codebook_embeddings.weight.device, - dtype=torch.bfloat16, - batch_sizes=(1,), - ) - except Exception as exc: - logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc) + if not getattr(self, "talker_mtp_graph_safe", False): + try: + self.fast_ar.warmup_compile( + device=self.codebook_embeddings.weight.device, + dtype=torch.bfloat16, + batch_sizes=(1,), + ) + except Exception as exc: + logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc) codec_device = self.codebook_embeddings.weight.device _load_dac_codec( diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 01ec23acb47..227855eaf6b 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -138,6 +138,68 @@ def _sampling_metadata_for_model_sampler(self, sampling_metadata): return sampling_metadata return replace(sampling_metadata, output_token_ids=output_token_ids) + def capture_model(self) -> int: + result = super().capture_model() + self._capture_talker_mtp_graphs() + return result + + def _capture_talker_mtp_graphs(self) -> None: + from vllm_omni.worker.gpu_model_runner import CUDAGraphWrapper + + if not self.has_talker_mtp or not isinstance(self.talker_mtp, CUDAGraphWrapper): + return + + from vllm.compilation.monitor import set_cudagraph_capturing_enabled + from vllm.distributed.parallel_state import graph_capture + + capture_sizes = self.compilation_config.cudagraph_capture_sizes + num_warmups = self.compilation_config.cudagraph_num_of_warmups + capture_sizes = sorted(capture_sizes, reverse=True) + logger.info("Capturing talker_mtp graphs for sizes %s", capture_sizes) + + set_cudagraph_capturing_enabled(True) + try: + with torch.inference_mode(), graph_capture(device=self.device): + for bsz in capture_sizes: + _, batch_desc, _, _, _ = self._determine_batch_execution_and_padding( + num_tokens=bsz, + num_reqs=bsz, + num_scheduled_tokens_np=np.ones(bsz, dtype=np.int32), + max_num_scheduled_tokens=1, + use_cascade_attn=False, + ) + n = batch_desc.num_tokens + ids = self.talker_mtp_input_ids.gpu[:n] + emb = self.talker_mtp_inputs_embeds.gpu[:n] + hid = self.last_talker_hidden.gpu[:n] + ts = self.text_step.gpu[:n] + + for _ in range(num_warmups): + with set_forward_context( + None, + self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=batch_desc, + ): + self.talker_mtp(ids, emb, hid, ts) + + with set_forward_context( + None, + self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + batch_descriptor=batch_desc, + ): + self.talker_mtp(ids, emb, hid, ts) + torch.cuda.synchronize() + + logger.info("Captured talker_mtp graphs for %d sizes", len(capture_sizes)) + except RuntimeError as e: + raise RuntimeError( + f"talker_mtp graph capture failed for a model that declared talker_mtp_graph_safe=True: {e}" + ) from e + finally: + set_cudagraph_capturing_enabled(False) + @torch.inference_mode() def execute_model( self, diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index a7abaf7b62a..aa05fbed304 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -100,11 +100,9 @@ def load_model(self, *args, **kwargs) -> None: self.has_talker_mtp = True cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None - # Only wrap talker_mtp in CUDAGraphWrapper for Omni models that - # have a separate .talker sub-module. TTS models' code predictor - # has internal AR loops / torch.multinomial — not graph-safe. has_separate_talker = getattr(self.model, "talker", None) is not None - if cudagraph_mode.has_full_cudagraphs() and has_separate_talker: + talker_mtp_graph_safe = getattr(self.model, "talker_mtp_graph_safe", False) + if cudagraph_mode.has_full_cudagraphs() and (has_separate_talker or talker_mtp_graph_safe): self.talker_mtp = CUDAGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) # TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size. hidden_size = int(