Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,16 @@ def forward(
hidden_states = self.norm(hidden_states)
return hidden_states

def precompute_fused_qkv(self) -> None:
"""Materialize fused QKV weights before CUDA Graph capture."""
for layer in self.layers:
attn = layer.self_attn
if attn._fused_qkv_weight is None:
attn._fused_qkv_weight = torch.cat(
[attn.q_proj.weight, attn.k_proj.weight, attn.v_proj.weight],
dim=0,
).detach()

def compile_selective(self) -> list[str]:
"""Compile the full model forward as one graph.

Expand Down Expand Up @@ -411,6 +421,16 @@ def forward(
hidden_states = self.norm(hidden_states)
return hidden_states

def precompute_fused_qkv(self) -> None:
"""Materialize fused QKV weights before CUDA Graph capture."""
for layer in self.layers:
attn = layer.self_attn
if attn._fused_qkv_weight is None:
attn._fused_qkv_weight = torch.cat(
[attn.q_proj.weight, attn.k_proj.weight, attn.v_proj.weight],
dim=0,
).detach()

def compile_selective(self) -> list[str]:
"""Compile the full residual model forward as one graph (same strategy as base_lm)."""
if self._compiled_layers:
Expand Down
188 changes: 169 additions & 19 deletions vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from __future__ import annotations

import copy
import dataclasses
import logging
import os
Expand All @@ -21,6 +22,7 @@
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context, override_forward_context
from vllm.logger import init_logger
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
Expand Down Expand Up @@ -101,6 +103,14 @@ class _RequestState:
last_decoded_audio: torch.Tensor | None = None


@dataclasses.dataclass
class _CapturedGraph:
graph: torch.cuda.CUDAGraph
input_embeds: torch.Tensor
positions: torch.Tensor
output: torch.Tensor


# ===================================================================
# Profiling timer
# ===================================================================
Expand Down Expand Up @@ -336,6 +346,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self._perf = _PerfTimer(enabled=_ENABLE_PROFILING)
self._cfm_buffers: _CFMBufferManager | None = None
self._enable_cuda_graph = True
self._scaffold_graphs: dict[int, _CapturedGraph] = {}
self._residual_graphs: dict[int, _CapturedGraph] = {}
self._max_cached_graphs = self._max_batch_size
self._cuda_graph_pool: tuple | None = None
self._cuda_graph_warmup_steps = 0
self._cuda_graph_warmup_threshold = 3

self._active_states: dict[str, _RequestState] = {}
self._current_request_id: str | None = None
Expand Down Expand Up @@ -483,19 +500,24 @@ def _setup_torch_compile(self) -> None:
except Exception as e:
logger.warning("torch.compile AudioVAE failed: %s", e)

if not getattr(self.model, "_selective_compiled", False):
try:
targets.extend(f"scaffold.{t}" for t in self.model.compile_selective())
self.model._selective_compiled = True
except Exception as e:
logger.warning("scaffold compile failed: %s", e)
if not self._enable_cuda_graph:
if not getattr(self.model, "_selective_compiled", False):
try:
targets.extend(f"scaffold.{t}" for t in self.model.compile_selective())
self.model._selective_compiled = True
except Exception as e:
logger.warning("scaffold compile failed: %s", e)

if not getattr(self.residual_model, "_selective_compiled", False):
try:
targets.extend(f"residual.{t}" for t in self.residual_model.compile_selective())
self.residual_model._selective_compiled = True
except Exception as e:
logger.warning("residual compile failed: %s", e)
if not getattr(self.residual_model, "_selective_compiled", False):
try:
targets.extend(f"residual.{t}" for t in self.residual_model.compile_selective())
self.residual_model._selective_compiled = True
except Exception as e:
logger.warning("residual compile failed: %s", e)
else:
self.model.precompute_fused_qkv()
self.residual_model.precompute_fused_qkv()
targets.append("scaffold+residual (CUDA Graph, skipping compile)")

if not getattr(self, "_projections_compiled", False):
try:
Expand All @@ -518,6 +540,90 @@ def _stop_fn(self, lm_h: torch.Tensor) -> torch.Tensor:
tts = self.tts
return tts.stop_head(tts.stop_actn(tts.stop_proj(lm_h)))

def _get_cuda_graph_pool(self) -> tuple:
if self._cuda_graph_pool is None:
self._cuda_graph_pool = torch.cuda.graph_pool_handle()
return self._cuda_graph_pool

@staticmethod
def _nullify_volatile_metadata(ctx: Any) -> Any:
"""Set ``scheduler_metadata`` to None on all attention layers.

This is the only tensor FA3 reallocates each step (variable shape).
All other metadata tensors are persistent model-runner buffers.
Setting it to None makes FA3 use default scheduling (~0.1ms cost).
"""
if not isinstance(ctx.attn_metadata, dict):
return ctx

ctx = copy.copy(ctx)
new_meta: dict[str, Any] = {}
for layer_name, meta in ctx.attn_metadata.items():
if getattr(meta, "scheduler_metadata", None) is not None:
meta = copy.copy(meta)
meta.scheduler_metadata = None
new_meta[layer_name] = meta
ctx.attn_metadata = new_meta
return ctx

def _capture_graph(
self,
model: nn.Module,
batch_size: int,
label: str,
is_residual: bool = False,
) -> _CapturedGraph:
"""Capture a CUDA Graph for *model* at *batch_size*."""
hidden_size = self.config.hidden_size
dtype = self._side_dtype
dev = torch.device(self._device)
pool = self._get_cuda_graph_pool()

model.precompute_fused_qkv()

g = _CapturedGraph(
graph=torch.cuda.CUDAGraph(),
input_embeds=torch.zeros(batch_size, hidden_size, device=dev, dtype=dtype),
positions=torch.zeros(batch_size, device=dev, dtype=torch.long),
output=torch.zeros(batch_size, hidden_size, device=dev, dtype=dtype),
)

if is_residual:
call_kwargs = dict(positions=g.positions, inputs_embeds=g.input_embeds)
else:
call_kwargs = dict(input_ids=None, positions=g.positions, inputs_embeds=g.input_embeds)

ctx = get_forward_context()
patched_ctx = self._nullify_volatile_metadata(ctx)

with override_forward_context(patched_ctx):
for _ in range(3):
_ = model(**call_kwargs)

with torch.cuda.graph(g.graph, pool=pool):
g.output = model(**call_kwargs)

logger.info("CUDA Graph captured for %s (batch_size=%d)", label, batch_size)
return g

def _replay_graph(
self,
g: _CapturedGraph,
inputs_embeds: torch.Tensor,
positions: torch.Tensor,
batch_size: int,
) -> torch.Tensor:
"""Copy fresh inputs into static buffers, then replay.

No metadata copy needed: persistent buffers (seq_lens, slot_mapping,
etc.) are updated in-place by the model runner. scheduler_metadata
was nullified at capture time so no kernel references it.
"""
g.input_embeds[:batch_size].copy_(inputs_embeds[:batch_size])
g.positions[:batch_size].copy_(positions[:batch_size])
g.graph.replay()
return g.output[:batch_size].clone()

# -------------------- vllm hooks --------------------

def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor:
Expand All @@ -534,12 +640,35 @@ def forward(
self._perf.start("forward_total")
dev = input_ids.device

model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
if isinstance(model_output, IntermediateTensors):
return model_output
scaffold_hidden = model_output
if isinstance(scaffold_hidden, tuple):
scaffold_hidden = scaffold_hidden[0]
num_reqs = len(self._pending_requests)
num_decode = sum(1 for _, is_p, _, n in self._pending_requests if not is_p and n == 1)
is_all_decode = num_decode == num_reqs and num_reqs > 0

tts_compiled = getattr(self.tts.feat_decoder.estimator, "_compiled", False) if self._tts is not None else False
graph_ready = tts_compiled and self._cuda_graph_warmup_steps >= self._cuda_graph_warmup_threshold
if num_decode > 0:
self._cuda_graph_warmup_steps += 1

can_use_graph = (
self._enable_cuda_graph and graph_ready and intermediate_tensors is None and inputs_embeds is not None
)

if can_use_graph and is_all_decode and num_reqs <= self._max_cached_graphs:
self._perf.start("scaffold_fwd")
if num_reqs not in self._scaffold_graphs:
self._scaffold_graphs[num_reqs] = self._capture_graph(self.model, num_reqs, "scaffold")
scaffold_hidden = self._replay_graph(self._scaffold_graphs[num_reqs], inputs_embeds, positions, num_reqs)
self._perf.stop("scaffold_fwd")

else:
self._perf.start("scaffold_fwd")
model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
self._perf.stop("scaffold_fwd")
if isinstance(model_output, IntermediateTensors):
return model_output
scaffold_hidden = model_output
if isinstance(scaffold_hidden, tuple):
scaffold_hidden = scaffold_hidden[0]

# Phase 1: per-request FSQ + residual input
token_offset = 0
Expand Down Expand Up @@ -571,7 +700,28 @@ def forward(
if residual_inputs:
batch_in = torch.cat(residual_inputs, dim=0)
batch_pos = torch.cat(residual_positions, dim=0)
batch_out = self.residual_model(batch_pos, batch_in)

residual_batch_size = batch_in.shape[0]
use_residual_graph = (
self._enable_cuda_graph
and is_all_decode
and graph_ready
and residual_batch_size == num_reqs # 1 token per request
and residual_batch_size <= self._max_cached_graphs
)

self._perf.start("residual_fwd")
if use_residual_graph:
if residual_batch_size not in self._residual_graphs:
self._residual_graphs[residual_batch_size] = self._capture_graph(
self.residual_model, residual_batch_size, "residual", is_residual=True
)
batch_out = self._replay_graph(
self._residual_graphs[residual_batch_size], batch_in, batch_pos, residual_batch_size
)
else:
batch_out = self.residual_model(batch_pos, batch_in)
self._perf.stop("residual_fwd")

# Phase 3: per-request LocDiT + update
offset = 0
Expand Down
Loading