From af8568ea00e57f77a0283e1d001751bf2e1ad40b Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Tue, 7 Apr 2026 03:57:24 +0800 Subject: [PATCH 1/5] feat(fish-speech): enable CUDA Graph capture for Fast AR code predictor Enable CUDAGraphWrapper for Fish Speech S2 Pro's Fast AR via opt-in talker_mtp_graph_safe attribute. - Wrap talker_mtp in CUDAGraphWrapper in GPUARModelRunner.load_model (not __init__, since has_talker_mtp is set during load_model) - Add _capture_talker_mtp_graphs() for explicit warmup+capture after capture_model() completes; capture largest bsz first to pre-allocate Fast AR internal buffers at max size (avoids buffer reallocation invalidating previously captured graphs) - Replace semantic_mask.any() with torch.where (graph-safe) - Disable torch.compile inside Fast AR when outer graph is active - Fallback to eager on capture failure with compile state reset Only affects models with talker_mtp_graph_safe = True. gpu_model_runner.py is untouched. Benchmark (H20, Fish Speech S2 Pro, vllm 0.19.0): Baseline: 2048ms -> Optimized: 955ms (-53.4%) Signed-off-by: Sy03 <1370724210@qq.com> --- .../models/fish_speech/fish_speech_fast_ar.py | 7 +- .../models/fish_speech/fish_speech_slow_ar.py | 40 ++++----- vllm_omni/worker/gpu_ar_model_runner.py | 87 +++++++++++++++++++ 3 files changed, 112 insertions(+), 22 deletions(-) diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py index 8bbb643ebec..fdd3236c06b 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py @@ -310,6 +310,7 @@ def __init__( self._compiled_model_fwd: object | None = None self._compile_attempted = False self._compile_failed = False + self._disable_compile_for_graph = False def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None: max_seq = self._num_codebooks + 1 # hidden_state + num_codebooks codes @@ -327,11 +328,13 @@ def _setup_compile(self) -> None: if self._compile_attempted: return self._compile_attempted = True + if self._disable_compile_for_graph: + self._compiled_model_fwd = self.model.forward + logger.info("Fast AR torch.compile disabled (outer CUDA Graph active)") + return try: self._compiled_model_fwd = torch.compile( self.model.forward, - # Keep the helper compiler separate from vLLM's outer - # cudagraph-managed Stage-0 execution. mode="default", dynamic=True, fullgraph=False, diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py index 4ad2a1fa63b..40cda7d5c29 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py @@ -194,6 +194,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.has_postprocess = True self.mtp_hidden_size = int(self.text_config.hidden_size) self.talker_mtp_output_key = "audio_codes" + self.talker_mtp_graph_safe = True self.gpu_resident_buffer_keys: set[str] = {"last_slow_ar_hidden"} # Qwen3 transformer backbone. @@ -236,6 +237,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): slow_ar_config=self.text_config, prefix="fast_ar", ) + if self.talker_mtp_graph_safe: + self.fast_ar._disable_compile_for_graph = True # Constant logit mask: allow only semantic tokens + im_end. vocab = int(self.text_config.vocab_size) @@ -622,19 +625,15 @@ def talker_mtp( # This ensures the Slow AR sees codes from FastAR(hidden_{t-1}). inputs_embeds_out = input_embeds.reshape(bsz, -1).clone() + # torch.where avoids host-device sync (.any()) for CUDA Graph compatibility. semantic_mask = (input_ids[:, 0] >= self._semantic_begin_id) & (input_ids[:, 0] <= self._semantic_end_id) - if semantic_mask.any(): - semantic_codes = audio_codes[semantic_mask].clamp(min=0) - offsets = ( - torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size - ).unsqueeze(0) - codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16) - - # Normalize by sqrt(num_codebooks + 1) as in the reference model - # (scale_codebook_embeddings=True for fish_qwen3_omni). - inputs_embeds_out[semantic_mask] = (inputs_embeds_out[semantic_mask] + codebook_sum) / math.sqrt( - self._num_codebooks + 1 - ) + semantic_codes = audio_codes.clamp(min=0) + offsets = ( + torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size + ).unsqueeze(0) + codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16) + norm_embeds = (inputs_embeds_out + codebook_sum) / math.sqrt(self._num_codebooks + 1) + inputs_embeds_out = torch.where(semantic_mask.unsqueeze(-1), norm_embeds, inputs_embeds_out) return inputs_embeds_out, audio_codes.to(dtype=torch.long) @@ -745,14 +744,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if truncated: logger.info("Truncated %d RoPE cos_sin_cache buffers to bf16 precision", truncated) - try: - self.fast_ar.warmup_compile( - device=self.codebook_embeddings.weight.device, - dtype=torch.bfloat16, - batch_sizes=(1,), - ) - except Exception as exc: - logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc) + if not getattr(self, "talker_mtp_graph_safe", False): + try: + self.fast_ar.warmup_compile( + device=self.codebook_embeddings.weight.device, + dtype=torch.bfloat16, + batch_sizes=(1,), + ) + except Exception as exc: + logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc) codec_device = self.codebook_embeddings.weight.device _load_dac_codec( diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 01ec23acb47..0923e34c48a 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -138,6 +138,93 @@ def _sampling_metadata_for_model_sampler(self, sampling_metadata): return sampling_metadata return replace(sampling_metadata, output_token_ids=output_token_ids) + def load_model(self, *args, **kwargs) -> None: + super().load_model(*args, **kwargs) + from vllm_omni.worker.gpu_model_runner import CUDAGraphWrapper + + if ( + self.has_talker_mtp + and not isinstance(self.talker_mtp, CUDAGraphWrapper) + and getattr(self.model, "talker_mtp_graph_safe", False) + and self.compilation_config.cudagraph_mode is not None + and self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ): + self.talker_mtp = CUDAGraphWrapper( + self.model.talker_mtp, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL, + ) + + def capture_model(self) -> int: + result = super().capture_model() + self._capture_talker_mtp_graphs() + return result + + def _capture_talker_mtp_graphs(self) -> None: + from vllm_omni.worker.gpu_model_runner import CUDAGraphWrapper + + if not self.has_talker_mtp or not isinstance(self.talker_mtp, CUDAGraphWrapper): + return + if not getattr(self.model, "talker_mtp_graph_safe", False): + return + + from vllm.compilation.monitor import set_cudagraph_capturing_enabled + from vllm.distributed.parallel_state import graph_capture + + capture_sizes = self.compilation_config.cudagraph_capture_sizes + num_warmups = self.compilation_config.cudagraph_num_of_warmups + # Capture largest first so _ensure_buffers pre-allocates for max bsz. + # Smaller captures then reuse the same buffer (shape[0] >= bsz). + capture_sizes = sorted(capture_sizes, reverse=True) + logger.info("Capturing talker_mtp graphs for sizes %s", capture_sizes) + + set_cudagraph_capturing_enabled(True) + try: + with torch.inference_mode(), graph_capture(device=self.device): + for bsz in capture_sizes: + _, batch_desc, _, _, _ = self._determine_batch_execution_and_padding( + num_tokens=bsz, + num_reqs=bsz, + num_scheduled_tokens_np=np.ones(bsz, dtype=np.int32), + max_num_scheduled_tokens=1, + use_cascade_attn=False, + ) + n = batch_desc.num_tokens + ids = self.talker_mtp_input_ids.gpu[:n] + emb = self.talker_mtp_inputs_embeds.gpu[:n] + hid = self.last_talker_hidden.gpu[:n] + ts = self.text_step.gpu[:n] + + for _ in range(num_warmups): + with set_forward_context( + None, + self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=batch_desc, + ): + self.talker_mtp(ids, emb, hid, ts) + + with set_forward_context( + None, + self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + batch_descriptor=batch_desc, + ): + self.talker_mtp(ids, emb, hid, ts) + torch.cuda.synchronize() + + logger.info("Captured talker_mtp graphs for %d sizes", len(capture_sizes)) + except RuntimeError as e: + logger.warning("talker_mtp graph capture failed, falling back to eager: %s", e) + self.talker_mtp = self.model.talker_mtp + # Re-enable torch.compile since we're back to eager + fast_ar = getattr(self.model, "fast_ar", None) + if fast_ar is not None: + fast_ar._disable_compile_for_graph = False + fast_ar._compile_attempted = False + finally: + set_cudagraph_capturing_enabled(False) + @torch.inference_mode() def execute_model( self, From ecee23c18d973ef01bb84b6b600eb6c5c640a1c2 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Thu, 9 Apr 2026 00:05:40 +0800 Subject: [PATCH 2/5] feat(fish-speech): add torch.compile + review fixes for Fast AR CUDA Graph - Extend CUDAGraphWrapper wrap condition with talker_mtp_graph_safe opt-in - Enable torch.compile(dynamic=True, epilogue_fusion=False) inside graph - Use compiled forward for all batch sizes in graph mode - Replace semantic_mask.any() with torch.where for graph compatibility - Add clamp(max=codebook_size-1) for codebook index safety - Clean fallback state reset (_compiled_model_fwd=None) Signed-off-by: Sy03 <1370724210@qq.com> --- .../models/fish_speech/fish_speech_fast_ar.py | 19 ++++++++++----- .../models/fish_speech/fish_speech_slow_ar.py | 3 +-- vllm_omni/worker/gpu_ar_model_runner.py | 23 +------------------ vllm_omni/worker/gpu_model_runner.py | 6 ++--- 4 files changed, 17 insertions(+), 34 deletions(-) diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py index fdd3236c06b..22a2744ff5d 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py @@ -329,8 +329,15 @@ def _setup_compile(self) -> None: return self._compile_attempted = True if self._disable_compile_for_graph: - self._compiled_model_fwd = self.model.forward - logger.info("Fast AR torch.compile disabled (outer CUDA Graph active)") + try: + self._compiled_model_fwd = torch.compile( + self.model.forward, + dynamic=True, + options={"epilogue_fusion": False}, + ) + except Exception as exc: + logger.warning("Fast AR torch.compile (graph mode) failed: %s", exc) + self._compiled_model_fwd = self.model.forward return try: self._compiled_model_fwd = torch.compile( @@ -369,10 +376,10 @@ def warmup_compile( @torch.inference_mode() def _run_model(self, step_input: torch.Tensor, step_pos_ids: torch.Tensor, bsz: int) -> torch.Tensor: - # Default-on compile only pays off for single-request decode. For - # batched decode, eager preserves loaded throughput and avoids the - # regression seen with batch>1 compiled execution. - model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward + if self._disable_compile_for_graph: + model_fwd = self._compiled_model_fwd or self.model.forward + else: + model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward try: return model_fwd(step_input, step_pos_ids) except Exception as exc: diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py index 40cda7d5c29..1c549421f65 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py @@ -625,9 +625,8 @@ def talker_mtp( # This ensures the Slow AR sees codes from FastAR(hidden_{t-1}). inputs_embeds_out = input_embeds.reshape(bsz, -1).clone() - # torch.where avoids host-device sync (.any()) for CUDA Graph compatibility. semantic_mask = (input_ids[:, 0] >= self._semantic_begin_id) & (input_ids[:, 0] <= self._semantic_end_id) - semantic_codes = audio_codes.clamp(min=0) + semantic_codes = audio_codes.clamp(min=0, max=self._codebook_size - 1) offsets = ( torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size ).unsqueeze(0) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 0923e34c48a..7d35630d983 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -138,23 +138,6 @@ def _sampling_metadata_for_model_sampler(self, sampling_metadata): return sampling_metadata return replace(sampling_metadata, output_token_ids=output_token_ids) - def load_model(self, *args, **kwargs) -> None: - super().load_model(*args, **kwargs) - from vllm_omni.worker.gpu_model_runner import CUDAGraphWrapper - - if ( - self.has_talker_mtp - and not isinstance(self.talker_mtp, CUDAGraphWrapper) - and getattr(self.model, "talker_mtp_graph_safe", False) - and self.compilation_config.cudagraph_mode is not None - and self.compilation_config.cudagraph_mode.has_full_cudagraphs() - ): - self.talker_mtp = CUDAGraphWrapper( - self.model.talker_mtp, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL, - ) - def capture_model(self) -> int: result = super().capture_model() self._capture_talker_mtp_graphs() @@ -165,16 +148,12 @@ def _capture_talker_mtp_graphs(self) -> None: if not self.has_talker_mtp or not isinstance(self.talker_mtp, CUDAGraphWrapper): return - if not getattr(self.model, "talker_mtp_graph_safe", False): - return from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.distributed.parallel_state import graph_capture capture_sizes = self.compilation_config.cudagraph_capture_sizes num_warmups = self.compilation_config.cudagraph_num_of_warmups - # Capture largest first so _ensure_buffers pre-allocates for max bsz. - # Smaller captures then reuse the same buffer (shape[0] >= bsz). capture_sizes = sorted(capture_sizes, reverse=True) logger.info("Capturing talker_mtp graphs for sizes %s", capture_sizes) @@ -217,11 +196,11 @@ def _capture_talker_mtp_graphs(self) -> None: except RuntimeError as e: logger.warning("talker_mtp graph capture failed, falling back to eager: %s", e) self.talker_mtp = self.model.talker_mtp - # Re-enable torch.compile since we're back to eager fast_ar = getattr(self.model, "fast_ar", None) if fast_ar is not None: fast_ar._disable_compile_for_graph = False fast_ar._compile_attempted = False + fast_ar._compiled_model_fwd = None finally: set_cudagraph_capturing_enabled(False) diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index a7abaf7b62a..aa05fbed304 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -100,11 +100,9 @@ def load_model(self, *args, **kwargs) -> None: self.has_talker_mtp = True cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None - # Only wrap talker_mtp in CUDAGraphWrapper for Omni models that - # have a separate .talker sub-module. TTS models' code predictor - # has internal AR loops / torch.multinomial — not graph-safe. has_separate_talker = getattr(self.model, "talker", None) is not None - if cudagraph_mode.has_full_cudagraphs() and has_separate_talker: + talker_mtp_graph_safe = getattr(self.model, "talker_mtp_graph_safe", False) + if cudagraph_mode.has_full_cudagraphs() and (has_separate_talker or talker_mtp_graph_safe): self.talker_mtp = CUDAGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) # TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size. hidden_size = int( From 0c89a2370e5a0d32ebc1791f5d71008230da6182 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Thu, 9 Apr 2026 19:56:54 +0800 Subject: [PATCH 3/5] fix: use CUDAGraphWrapper.unwrap() in capture fallback path Signed-off-by: Sy03 <1370724210@qq.com> --- vllm_omni/worker/gpu_ar_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 7d35630d983..7a654e41881 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -195,7 +195,7 @@ def _capture_talker_mtp_graphs(self) -> None: logger.info("Captured talker_mtp graphs for %d sizes", len(capture_sizes)) except RuntimeError as e: logger.warning("talker_mtp graph capture failed, falling back to eager: %s", e) - self.talker_mtp = self.model.talker_mtp + self.talker_mtp = self.talker_mtp.unwrap() fast_ar = getattr(self.model, "fast_ar", None) if fast_ar is not None: fast_ar._disable_compile_for_graph = False From 42293c91a4052d16aff9c99627429d435144aec9 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Thu, 9 Apr 2026 20:46:53 +0800 Subject: [PATCH 4/5] fix: raise error on talker_mtp graph capture failure instead of silent fallback Signed-off-by: Sy03 <1370724210@qq.com> --- vllm_omni/worker/gpu_ar_model_runner.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 7a654e41881..e77f340b73d 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -194,13 +194,10 @@ def _capture_talker_mtp_graphs(self) -> None: logger.info("Captured talker_mtp graphs for %d sizes", len(capture_sizes)) except RuntimeError as e: - logger.warning("talker_mtp graph capture failed, falling back to eager: %s", e) - self.talker_mtp = self.talker_mtp.unwrap() - fast_ar = getattr(self.model, "fast_ar", None) - if fast_ar is not None: - fast_ar._disable_compile_for_graph = False - fast_ar._compile_attempted = False - fast_ar._compiled_model_fwd = None + raise RuntimeError( + f"talker_mtp graph capture failed for a model that declared " + f"talker_mtp_graph_safe=True: {e}" + ) from e finally: set_cudagraph_capturing_enabled(False) From e24c9e5a2c456158338ae8f14c1dae595db9a350 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Thu, 9 Apr 2026 20:48:26 +0800 Subject: [PATCH 5/5] style: ruff format Signed-off-by: Sy03 <1370724210@qq.com> --- vllm_omni/worker/gpu_ar_model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index e77f340b73d..227855eaf6b 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -195,8 +195,7 @@ def _capture_talker_mtp_graphs(self) -> None: logger.info("Captured talker_mtp graphs for %d sizes", len(capture_sizes)) except RuntimeError as e: raise RuntimeError( - f"talker_mtp graph capture failed for a model that declared " - f"talker_mtp_graph_safe=True: {e}" + f"talker_mtp graph capture failed for a model that declared talker_mtp_graph_safe=True: {e}" ) from e finally: set_cudagraph_capturing_enabled(False)