From 8cfdf660017a2e0795fd46a0d7dc028154002b76 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Wed, 15 Apr 2026 11:33:13 +0800 Subject: [PATCH] perf(voxcpm2): manual CUDA Graph capture for scaffold/residual forward Capture and replay CUDA Graphs for the 28-layer scaffold and 8-layer residual model forwards during decode steps, eliminating per-step kernel launch overhead. Reduces average RTF from 0.135 to 0.106 (-21%) and improves concurrent throughput by 14-40%. Signed-off-by: Sy03 <1370724210@qq.com> --- .../models/voxcpm2/minicpm4_paged.py | 20 ++ .../models/voxcpm2/voxcpm2_talker.py | 188 ++++++++++++++++-- 2 files changed, 189 insertions(+), 19 deletions(-) diff --git a/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py b/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py index 40bacfff6c7..b87ec5aafef 100644 --- a/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py +++ b/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py @@ -307,6 +307,16 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states + def precompute_fused_qkv(self) -> None: + """Materialize fused QKV weights before CUDA Graph capture.""" + for layer in self.layers: + attn = layer.self_attn + if attn._fused_qkv_weight is None: + attn._fused_qkv_weight = torch.cat( + [attn.q_proj.weight, attn.k_proj.weight, attn.v_proj.weight], + dim=0, + ).detach() + def compile_selective(self) -> list[str]: """Compile the full model forward as one graph. @@ -411,6 +421,16 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states + def precompute_fused_qkv(self) -> None: + """Materialize fused QKV weights before CUDA Graph capture.""" + for layer in self.layers: + attn = layer.self_attn + if attn._fused_qkv_weight is None: + attn._fused_qkv_weight = torch.cat( + [attn.q_proj.weight, attn.k_proj.weight, attn.v_proj.weight], + dim=0, + ).detach() + def compile_selective(self) -> list[str]: """Compile the full residual model forward as one graph (same strategy as base_lm).""" if self._compiled_layers: diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py index 94f06589046..02bcae821e1 100644 --- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py +++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py @@ -10,6 +10,7 @@ from __future__ import annotations +import copy import dataclasses import logging import os @@ -21,6 +22,7 @@ import torch import torch.nn as nn from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context, override_forward_context from vllm.logger import init_logger from vllm.model_executor.models.utils import ( AutoWeightsLoader, @@ -101,6 +103,14 @@ class _RequestState: last_decoded_audio: torch.Tensor | None = None +@dataclasses.dataclass +class _CapturedGraph: + graph: torch.cuda.CUDAGraph + input_embeds: torch.Tensor + positions: torch.Tensor + output: torch.Tensor + + # =================================================================== # Profiling timer # =================================================================== @@ -336,6 +346,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._perf = _PerfTimer(enabled=_ENABLE_PROFILING) self._cfm_buffers: _CFMBufferManager | None = None + self._enable_cuda_graph = True + self._scaffold_graphs: dict[int, _CapturedGraph] = {} + self._residual_graphs: dict[int, _CapturedGraph] = {} + self._max_cached_graphs = self._max_batch_size + self._cuda_graph_pool: tuple | None = None + self._cuda_graph_warmup_steps = 0 + self._cuda_graph_warmup_threshold = 3 self._active_states: dict[str, _RequestState] = {} self._current_request_id: str | None = None @@ -483,19 +500,24 @@ def _setup_torch_compile(self) -> None: except Exception as e: logger.warning("torch.compile AudioVAE failed: %s", e) - if not getattr(self.model, "_selective_compiled", False): - try: - targets.extend(f"scaffold.{t}" for t in self.model.compile_selective()) - self.model._selective_compiled = True - except Exception as e: - logger.warning("scaffold compile failed: %s", e) + if not self._enable_cuda_graph: + if not getattr(self.model, "_selective_compiled", False): + try: + targets.extend(f"scaffold.{t}" for t in self.model.compile_selective()) + self.model._selective_compiled = True + except Exception as e: + logger.warning("scaffold compile failed: %s", e) - if not getattr(self.residual_model, "_selective_compiled", False): - try: - targets.extend(f"residual.{t}" for t in self.residual_model.compile_selective()) - self.residual_model._selective_compiled = True - except Exception as e: - logger.warning("residual compile failed: %s", e) + if not getattr(self.residual_model, "_selective_compiled", False): + try: + targets.extend(f"residual.{t}" for t in self.residual_model.compile_selective()) + self.residual_model._selective_compiled = True + except Exception as e: + logger.warning("residual compile failed: %s", e) + else: + self.model.precompute_fused_qkv() + self.residual_model.precompute_fused_qkv() + targets.append("scaffold+residual (CUDA Graph, skipping compile)") if not getattr(self, "_projections_compiled", False): try: @@ -518,6 +540,90 @@ def _stop_fn(self, lm_h: torch.Tensor) -> torch.Tensor: tts = self.tts return tts.stop_head(tts.stop_actn(tts.stop_proj(lm_h))) + def _get_cuda_graph_pool(self) -> tuple: + if self._cuda_graph_pool is None: + self._cuda_graph_pool = torch.cuda.graph_pool_handle() + return self._cuda_graph_pool + + @staticmethod + def _nullify_volatile_metadata(ctx: Any) -> Any: + """Set ``scheduler_metadata`` to None on all attention layers. + + This is the only tensor FA3 reallocates each step (variable shape). + All other metadata tensors are persistent model-runner buffers. + Setting it to None makes FA3 use default scheduling (~0.1ms cost). + """ + if not isinstance(ctx.attn_metadata, dict): + return ctx + + ctx = copy.copy(ctx) + new_meta: dict[str, Any] = {} + for layer_name, meta in ctx.attn_metadata.items(): + if getattr(meta, "scheduler_metadata", None) is not None: + meta = copy.copy(meta) + meta.scheduler_metadata = None + new_meta[layer_name] = meta + ctx.attn_metadata = new_meta + return ctx + + def _capture_graph( + self, + model: nn.Module, + batch_size: int, + label: str, + is_residual: bool = False, + ) -> _CapturedGraph: + """Capture a CUDA Graph for *model* at *batch_size*.""" + hidden_size = self.config.hidden_size + dtype = self._side_dtype + dev = torch.device(self._device) + pool = self._get_cuda_graph_pool() + + model.precompute_fused_qkv() + + g = _CapturedGraph( + graph=torch.cuda.CUDAGraph(), + input_embeds=torch.zeros(batch_size, hidden_size, device=dev, dtype=dtype), + positions=torch.zeros(batch_size, device=dev, dtype=torch.long), + output=torch.zeros(batch_size, hidden_size, device=dev, dtype=dtype), + ) + + if is_residual: + call_kwargs = dict(positions=g.positions, inputs_embeds=g.input_embeds) + else: + call_kwargs = dict(input_ids=None, positions=g.positions, inputs_embeds=g.input_embeds) + + ctx = get_forward_context() + patched_ctx = self._nullify_volatile_metadata(ctx) + + with override_forward_context(patched_ctx): + for _ in range(3): + _ = model(**call_kwargs) + + with torch.cuda.graph(g.graph, pool=pool): + g.output = model(**call_kwargs) + + logger.info("CUDA Graph captured for %s (batch_size=%d)", label, batch_size) + return g + + def _replay_graph( + self, + g: _CapturedGraph, + inputs_embeds: torch.Tensor, + positions: torch.Tensor, + batch_size: int, + ) -> torch.Tensor: + """Copy fresh inputs into static buffers, then replay. + + No metadata copy needed: persistent buffers (seq_lens, slot_mapping, + etc.) are updated in-place by the model runner. scheduler_metadata + was nullified at capture time so no kernel references it. + """ + g.input_embeds[:batch_size].copy_(inputs_embeds[:batch_size]) + g.positions[:batch_size].copy_(positions[:batch_size]) + g.graph.replay() + return g.output[:batch_size].clone() + # -------------------- vllm hooks -------------------- def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: @@ -534,12 +640,35 @@ def forward( self._perf.start("forward_total") dev = input_ids.device - model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) - if isinstance(model_output, IntermediateTensors): - return model_output - scaffold_hidden = model_output - if isinstance(scaffold_hidden, tuple): - scaffold_hidden = scaffold_hidden[0] + num_reqs = len(self._pending_requests) + num_decode = sum(1 for _, is_p, _, n in self._pending_requests if not is_p and n == 1) + is_all_decode = num_decode == num_reqs and num_reqs > 0 + + tts_compiled = getattr(self.tts.feat_decoder.estimator, "_compiled", False) if self._tts is not None else False + graph_ready = tts_compiled and self._cuda_graph_warmup_steps >= self._cuda_graph_warmup_threshold + if num_decode > 0: + self._cuda_graph_warmup_steps += 1 + + can_use_graph = ( + self._enable_cuda_graph and graph_ready and intermediate_tensors is None and inputs_embeds is not None + ) + + if can_use_graph and is_all_decode and num_reqs <= self._max_cached_graphs: + self._perf.start("scaffold_fwd") + if num_reqs not in self._scaffold_graphs: + self._scaffold_graphs[num_reqs] = self._capture_graph(self.model, num_reqs, "scaffold") + scaffold_hidden = self._replay_graph(self._scaffold_graphs[num_reqs], inputs_embeds, positions, num_reqs) + self._perf.stop("scaffold_fwd") + + else: + self._perf.start("scaffold_fwd") + model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) + self._perf.stop("scaffold_fwd") + if isinstance(model_output, IntermediateTensors): + return model_output + scaffold_hidden = model_output + if isinstance(scaffold_hidden, tuple): + scaffold_hidden = scaffold_hidden[0] # Phase 1: per-request FSQ + residual input token_offset = 0 @@ -571,7 +700,28 @@ def forward( if residual_inputs: batch_in = torch.cat(residual_inputs, dim=0) batch_pos = torch.cat(residual_positions, dim=0) - batch_out = self.residual_model(batch_pos, batch_in) + + residual_batch_size = batch_in.shape[0] + use_residual_graph = ( + self._enable_cuda_graph + and is_all_decode + and graph_ready + and residual_batch_size == num_reqs # 1 token per request + and residual_batch_size <= self._max_cached_graphs + ) + + self._perf.start("residual_fwd") + if use_residual_graph: + if residual_batch_size not in self._residual_graphs: + self._residual_graphs[residual_batch_size] = self._capture_graph( + self.residual_model, residual_batch_size, "residual", is_residual=True + ) + batch_out = self._replay_graph( + self._residual_graphs[residual_batch_size], batch_in, batch_pos, residual_batch_size + ) + else: + batch_out = self.residual_model(batch_pos, batch_in) + self._perf.stop("residual_fwd") # Phase 3: per-request LocDiT + update offset = 0