diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py index 2520a1d87e9..4a90b83afe5 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py @@ -346,10 +346,13 @@ def __init__( # Pre-allocated buffers (lazily initialized on first forward). self._proj_buf: torch.Tensor | None = None - self._pos_ids: torch.Tensor | None = None - # torch.compile: fuse small kernels in the 5-layer transformer. - self._compiled_model_fwd: object | None = None + # torch.compile + warmup state (lazily initialized in _setup_compile). + self._compiled_model_fwd = None + self._bucket_sizes: list[int] = [] + self._bucket_pos_ids: dict[int, torch.Tensor] = {} + self._lm_heads_list: list[nn.Module] | None = None + self._codec_embeds_list: list[nn.Module] | None = None def get_input_embeddings(self) -> nn.ModuleList: return self.model.get_input_embeddings() @@ -374,54 +377,74 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue default_weight_loader(params[name], w) loaded.add(name) + return loaded # ------------------------------------------------------------------ # Pre-allocated buffer management # ------------------------------------------------------------------ - def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None: + def _ensure_buffers(self, device: torch.device, dtype: torch.dtype) -> None: max_seq = self._num_groups + 1 - if ( - self._proj_buf is not None - and self._proj_buf.shape[0] >= bsz - and self._proj_buf.device == device - and self._proj_buf.dtype == dtype - ): + if self._proj_buf is not None and self._proj_buf.device == device and self._proj_buf.dtype == dtype: return + max_bsz = self._vllm_config.scheduler_config.max_num_seqs self._proj_buf = torch.zeros( - bsz, + max_bsz, max_seq, self._cp_hidden, dtype=dtype, device=device, ) - self._pos_ids = torch.arange( - max_seq, - dtype=torch.long, - device=device, - ) def _setup_compile(self) -> None: """Lazily set up torch.compiled model forward for kernel fusion. - Uses ``mode="default"`` so Inductor performs operator fusion without - capturing its own CUDA graphs. This avoids conflicts with vLLM's - ``CUDAGraphWrapper`` which manages CUDA graphs for the main Talker - model on the default stream. + Uses ``mode="reduce-overhead"`` with ``dynamic=False`` so Inductor + captures internal CUDA graphs for fixed shapes, eliminating kernel + launch overhead entirely. """ if self._compiled_model_fwd is not None: return + self._lm_heads_list = list(self.lm_head) + self._codec_embeds_list = list(self.model.codec_embedding) if not current_omni_platform.supports_torch_inductor(): logger.warning_once("code_predictor: torch.compile disabled") self._compiled_model_fwd = self.model.forward return self._compiled_model_fwd = torch.compile( self.model.forward, - mode="default", - dynamic=True, + mode="reduce-overhead", + dynamic=False, ) - logger.info("code_predictor: torch.compile enabled (mode=default)") + logger.info("code_predictor: torch.compile enabled (mode=reduce-overhead, dynamic=False)") + self._warmup_compile() + + def _padded_bsz(self, bsz: int) -> int: + for bucket in self._bucket_sizes: + if bsz <= bucket: + return bucket + return bsz + + def _warmup_compile(self) -> None: + """Warmup power-of-2 batch-size buckets to front-load compilation.""" + max_bsz = self._vllm_config.scheduler_config.max_num_seqs + bucket_sizes = [1 << i for i in range(max_bsz.bit_length()) if (1 << i) <= max_bsz] + if max_bsz not in bucket_sizes: + bucket_sizes.append(max_bsz) + self._bucket_sizes = sorted(bucket_sizes) + + max_seq = self._num_groups + 1 + device = next(self.model.parameters()).device + base_pos = torch.arange(max_seq, device=device, dtype=torch.long) + + proj_buf = self._proj_buf + for bsz in self._bucket_sizes: + pos_ids = base_pos if bsz == 1 else base_pos.repeat(bsz) + self._bucket_pos_ids[bsz] = pos_ids + for _ in range(3): + self._compiled_model_fwd(proj_buf[:bsz, :max_seq, :], pos_ids) + logger.info("code_predictor: warmup done for bucket sizes %s", self._bucket_sizes) # ------------------------------------------------------------------ # Optimized forward: re-prefill + torch.compile + projection cache @@ -454,19 +477,16 @@ def forward( all_codes = torch.empty(bsz, num_groups, dtype=torch.long, device=device) all_codes[:, 0] = layer0_code.reshape(bsz) - self._ensure_buffers(bsz, device, dtype) + self._ensure_buffers(device, dtype) self._setup_compile() proj_buf = self._proj_buf - pos_ids = self._pos_ids + max_seq = self._num_groups + 1 projection = self.small_to_mtp_projection model_fwd = self._compiled_model_fwd - lm_heads = list(self.lm_head) - codec_embeds = list(self.model.codec_embedding) - - proj_buf[:bsz, 0, :] = projection(last_talker_hidden.reshape(bsz, 1, -1)).reshape(bsz, -1) - proj_buf[:bsz, 1, :] = projection(layer0_embed.reshape(bsz, 1, -1)).reshape(bsz, -1) + lm_heads = self._lm_heads_list + codec_embeds = self._codec_embeds_list use_sampling = do_sample and temperature > 0 inv_temperature = 1.0 / max(temperature, 1e-6) if use_sampling else 0.0 @@ -475,15 +495,21 @@ def forward( "top_p sampling is not implemented for the vLLM-native code predictor; please set top_p=1.0." ) - for step in range(1, num_groups): - seq_len = step + 1 + padded_bsz = self._padded_bsz(bsz) + proj_buf[:padded_bsz].zero_() - projected = proj_buf[:bsz, :seq_len, :] - step_pos_ids = pos_ids[:seq_len] if bsz == 1 else pos_ids[:seq_len].repeat(bsz) + proj_buf[:bsz, 0, :] = projection(last_talker_hidden.reshape(bsz, 1, -1)).reshape(bsz, -1) + proj_buf[:bsz, 1, :] = projection(layer0_embed.reshape(bsz, 1, -1)).reshape(bsz, -1) + full_pos_ids = self._bucket_pos_ids.get(padded_bsz) + if full_pos_ids is None: + base_pos = torch.arange(max_seq, device=device, dtype=torch.long) + full_pos_ids = base_pos if padded_bsz == 1 else base_pos.repeat(padded_bsz) - hidden_out = model_fwd(projected, step_pos_ids) + for step in range(1, num_groups): + projected = proj_buf[:padded_bsz, :max_seq, :] - logits = lm_heads[step - 1](hidden_out[:, -1, :]) + hidden_out = model_fwd(projected, full_pos_ids) + logits = lm_heads[step - 1](hidden_out[:bsz, step, :]) if use_sampling: scaled = logits * inv_temperature