Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def __init__(
self._compiled_model_fwd: object | None = None
self._compile_attempted = False
self._compile_failed = False
self._disable_compile_for_graph = False

def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None:
max_seq = self._num_codebooks + 1 # hidden_state + num_codebooks codes
Expand All @@ -327,11 +328,20 @@ def _setup_compile(self) -> None:
if self._compile_attempted:
return
self._compile_attempted = True
if self._disable_compile_for_graph:
try:
self._compiled_model_fwd = torch.compile(
self.model.forward,
dynamic=True,
options={"epilogue_fusion": False},
)
except Exception as exc:
logger.warning("Fast AR torch.compile (graph mode) failed: %s", exc)
self._compiled_model_fwd = self.model.forward
return
try:
self._compiled_model_fwd = torch.compile(
self.model.forward,
# Keep the helper compiler separate from vLLM's outer
# cudagraph-managed Stage-0 execution.
mode="default",
dynamic=True,
fullgraph=False,
Expand Down Expand Up @@ -366,10 +376,10 @@ def warmup_compile(

@torch.inference_mode()
def _run_model(self, step_input: torch.Tensor, step_pos_ids: torch.Tensor, bsz: int) -> torch.Tensor:
# Default-on compile only pays off for single-request decode. For
# batched decode, eager preserves loaded throughput and avoids the
# regression seen with batch>1 compiled execution.
model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward
if self._disable_compile_for_graph:
model_fwd = self._compiled_model_fwd or self.model.forward
else:
model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward
try:
return model_fwd(step_input, step_pos_ids)
except Exception as exc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.has_postprocess = True
self.mtp_hidden_size = int(self.text_config.hidden_size)
self.talker_mtp_output_key = "audio_codes"
self.talker_mtp_graph_safe = True
self.gpu_resident_buffer_keys: set[str] = {"last_slow_ar_hidden"}

# Qwen3 transformer backbone.
Expand Down Expand Up @@ -236,6 +237,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
slow_ar_config=self.text_config,
prefix="fast_ar",
)
if self.talker_mtp_graph_safe:
self.fast_ar._disable_compile_for_graph = True

# Constant logit mask: allow only semantic tokens + im_end.
vocab = int(self.text_config.vocab_size)
Expand Down Expand Up @@ -623,18 +626,13 @@ def talker_mtp(
inputs_embeds_out = input_embeds.reshape(bsz, -1).clone()

semantic_mask = (input_ids[:, 0] >= self._semantic_begin_id) & (input_ids[:, 0] <= self._semantic_end_id)
if semantic_mask.any():
semantic_codes = audio_codes[semantic_mask].clamp(min=0)
offsets = (
torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size
).unsqueeze(0)
codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16)

# Normalize by sqrt(num_codebooks + 1) as in the reference model
# (scale_codebook_embeddings=True for fish_qwen3_omni).
inputs_embeds_out[semantic_mask] = (inputs_embeds_out[semantic_mask] + codebook_sum) / math.sqrt(
self._num_codebooks + 1
)
semantic_codes = audio_codes.clamp(min=0, max=self._codebook_size - 1)
offsets = (
torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size
).unsqueeze(0)
codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16)
norm_embeds = (inputs_embeds_out + codebook_sum) / math.sqrt(self._num_codebooks + 1)
inputs_embeds_out = torch.where(semantic_mask.unsqueeze(-1), norm_embeds, inputs_embeds_out)

return inputs_embeds_out, audio_codes.to(dtype=torch.long)

Expand Down Expand Up @@ -745,14 +743,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if truncated:
logger.info("Truncated %d RoPE cos_sin_cache buffers to bf16 precision", truncated)

try:
self.fast_ar.warmup_compile(
device=self.codebook_embeddings.weight.device,
dtype=torch.bfloat16,
batch_sizes=(1,),
)
except Exception as exc:
logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc)
if not getattr(self, "talker_mtp_graph_safe", False):
try:
self.fast_ar.warmup_compile(
device=self.codebook_embeddings.weight.device,
dtype=torch.bfloat16,
batch_sizes=(1,),
)
except Exception as exc:
logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc)

codec_device = self.codebook_embeddings.weight.device
_load_dac_codec(
Expand Down
62 changes: 62 additions & 0 deletions vllm_omni/worker/gpu_ar_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,68 @@ def _sampling_metadata_for_model_sampler(self, sampling_metadata):
return sampling_metadata
return replace(sampling_metadata, output_token_ids=output_token_ids)

def capture_model(self) -> int:
result = super().capture_model()
self._capture_talker_mtp_graphs()
return result

def _capture_talker_mtp_graphs(self) -> None:
Comment thread
Sy0307 marked this conversation as resolved.
from vllm_omni.worker.gpu_model_runner import CUDAGraphWrapper

if not self.has_talker_mtp or not isinstance(self.talker_mtp, CUDAGraphWrapper):
return

from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.distributed.parallel_state import graph_capture

capture_sizes = self.compilation_config.cudagraph_capture_sizes
num_warmups = self.compilation_config.cudagraph_num_of_warmups
capture_sizes = sorted(capture_sizes, reverse=True)
logger.info("Capturing talker_mtp graphs for sizes %s", capture_sizes)

set_cudagraph_capturing_enabled(True)
try:
with torch.inference_mode(), graph_capture(device=self.device):
for bsz in capture_sizes:
_, batch_desc, _, _, _ = self._determine_batch_execution_and_padding(
num_tokens=bsz,
num_reqs=bsz,
num_scheduled_tokens_np=np.ones(bsz, dtype=np.int32),
max_num_scheduled_tokens=1,
use_cascade_attn=False,
)
n = batch_desc.num_tokens
ids = self.talker_mtp_input_ids.gpu[:n]
emb = self.talker_mtp_inputs_embeds.gpu[:n]
hid = self.last_talker_hidden.gpu[:n]
ts = self.text_step.gpu[:n]

for _ in range(num_warmups):
with set_forward_context(
None,
self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=batch_desc,
):
self.talker_mtp(ids, emb, hid, ts)

with set_forward_context(
None,
self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
batch_descriptor=batch_desc,
):
self.talker_mtp(ids, emb, hid, ts)
torch.cuda.synchronize()

logger.info("Captured talker_mtp graphs for %d sizes", len(capture_sizes))
except RuntimeError as e:
raise RuntimeError(
f"talker_mtp graph capture failed for a model that declared talker_mtp_graph_safe=True: {e}"
) from e
finally:
set_cudagraph_capturing_enabled(False)

@torch.inference_mode()
def execute_model(
self,
Expand Down
6 changes: 2 additions & 4 deletions vllm_omni/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,9 @@ def load_model(self, *args, **kwargs) -> None:
self.has_talker_mtp = True
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
# Only wrap talker_mtp in CUDAGraphWrapper for Omni models that
# have a separate .talker sub-module. TTS models' code predictor
# has internal AR loops / torch.multinomial — not graph-safe.
has_separate_talker = getattr(self.model, "talker", None) is not None
if cudagraph_mode.has_full_cudagraphs() and has_separate_talker:
talker_mtp_graph_safe = getattr(self.model, "talker_mtp_graph_safe", False)
if cudagraph_mode.has_full_cudagraphs() and (has_separate_talker or talker_mtp_graph_safe):
self.talker_mtp = CUDAGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL)
# TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size.
hidden_size = int(
Expand Down
Loading