diff --git a/tests/model_executor/models/qwen3_tts/test_talker_mtp_cuda_graph.py b/tests/model_executor/models/qwen3_tts/test_talker_mtp_cuda_graph.py new file mode 100644 index 00000000000..8557ad8fe3f --- /dev/null +++ b/tests/model_executor/models/qwen3_tts/test_talker_mtp_cuda_graph.py @@ -0,0 +1,339 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for TalkerMTPCudaGraphWrapper. + +Verifies: + - Warmup / graph capture mechanics. + - Output shape and validity of audio codes. + - Numerical equivalence with sampling disabled (no randomness). + - Batch size > 1 support via bucket-based graph capture. +""" + +from __future__ import annotations + +import importlib.util +import os + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +pytestmark = [pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")] + +DEVICE = torch.device("cuda:0") +VOCAB_SIZE = 8 +NUM_CODE_GROUPS = 3 +HIDDEN_SIZE = 16 + +# --------------------------------------------------------------------------- +# Module import (package or direct file path fallback) +# --------------------------------------------------------------------------- + +try: + from vllm_omni.model_executor.models.qwen3_tts.cuda_graph_talker_wrapper import ( + TalkerMTPCudaGraphWrapper, + ) +except Exception: + _WRAPPER_PATH = os.path.abspath( + os.path.join( + os.path.dirname(__file__), + os.pardir, + os.pardir, + os.pardir, + os.pardir, + "vllm_omni", + "model_executor", + "models", + "qwen3_tts", + "cuda_graph_decoder_wrapper.py", + ) + ) + _spec = importlib.util.spec_from_file_location("cuda_graph_decoder_wrapper", _WRAPPER_PATH) + _mod = importlib.util.module_from_spec(_spec) + _spec.loader.exec_module(_mod) + TalkerMTPCudaGraphWrapper = _mod.TalkerMTPCudaGraphWrapper + + +# --------------------------------------------------------------------------- +# Synthetic models that mimic the real talker / code-predictor interface +# --------------------------------------------------------------------------- + + +class SyntheticCodePredictorConfig: + def __init__( + self, + vocab_size: int = VOCAB_SIZE, + num_code_groups: int = NUM_CODE_GROUPS, + hidden_size: int = HIDDEN_SIZE, + ): + self.vocab_size = vocab_size + self.num_code_groups = num_code_groups + self.hidden_size = hidden_size + + +class SyntheticTalkerConfig: + def __init__( + self, + num_code_groups: int = NUM_CODE_GROUPS, + hidden_size: int = HIDDEN_SIZE, + ): + self.num_code_groups = num_code_groups + self.hidden_size = hidden_size + + +class SyntheticCodePredictor(nn.Module): + def __init__(self, config: SyntheticCodePredictorConfig): + super().__init__() + self.config = config + self._num_groups = config.num_code_groups + + # One lm_head per residual step (steps 1 .. Q-1) + self.lm_heads = nn.ModuleList( + [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_code_groups - 1)] + ) + # Codec embeddings for embedding-summation step inside the wrapper + self._codec_embeddings = nn.ModuleList( + [nn.Embedding(config.vocab_size, config.hidden_size) for _ in range(config.num_code_groups - 1)] + ) + + def get_input_embeddings(self) -> nn.ModuleList: + return self._codec_embeddings + + def forward( + self, + layer0_code: torch.Tensor, + layer0_embed: torch.Tensor, + last_talker_hidden: torch.Tensor, + do_sample: bool = True, + temperature: float = 0.9, + top_k: int = 50, + top_p: float = 1.0, + ) -> torch.Tensor: + bsz = int(layer0_code.shape[0]) + device = layer0_code.device + + all_codes = torch.zeros(bsz, self._num_groups, dtype=torch.long, device=device) + all_codes[:, 0] = layer0_code.reshape(bsz) + + use_sampling = do_sample and temperature > 0 + inv_temperature = 1.0 / max(temperature, 1e-6) if use_sampling else 0.0 + + # Use last_talker_hidden as the shared hidden state for all steps + hidden = last_talker_hidden.reshape(bsz, -1).to(self.lm_heads[0].weight.dtype) + + for step in range(1, self._num_groups): + logits = self.lm_heads[step - 1](hidden) # [bsz, vocab_size] + + if use_sampling: + scaled = logits * inv_temperature + if top_k > 0: + topk_vals, _ = scaled.topk(top_k, dim=-1) + scaled = scaled.masked_fill(scaled < topk_vals[:, -1:], float("-inf")) + probs = F.softmax(scaled, dim=-1) + next_ids = torch.multinomial(probs, num_samples=1) + else: + next_ids = logits.argmax(dim=-1, keepdim=True) + + all_codes[:, step] = next_ids.reshape(bsz) + + return all_codes + + +class SyntheticTalkerModel: + def __init__(self, predictor_config: SyntheticCodePredictorConfig): + self.code_predictor = SyntheticCodePredictor(predictor_config).to(device=DEVICE, dtype=torch.bfloat16) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def predictor_config(): + return SyntheticCodePredictorConfig() + + +@pytest.fixture(scope="module") +def talker_config(): + return SyntheticTalkerConfig() + + +@pytest.fixture(scope="module") +def talker_model(predictor_config): + torch.manual_seed(0) + return SyntheticTalkerModel(predictor_config) + + +@pytest.fixture(scope="module") +def wrapper(talker_model, talker_config): + w = TalkerMTPCudaGraphWrapper( + talker_model=talker_model, + talker_config=talker_config, + device=DEVICE, + enabled=True, + temperature=0.9, + top_k=VOCAB_SIZE, # allow all tokens + max_batch_size=4, + ) + w.warmup(DEVICE) + return w + + +def _random_inputs(bs: int = 1, hidden_size: int = HIDDEN_SIZE, seed: int | None = None): + if seed is not None: + torch.manual_seed(seed) + input_ids = torch.randint(0, VOCAB_SIZE, (bs,), dtype=torch.long, device=DEVICE) + last_id_hidden = torch.randn(bs, hidden_size, dtype=torch.bfloat16, device=DEVICE) + past_hidden = torch.randn(bs, hidden_size, dtype=torch.bfloat16, device=DEVICE) + text_step = torch.randn(bs, hidden_size, dtype=torch.bfloat16, device=DEVICE) + return input_ids, last_id_hidden, past_hidden, text_step + + +# --------------------------------------------------------------------------- +# 1. Warmup / capture mechanics and output shapes +# --------------------------------------------------------------------------- + + +def test_warmup_sets_captured_flag(talker_model, talker_config): + w = TalkerMTPCudaGraphWrapper( + talker_model=talker_model, + talker_config=talker_config, + device=DEVICE, + enabled=True, + top_k=VOCAB_SIZE, + ) + assert not w.captured + assert w.graph is None + w.warmup(DEVICE) + assert w.captured + assert w.graph is not None + + +def test_captures_all_buckets(talker_model, talker_config): + """Wrapper should capture one graph per bucket size.""" + w = TalkerMTPCudaGraphWrapper( + talker_model=talker_model, + talker_config=talker_config, + device=DEVICE, + enabled=True, + top_k=VOCAB_SIZE, + max_batch_size=4, + ) + w.warmup(DEVICE) + assert set(w.graphs.keys()) == {1, 2, 4} + + +@pytest.mark.parametrize("bs", [1, 2, 3, 4]) +def test_output_shapes(wrapper, bs): + inputs = _random_inputs(bs=bs, seed=42) + inputs_embeds, audio_codes = wrapper._talker_mtp(*inputs) + + assert inputs_embeds.shape == (bs, HIDDEN_SIZE), ( + f"Expected inputs_embeds shape ({bs}, {HIDDEN_SIZE}), got {inputs_embeds.shape}" + ) + assert audio_codes.shape == (bs, NUM_CODE_GROUPS), ( + f"Expected audio_codes shape ({bs}, {NUM_CODE_GROUPS}), got {audio_codes.shape}" + ) + + +# --------------------------------------------------------------------------- +# 2. Numerical equivalence with sampling disabled: CUDA graph vs eager +# --------------------------------------------------------------------------- + + +class _ArgmaxWrapper(TalkerMTPCudaGraphWrapper): + """TalkerMTPCudaGraphWrapper variant that captures argmax (do_sample=False).""" + + @torch.inference_mode + def _mtp_forward(self): + audio_codes = self.code_predictor.forward( + layer0_code=self.input_ids_buf, + layer0_embed=self.last_id_hidden_buf, + last_talker_hidden=self.past_hidden_buf, + do_sample=False, + ) + self.audio_codes_buf.copy_(audio_codes) + + bs = audio_codes.shape[0] + residual_ids = audio_codes[:, 1:] + self.inputs_embeds_out_buf.copy_(self.last_id_hidden_buf.reshape(bs, -1)) + codec_embeds = self.code_predictor.get_input_embeddings() + for i in range(self.num_code_groups - 1): + self.inputs_embeds_out_buf.add_(codec_embeds[i](residual_ids[:, i : i + 1]).reshape(bs, -1)) + self.inputs_embeds_out_buf.add_(self.text_step_buf.reshape(bs, -1)) + + +def _eager_mtp_argmax(predictor, input_ids, last_id_hidden, past_hidden, text_step): + bsz = input_ids.shape[0] + num_groups = predictor.config.num_code_groups + with torch.inference_mode(): + audio_codes = predictor.forward( + layer0_code=input_ids.reshape(bsz, 1), + layer0_embed=last_id_hidden.reshape(bsz, 1, -1), + last_talker_hidden=past_hidden.reshape(bsz, 1, -1), + do_sample=False, + ) + + residual_ids = audio_codes[:, 1:] + inputs_embeds = last_id_hidden.reshape(bsz, -1).clone() + codec_embeds = predictor.get_input_embeddings() + for i in range(num_groups - 1): + inputs_embeds.add_(codec_embeds[i](residual_ids[:, i : i + 1]).reshape(bsz, -1)) + inputs_embeds.add_(text_step.reshape(bsz, -1)) + + return inputs_embeds, audio_codes + + +@pytest.mark.parametrize("seed", [42, 99, 1]) +def test_graph_matches_eager_argmax(talker_model, talker_config, seed): + """Argmax CUDA graph must be bit-identical to eager argmax (bs=1).""" + w = _ArgmaxWrapper( + talker_model=talker_model, + talker_config=talker_config, + device=DEVICE, + enabled=True, + top_k=VOCAB_SIZE, + ) + w.warmup(DEVICE) + + input_ids, last_id_hidden, past_hidden, text_step = _random_inputs(bs=1, seed=seed) + graph_embeds, graph_codes = w._talker_mtp(input_ids, last_id_hidden, past_hidden, text_step) + eager_embeds, eager_codes = _eager_mtp_argmax( + talker_model.code_predictor, input_ids, last_id_hidden, past_hidden, text_step + ) + + torch.testing.assert_close( + graph_codes, eager_codes, atol=0, rtol=0, msg="audio_codes mismatch (argmax, no sampling)" + ) + torch.testing.assert_close( + graph_embeds, eager_embeds, atol=0, rtol=0, msg="inputs_embeds mismatch (argmax, no sampling)" + ) + + +@pytest.mark.parametrize("bs", [2, 3, 4]) +@pytest.mark.parametrize("seed", [42, 7]) +def test_graph_matches_eager_argmax_batched(talker_model, talker_config, bs, seed): + """Argmax CUDA graph must be bit-identical to eager argmax for bs > 1.""" + w = _ArgmaxWrapper( + talker_model=talker_model, + talker_config=talker_config, + device=DEVICE, + enabled=True, + top_k=VOCAB_SIZE, + max_batch_size=4, + ) + w.warmup(DEVICE) + + input_ids, last_id_hidden, past_hidden, text_step = _random_inputs(bs=bs, seed=seed) + graph_embeds, graph_codes = w._talker_mtp(input_ids, last_id_hidden, past_hidden, text_step) + eager_embeds, eager_codes = _eager_mtp_argmax( + talker_model.code_predictor, input_ids, last_id_hidden, past_hidden, text_step + ) + + torch.testing.assert_close(graph_codes, eager_codes, atol=0, rtol=0, msg=f"audio_codes mismatch (argmax, bs={bs})") + torch.testing.assert_close( + graph_embeds, eager_embeds, atol=0, rtol=0, msg=f"inputs_embeds mismatch (argmax, bs={bs})" + ) diff --git a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_talker_wrapper.py b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_talker_wrapper.py new file mode 100644 index 00000000000..886adf270d2 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_talker_wrapper.py @@ -0,0 +1,245 @@ +import torch +from torch.cuda import CUDAGraph +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class TalkerMTPCudaGraphWrapper: + """ + CUDA Graph wrapper for talker_mtp (multi-token prediction). + + Captures the entire MTP pipeline for each batch-size bucket: + - Code predictor forward + - Embedding summation + - Text step addition + + At inference time the wrapper selects the smallest captured bucket that fits + the actual batch size, zero-pads the inputs, replays the corresponding graph, + and returns only the non-padded output rows. + """ + + def __init__( + self, + talker_model, + talker_config, + device="cuda", + enabled=True, + temperature=0.9, + top_k=50, + num_warmup_steps=3, + max_batch_size: int = 1, + ): + self.device = device + self.device_index = torch.device(device).index or 0 + self.enabled = enabled + + self.talker = talker_model + self.code_predictor = talker_model.code_predictor + self.num_code_groups = talker_config.num_code_groups + self.hidden_size = talker_config.hidden_size + self.vocab_size = talker_model.code_predictor.config.vocab_size + self.temperature = temperature + self.top_k = top_k + + self.batch_sizes = self._compute_bucket_sizes(max_batch_size) + + # Per-bucket static GPU buffers, keyed by batch size. + self._buffers: dict[int, dict[str, torch.Tensor]] = {} + for bs in self.batch_sizes: + self._buffers[bs] = { + "input_ids": torch.zeros(bs, 1, dtype=torch.long, device=device), + "last_id_hidden": torch.zeros(bs, 1, self.hidden_size, dtype=torch.bfloat16, device=device), + "past_hidden": torch.zeros(bs, 1, self.hidden_size, dtype=torch.bfloat16, device=device), + "text_step": torch.zeros(bs, 1, self.hidden_size, dtype=torch.bfloat16, device=device), + "audio_codes": torch.zeros(bs, self.num_code_groups, dtype=torch.long, device=device), + "inputs_embeds": torch.zeros(bs, self.hidden_size, dtype=torch.bfloat16, device=device), + } + + # Current bucket's buffer dict; always set before _mtp_forward() is called. + self._active_bufs: dict[str, torch.Tensor] = self._buffers[self.batch_sizes[0]] + + self.graphs: dict[int, CUDAGraph] = {} + + self.num_warmup_steps = num_warmup_steps + self.warmed_up = False + self.captured = False + + def _compute_bucket_sizes(self, max_batch_size: int) -> list[int]: + """Return sorted list of CUDA-graph bucket sizes covering 1..max_batch_size. + + Uses powers of 2 up to max_batch_size, then appends max_batch_size itself + if it is not already a power of 2. Always includes 1. + """ + sizes: list[int] = [] + b = 1 + while b <= max_batch_size: + sizes.append(b) + b *= 2 + if sizes[-1] < max_batch_size: + sizes.append(max_batch_size) + return sizes + + @property + def input_ids_buf(self) -> torch.Tensor: + return self._active_bufs["input_ids"] + + @property + def last_id_hidden_buf(self) -> torch.Tensor: + return self._active_bufs["last_id_hidden"] + + @property + def past_hidden_buf(self) -> torch.Tensor: + return self._active_bufs["past_hidden"] + + @property + def text_step_buf(self) -> torch.Tensor: + return self._active_bufs["text_step"] + + @property + def audio_codes_buf(self) -> torch.Tensor: + return self._active_bufs["audio_codes"] + + @property + def inputs_embeds_out_buf(self) -> torch.Tensor: + return self._active_bufs["inputs_embeds"] + + @property + def graph(self) -> CUDAGraph | None: + return self.graphs.get(1) + + @torch.inference_mode + def _mtp_forward(self): + """Run the full MTP pipeline once; this is the function captured by the graph. + + Calls the code predictor to generate residual codebook tokens, then + accumulates their embeddings together with the layer-0 hidden state and + the text step to produce the next-step input embedding. + Results are written into the active output buffers (_active_bufs). + """ + audio_codes = self.code_predictor.forward( + layer0_code=self.input_ids_buf, + layer0_embed=self.last_id_hidden_buf, + last_talker_hidden=self.past_hidden_buf, + do_sample=True, + temperature=self.temperature, + top_k=self.top_k, + ) + self.audio_codes_buf.copy_(audio_codes) + + layer0 = self.audio_codes_buf[:, :1] + invalid0 = (layer0 < 0) | (layer0 >= int(self.vocab_size)) + self.audio_codes_buf.masked_fill_(invalid0.expand_as(self.audio_codes_buf), 0) + residual_ids = self.audio_codes_buf[:, 1:] + + embeds = [self.last_id_hidden_buf] + for i in range(self.num_code_groups - 1): + emb = self.code_predictor.get_input_embeddings()[i](residual_ids[:, i : i + 1]) + embeds.append(emb) + + bs = self.input_ids_buf.shape[0] + summed = torch.cat(embeds, dim=1).sum(1, keepdim=True) + result = (summed + self.text_step_buf).reshape(bs, -1) + self.inputs_embeds_out_buf.copy_(result) + + def capture(self): + """Warm up and capture _mtp_forward as CUDA graphs for every bucket size. + Running the largest batch size first ensures that the code predictor's + _proj_buf is sized for the max batch size and is not reallocated. + """ + for bs in reversed(self.batch_sizes): + self._active_bufs = self._buffers[bs] + for _ in range(self.num_warmup_steps): + self._mtp_forward() + torch.cuda.synchronize(self.device) + + for bs in self.batch_sizes: + self._active_bufs = self._buffers[bs] + + # Capture on a dedicated side stream so the default stream is not + # polluted by graph memory. + with torch.cuda.device(self.device_index): + graph = CUDAGraph() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + # One additional warmup on the capture stream. + self._mtp_forward() + s.synchronize() + with torch.cuda.graph(graph): + self._mtp_forward() + + torch.cuda.current_stream().wait_stream(s) + self.graphs[bs] = graph + + torch.cuda.synchronize() + self.captured = True + + def warmup(self, device: torch.device): + """Capture CUDA graphs for all batch-size buckets on the given device.""" + if not self.enabled: + logger.info("TalkerMTPCudaGraphWrapper: disabled, skipping capture") + return + if device.type != "cuda": + logger.info("CUDA Graph warmup skipped: device %s is not CUDA", device) + return + if self.warmed_up: + logger.warning("CUDA Graph already warmed up, skipping") + return + self.device = device + self.device_index = device.index or 0 + self.capture() + self.warmed_up = True + logger.info( + "TalkerMTPCudaGraphWrapper: CUDA graphs captured for batch sizes %s", + self.batch_sizes, + ) + + @torch.inference_mode() + def _talker_mtp(self, input_ids, last_id_hidden, past_hidden, text_step): + """Run one MTP step via graph replay. + Zero-pads the inputs to the smallest fitting bucket, unpads after replay. + + Args: + input_ids: Layer-0 token ids, shape [B] or [B, 1]. + last_id_hidden: Layer-0 hidden state, shape [B, H] or [B, 1, H]. + past_hidden: Previous talker hidden state, shape [B, H] or [B, 1, H]. + text_step: Current text hidden state, shape [B, H] or [B, 1, H]. + + Returns: + (inputs_embeds, audio_codes): shapes [B, H] and [B, num_code_groups]. + + Raises: + RuntimeError: If warmup() has not been called yet. + ValueError: If B exceeds the maximum captured bucket size. + """ + if not self.captured or not self.graphs: + raise RuntimeError("TalkerMTPCudaGraphWrapper: graph not captured — call warmup() first") + + actual_bs = input_ids.shape[0] + target_bs = min((b for b in self.graphs if b >= actual_bs), default=None) + if target_bs is None: + logger.warning( + "TalkerMTPCudaGraphWrapper: batch size %d exceeds max captured bucket %d, " + "falling back to eager execution", + actual_bs, + max(self.graphs), + ) + return self.talker._talker_mtp(input_ids, last_id_hidden, past_hidden, text_step) + + bufs = self._buffers[target_bs] + + bufs["input_ids"][:actual_bs].copy_(input_ids.reshape(actual_bs, 1)) + bufs["last_id_hidden"][:actual_bs].copy_(last_id_hidden.reshape(actual_bs, 1, -1)) + bufs["past_hidden"][:actual_bs].copy_(past_hidden.reshape(actual_bs, 1, -1)) + bufs["text_step"][:actual_bs].copy_(text_step.reshape(actual_bs, 1, -1)) + + if actual_bs < target_bs: + bufs["input_ids"][actual_bs:].zero_() + bufs["last_id_hidden"][actual_bs:].zero_() + bufs["past_hidden"][actual_bs:].zero_() + bufs["text_step"][actual_bs:].zero_() + + self.graphs[target_bs].replay() + + return bufs["inputs_embeds"][:actual_bs].clone(), bufs["audio_codes"][:actual_bs].clone() diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index de248f0f330..c868211125c 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -25,6 +25,7 @@ from vllm.model_executor.models.qwen3 import Qwen3Model from vllm.model_executor.models.utils import AutoWeightsLoader, PPMissingLayer, WeightsMapper, maybe_prefix from vllm.sequence import IntermediateTensors +from vllm.v1.utils import record_function_or_nullcontext from vllm_omni.model_executor.models.output_templates import OmniOutput @@ -406,6 +407,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._tokenizer = None self._speech_tokenizer: Qwen3TTSTokenizer | None = None + # CUDA Graph support + self._cudagraph_enabled = False + self._cudagraph_wrapper = None + # -------------------- vLLM required hooks -------------------- def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: @@ -514,21 +519,13 @@ def preprocess( input_embeds: torch.Tensor | None, **info_dict: Any, ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: - # Metadata may be passed flattened or under `additional_information`; normalize to flattened keys. - additional_information = info_dict.get("additional_information") - if isinstance(additional_information, dict): - merged: dict[str, Any] = {k: v for k, v in info_dict.items() if k != "additional_information"} - for k, v in additional_information.items(): - merged.setdefault(k, v) - info_dict = merged - span_len = int(input_ids.shape[0]) if span_len <= 0: return input_ids, input_embeds if input_embeds is not None else self.embed_input_ids(input_ids), {} text_list = info_dict.get("text") if not isinstance(text_list, list) or not text_list or not text_list[0]: - raise ValueError("Missing additional_information.text for Qwen3-TTS AR talker.") + raise ValueError("Missing model_intermediate_buffer['text'] for Qwen3-TTS AR talker.") task_type = (info_dict.get("task_type") or ["CustomVoice"])[0] codec_streaming_val = info_dict.get("codec_streaming") @@ -585,7 +582,9 @@ def preprocess( else: # Subsequent prefill chunk: slice from stored embeddings at running offset. if tts_pad_embed is None: - raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must initialize it.") + raise RuntimeError( + "Missing `tts_pad_embed` in model_intermediate_buffer; prefill must initialize it." + ) offset = int(info_dict.get("talker_prefill_offset", 0) or 0) if offset < 0: offset = 0 @@ -617,7 +616,7 @@ def preprocess( # These tensors stay on GPU via gpu_resident_buffer_keys - .to() is a no-op. tts_pad_embed_buf = info_dict.get("tts_pad_embed") if not isinstance(tts_pad_embed_buf, torch.Tensor): - raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must run first.") + raise RuntimeError("Missing `tts_pad_embed` in model_intermediate_buffer; prefill must run first.") tts_pad_embed = tts_pad_embed_buf.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) tail = info_dict.get("tailing_text_hidden") @@ -628,10 +627,11 @@ def preprocess( text_step = tts_pad_embed new_tail = tail if isinstance(tail, torch.Tensor) else torch.empty((0, tts_pad_embed.shape[-1])) - last_hidden = info_dict.get("last_talker_hidden") - if not isinstance(last_hidden, torch.Tensor): - raise RuntimeError("Missing `last_talker_hidden` in additional_information; postprocess must run.") - past_hidden = last_hidden.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + last_hidden_gpu = info_dict.get("last_talker_hidden") + if not isinstance(last_hidden_gpu, torch.Tensor): + raise RuntimeError("Missing `last_talker_hidden` in model_intermediate_buffer; postprocess must run.") + assert last_hidden_gpu.device == input_ids.device + past_hidden = last_hidden_gpu # Use OmniGPUModelRunner talker_mtp fast-path for residual codebooks and per-step inputs_embeds update. last_id_hidden = self.embed_input_ids(input_ids.reshape(1, 1).to(torch.long)).to( @@ -651,7 +651,7 @@ def postprocess(self, hidden_states: torch.Tensor, **_: Any) -> dict[str, Any]: # Stays on GPU - gpu_resident_buffer_keys avoids the CPU round-trip. if hidden_states.numel() == 0: return {} - last = hidden_states[-1, :].detach() + last = hidden_states[-1, :] return {"last_talker_hidden": last} # -------------------- prompt construction helpers -------------------- @@ -1440,7 +1440,7 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: elif task_type == "CustomVoice": speaker = (info_dict.get("speaker") or [""])[0] if not isinstance(speaker, str) or not speaker.strip(): - raise ValueError("CustomVoice requires additional_information.speaker.") + raise ValueError("CustomVoice requires model_intermediate_buffer['speaker'].") spk_id_map = getattr(self.talker_config, "spk_id", None) or {} if speaker.lower() not in spk_id_map: raise ValueError(f"Unsupported speaker: {speaker}") @@ -1569,15 +1569,51 @@ def _talker_and_collect_speaker(ws: Iterable[tuple[str, torch.Tensor]]): logger.info("Loaded %d weights for Qwen3TTSTalkerForConditionalGeneration", len(loaded)) return loaded + def enable_cudagraph(self, device: torch.device | None = None, max_batch_size: int = 1): + from .cuda_graph_talker_wrapper import TalkerMTPCudaGraphWrapper + + if device is None: + device = next(self.model.parameters()).device + if device.type != "cuda": + logger.warning("Cannot enable CUDA Graph: talker is not on a CUDA device (got %s)", device) + return + + self._cudagraph_wrapper = TalkerMTPCudaGraphWrapper( + talker_model=self, + talker_config=self.talker_config, + enabled=True, + max_batch_size=max_batch_size, + ) + try: + self._cudagraph_wrapper.warmup(device) + self._cudagraph_enabled = True + logger.info("CUDA Graph enabled for TTS talker MTP (max_batch_size=%d)", max_batch_size) + except Exception: + self._cudagraph_wrapper = None + logger.warning("CUDA Graph capture failed; falling back to eager execution", exc_info=True) + # -------------------- GPU-side MTP fast-path -------------------- - @torch.inference_mode() def talker_mtp( self, input_ids: torch.Tensor, input_embeds: torch.Tensor, last_talker_hidden: torch.Tensor, text_step: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Run talker graph if available, else fall back to eager""" + with record_function_or_nullcontext("talker_mtp"): + if self._cudagraph_enabled: + return self._cudagraph_wrapper._talker_mtp(input_ids, input_embeds, last_talker_hidden, text_step) + return self._talker_mtp(input_ids, input_embeds, last_talker_hidden, text_step) + + @torch.inference_mode() + def _talker_mtp( + self, + input_ids: torch.Tensor, + input_embeds: torch.Tensor, + last_talker_hidden: torch.Tensor, + text_step: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """GPU fast-path used by OmniGPUModelRunner to predict residual codebooks (1..Q-1). Returns (inputs_embeds, audio_codes) for the current step.""" diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index af3006ab2bd..8547f38d62b 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -98,16 +98,23 @@ def load_model(self, *args, **kwargs) -> None: cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None # Only wrap talker_mtp in CUDAGraphWrapper for Omni models that - # have a separate .talker sub-module. TTS models' code predictor - # has internal AR loops / torch.multinomial — not graph-safe. + # have a separate .talker sub-module. TTS models use model-specific graph implementation. has_separate_talker = getattr(self.model, "talker", None) is not None - if cudagraph_mode.has_full_cudagraphs() and has_separate_talker: - self.talker_mtp = CUDAGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) # TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size. hidden_size = int( getattr(self.model, "mtp_hidden_size", 0) or getattr(self.model_config.hf_text_config, "hidden_size") ) + # Defaults to max_num_seqs if max_cudagraph_capture_size is not set max_batch_size = max(self.max_num_reqs, self.compilation_config.max_cudagraph_capture_size) + if cudagraph_mode.has_full_cudagraphs(): + if has_separate_talker: + self.talker_mtp = CUDAGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) + else: + if hasattr(self.model, "enable_cudagraph"): + try: + self.model.enable_cudagraph(max_batch_size=max_batch_size) + except Exception: + logger.warning("Failed to enable CUDA graph for TTS talker", exc_info=True) self.talker_mtp_input_ids = self._make_buffer(max_batch_size, dtype=torch.int32) self.talker_mtp_inputs_embeds = self._make_buffer( max_batch_size, hidden_size, dtype=self.dtype, numpy=False @@ -1359,9 +1366,10 @@ def _update_intermediate_buffer(self, req_id: str, upd: dict) -> None: if req_state is None: return existing = self.model_intermediate_buffer.setdefault(req_id, {}) + gpu_keys: set[str] = getattr(self.model, "gpu_resident_buffer_keys", set()) for k, v in upd.items(): if isinstance(v, torch.Tensor): - existing[k] = v.detach().to("cpu").contiguous() + existing[k] = v.detach() if k in gpu_keys else v.detach().to("cpu").contiguous() elif isinstance(v, list): existing[k] = [ (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v