diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index 453792f4c81..e190ab6b70f 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -52,7 +52,7 @@ def __init__(self, vllm_config: Any): self.get_req_chunk: dict[str, int] = defaultdict(int) self.finished_requests: set[str] = set() self.request_payload = {} - self.code_prompt_token_ids: dict[str, list[list[int]]] = defaultdict(list) + self.code_prompt_token_ids: dict[str, list[torch.Tensor]] = defaultdict(list) self.request_ids_mapping: dict[str, str] = {} self.waiting_for_chunk_waiting_requests: deque[Any] = deque() diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_dac_decoder.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_dac_decoder.py index a889f66fdad..3a8042eb2e6 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_dac_decoder.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_dac_decoder.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn +from torch.nn.utils.parametrize import remove_parametrizations from vllm.config import VllmConfig from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger @@ -58,6 +59,56 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._hop_length: int = DAC_HOP_LENGTH self._logged_codec_stats = False + def _bake_weight_norm(self, codec: nn.Module) -> None: + baked = 0 + for module in codec.modules(): + parametrizations = getattr(module, "parametrizations", None) + if not parametrizations: + continue + for name in list(parametrizations.keys()): + remove_parametrizations(module, name, leave_parametrized=True) + baked += 1 + if baked > 0: + logger.info("Baked %d DAC parametrized weights for inference", baked) + + def _cache_attention_masks(self, codec: nn.Module) -> None: + for module in codec.modules(): + if not hasattr(module, "make_mask") or not hasattr(module, "make_window_limited_mask"): + continue + + base_make_mask = module.make_mask + base_make_window_mask = module.make_window_limited_mask + mask_cache: dict[int, torch.Tensor] = {} + window_mask_cache: dict[int, torch.Tensor] = {} + + def make_mask_cached(max_length: int, x_lens: torch.Tensor | None = None, *, _orig=base_make_mask): + if x_lens is not None: + return _orig(max_length, x_lens) + key = int(max_length) + cached = mask_cache.get(key) + if cached is None: + cached = _orig(max_length, x_lens) + mask_cache[key] = cached + return cached + + def make_window_mask_cached( + max_length: int, + x_lens: torch.Tensor | None = None, + *, + _orig=base_make_window_mask, + ): + if x_lens is not None: + return _orig(max_length, x_lens) + key = int(max_length) + cached = window_mask_cache.get(key) + if cached is None: + cached = _orig(max_length, x_lens) + window_mask_cache[key] = cached + return cached + + module.make_mask = make_mask_cached + module.make_window_limited_mask = make_window_mask_cached + def _ensure_codec_loaded(self) -> None: if self._codec is not None: return @@ -87,6 +138,8 @@ def _ensure_codec_loaded(self) -> None: if "generator" in state_dict: state_dict = state_dict["generator"] codec.load_state_dict(state_dict, strict=False) + self._bake_weight_norm(codec) + self._cache_attention_masks(codec) device = self.vllm_config.device_config.device codec = codec.to(device=device, dtype=torch.float32) @@ -160,7 +213,7 @@ def forward( ids = input_ids.reshape(-1).to(dtype=torch.long) request_ids_list = self._split_request_ids(ids, kwargs.get("seq_token_counts")) - parsed: list[tuple[int, int]] = [] + parsed_ctx_frames: list[int] = [] valid_codes_qf: list[torch.Tensor] = [] valid_indices: list[int] = [] left_context_size = [0] * len(request_ids_list) @@ -173,7 +226,7 @@ def forward( for i, req_ids in enumerate(request_ids_list): if req_ids.numel() < 1: - parsed.append((0, 0)) + parsed_ctx_frames.append(0) continue ctx_frames = left_context_size[i] flat = req_ids @@ -185,11 +238,11 @@ def forward( n, q, ) - parsed.append((0, 0)) + parsed_ctx_frames.append(0) continue frames = n // q codes_qf = flat.reshape(q, frames) - parsed.append((ctx_frames, frames)) + parsed_ctx_frames.append(ctx_frames) valid_codes_qf.append(codes_qf) valid_indices.append(i) @@ -219,23 +272,33 @@ def forward( except Exception: pass - # Decode each request individually. - wav_tensors: list[torch.Tensor] = [] - for codes_qf in valid_codes_qf: - codes_bqf = codes_qf.unsqueeze(0) # [1, num_codebooks, num_frames] - num_frames = codes_qf.shape[1] - feature_lengths = torch.tensor([num_frames], device=codes_bqf.device) - with torch.cuda.amp.autocast(enabled=False): - wav, audio_lengths = self._codec.decode(codes_bqf, feature_lengths) - # wav shape: [1, 1, wav_len] - wav_tensors.append(wav.squeeze(0).squeeze(0)) # [wav_len] + feature_lengths = torch.tensor( + [codes_qf.shape[1] for codes_qf in valid_codes_qf], + device=valid_codes_qf[0].device, + dtype=torch.long, + ) + max_frames = int(feature_lengths.max().item()) + batch_size = len(valid_codes_qf) + + codes_bqf = torch.zeros( + (batch_size, q, max_frames), + device=valid_codes_qf[0].device, + dtype=torch.long, + ) + for i, codes_qf in enumerate(valid_codes_qf): + frame_count = int(feature_lengths[i].item()) + codes_bqf[i, :, :frame_count] = codes_qf + + with torch.amp.autocast("cuda", enabled=False): + wav_batch, audio_lengths = self._codec.decode(codes_bqf, feature_lengths) audios: list[torch.Tensor] = [empty] * num_req srs = [sr_tensor] * num_req for j, idx in enumerate(valid_indices): - ctx_frames, actual_frames = parsed[idx] - wav = wav_tensors[j] + ctx_frames = parsed_ctx_frames[idx] + audio_len = int(audio_lengths[j].item()) if audio_lengths.numel() > j else int(wav_batch.shape[-1]) + wav = wav_batch[j, 0, :audio_len] # Trim context frames (left overlap for streaming). if ctx_frames > 0: cut = ctx_frames * self._hop_length diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py index f73914c2818..8bbb643ebec 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py @@ -308,6 +308,8 @@ def __init__( self._embed_buf: torch.Tensor | None = None self._pos_ids: torch.Tensor | None = None self._compiled_model_fwd: object | None = None + self._compile_attempted = False + self._compile_failed = False def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None: max_seq = self._num_codebooks + 1 # hidden_state + num_codebooks codes @@ -322,11 +324,61 @@ def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> self._pos_ids = torch.arange(max_seq, dtype=torch.long, device=device) def _setup_compile(self) -> None: - if self._compiled_model_fwd is not None: + if self._compile_attempted: return - # TODO: Enable torch.compile for performance. Eager for now to avoid - # potential graph-break issues during initial bring-up. - self._compiled_model_fwd = self.model.forward + self._compile_attempted = True + try: + self._compiled_model_fwd = torch.compile( + self.model.forward, + # Keep the helper compiler separate from vLLM's outer + # cudagraph-managed Stage-0 execution. + mode="default", + dynamic=True, + fullgraph=False, + ) + except Exception as exc: + self._compile_failed = True + logger.warning("Failed to enable torch.compile for Fish Speech Fast AR: %s", exc) + self._compiled_model_fwd = self.model.forward + else: + logger.info("Enabled torch.compile for Fish Speech Fast AR forward (mode=default)") + + @torch.inference_mode() + def warmup_compile( + self, + device: torch.device, + dtype: torch.dtype, + batch_sizes: tuple[int, ...] = (1,), + ) -> None: + self._setup_compile() + if self._compiled_model_fwd is self.model.forward or self._compile_failed: + return + for batch_size in batch_sizes: + hidden = torch.zeros((batch_size, self.slow_ar_config.hidden_size), device=device, dtype=dtype) + semantic = torch.full( + (batch_size,), + self.slow_ar_config.semantic_begin_id, + device=device, + dtype=torch.long, + ) + self(hidden, semantic, do_sample=False) + torch.cuda.synchronize(device) + + @torch.inference_mode() + def _run_model(self, step_input: torch.Tensor, step_pos_ids: torch.Tensor, bsz: int) -> torch.Tensor: + # Default-on compile only pays off for single-request decode. For + # batched decode, eager preserves loaded throughput and avoids the + # regression seen with batch>1 compiled execution. + model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward + try: + return model_fwd(step_input, step_pos_ids) + except Exception as exc: + if model_fwd is self.model.forward or self._compile_failed: + raise + self._compile_failed = True + self._compiled_model_fwd = self.model.forward + logger.warning("Fish Speech Fast AR torch.compile fallback to eager after runtime failure: %s", exc) + return self.model.forward(step_input, step_pos_ids) @torch.inference_mode() def forward( @@ -369,7 +421,6 @@ def forward( embed_buf = self._embed_buf pos_ids = self._pos_ids - model_fwd = self._compiled_model_fwd # Position 0: projected Slow AR hidden state. projected = self.fast_project_in(slow_ar_hidden.reshape(bsz, -1)) @@ -390,9 +441,11 @@ def forward( for step in range(1, num_cb): seq_len = step + 1 step_input = embed_buf[:bsz, :seq_len, :] - step_pos_ids = pos_ids[:seq_len] if bsz == 1 else pos_ids[:seq_len].repeat(bsz) + # Use a dense 2D position tensor for every batch size; stride-0 + # views from expand() were fragile under compiled execution. + step_pos_ids = pos_ids[:seq_len].unsqueeze(0).repeat(bsz, 1) - hidden_out = model_fwd(step_input, step_pos_ids) + hidden_out = self._run_model(step_input, step_pos_ids, bsz) logits = self.fast_output(self.fast_norm(hidden_out[:, -1, :])) # Residual codebooks (step >= 1) only have 1024 entries. diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py index 9c41d87f98f..6145815aac8 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py @@ -190,6 +190,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.has_postprocess = True self.mtp_hidden_size = int(self.text_config.hidden_size) self.talker_mtp_output_key = "audio_codes" + self.gpu_resident_buffer_keys: set[str] = {"last_slow_ar_hidden"} # Qwen3 transformer backbone. self.model = Qwen3Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) @@ -354,28 +355,31 @@ def preprocess( if span_len > 1: # --- Prefill --- - prompt_embeds_cpu = info_dict.get("slow_ar_prompt_embeds") - is_first_prefill = not isinstance(prompt_embeds_cpu, torch.Tensor) or prompt_embeds_cpu.ndim != 2 + prompt_embeds_buf = info_dict.get("slow_ar_prompt_embeds") + is_first_prefill = not isinstance(prompt_embeds_buf, torch.Tensor) or prompt_embeds_buf.ndim != 2 dev = input_ids.device if is_first_prefill: prompt_embeds = self._build_prefill_embeds(input_ids, info_dict) - prompt_embeds_cpu = prompt_embeds.detach().to("cpu").contiguous() + prompt_embeds_buf = prompt_embeds.detach().to("cpu").contiguous() + if not prompt_embeds_buf.is_pinned(): + prompt_embeds_buf = prompt_embeds_buf.pin_memory() + total_prompt_len = int(prompt_embeds_buf.shape[0]) + next_offset = min(span_len, total_prompt_len) info_update: dict[str, Any] = { - "slow_ar_prompt_embeds": prompt_embeds_cpu, - "prefill_offset": 0, + "slow_ar_prompt_embeds": prompt_embeds_buf if next_offset < total_prompt_len else None, + "prefill_offset": next_offset, } - take = prompt_embeds_cpu[:span_len] + take = prompt_embeds_buf[:span_len] if int(take.shape[0]) < span_len: pad_n = span_len - int(take.shape[0]) pad_embed = self.embed_input_ids( torch.tensor([self._audio_pad_token_id], device=dev, dtype=torch.long) ).reshape(1, -1) - take = torch.cat([take, pad_embed.detach().cpu().expand(pad_n, -1)], dim=0) - prompt_embeds = take.to(device=dev, dtype=torch.bfloat16) - info_update["prefill_offset"] = span_len + take = torch.cat([take, pad_embed.expand(pad_n, -1)], dim=0) + prompt_embeds = take.to(device=dev, dtype=torch.bfloat16, non_blocking=True) zeros = torch.zeros( (prompt_embeds.shape[0], self._num_codebooks), @@ -391,23 +395,26 @@ def preprocess( else: # Subsequent prefill chunk. offset = int(info_dict.get("prefill_offset", 0) or 0) - s = max(0, min(offset, int(prompt_embeds_cpu.shape[0]))) - e = max(0, min(offset + span_len, int(prompt_embeds_cpu.shape[0]))) - take = prompt_embeds_cpu[s:e] + total_prompt_len = int(prompt_embeds_buf.shape[0]) + s = max(0, min(offset, total_prompt_len)) + e = max(0, min(offset + span_len, total_prompt_len)) + take = prompt_embeds_buf[s:e] if int(take.shape[0]) < span_len: pad_n = span_len - int(take.shape[0]) pad_embed = self.embed_input_ids( torch.tensor([self._audio_pad_token_id], device=dev, dtype=torch.long) ).reshape(1, -1) - take = torch.cat([take, pad_embed.detach().cpu().expand(pad_n, -1)], dim=0) - prompt_embeds = take.to(device=dev, dtype=torch.bfloat16) + take = torch.cat([take, pad_embed.expand(pad_n, -1)], dim=0) + prompt_embeds = take.to(device=dev, dtype=torch.bfloat16, non_blocking=True) + next_offset = offset + span_len zeros = torch.zeros((prompt_embeds.shape[0], self._num_codebooks), device=dev, dtype=torch.long) return ( input_ids.clone().fill_(self._audio_pad_token_id), prompt_embeds, { - "prefill_offset": offset + span_len, + "slow_ar_prompt_embeds": prompt_embeds_buf if next_offset < total_prompt_len else None, + "prefill_offset": next_offset, "audio_codes": zeros, }, ) @@ -415,8 +422,8 @@ def preprocess( # --- Decode: span_len == 1 --- dev = input_ids.device - last_hidden_cpu = info_dict.get("last_slow_ar_hidden") - if not isinstance(last_hidden_cpu, torch.Tensor): + last_hidden = info_dict.get("last_slow_ar_hidden") + if not isinstance(last_hidden, torch.Tensor): # First decode step after prefill -- just embed the token directly. logger.warning( "preprocess decode: last_slow_ar_hidden not found (keys=%s), " @@ -437,7 +444,7 @@ def preprocess( info_update = { "mtp_inputs": ( - last_hidden_cpu.to(device=dev, dtype=torch.bfloat16).reshape(1, -1), + last_hidden.to(device=dev, dtype=torch.bfloat16).reshape(1, -1), torch.zeros(1, self.text_config.hidden_size, device=dev, dtype=torch.bfloat16), ), } @@ -447,7 +454,7 @@ def postprocess(self, hidden_states: torch.Tensor, **_: Any) -> dict[str, Any]: if hidden_states.numel() == 0: logger.debug("postprocess: empty hidden_states") return {} - last = hidden_states[-1, :].detach().to("cpu").contiguous() + last = hidden_states[-1, :].detach().contiguous() logger.debug("postprocess: saved last_slow_ar_hidden shape=%s", tuple(last.shape)) return {"last_slow_ar_hidden": last} @@ -542,20 +549,19 @@ def talker_mtp( # This ensures the Slow AR sees codes from FastAR(hidden_{t-1}). inputs_embeds_out = input_embeds.reshape(bsz, -1).clone() - for b in range(bsz): - token_id = int(input_ids[b, 0].item()) - is_semantic = self._semantic_begin_id <= token_id <= self._semantic_end_id - if is_semantic: - codes = audio_codes[b] # [num_codebooks] - codebook_sum = torch.zeros(self.text_config.hidden_size, device=dev, dtype=torch.bfloat16) - for i in range(self._num_codebooks): - code_with_offset = codes[i].clamp(min=0) + i * self._codebook_size - emb = self.codebook_embeddings(code_with_offset.unsqueeze(0)) - codebook_sum += emb.squeeze(0).to(dtype=torch.bfloat16) - - # Normalize by sqrt(num_codebooks + 1) as in the reference model - # (scale_codebook_embeddings=True for fish_qwen3_omni). - inputs_embeds_out[b] = (inputs_embeds_out[b] + codebook_sum) / math.sqrt(self._num_codebooks + 1) + semantic_mask = (input_ids[:, 0] >= self._semantic_begin_id) & (input_ids[:, 0] <= self._semantic_end_id) + if semantic_mask.any(): + semantic_codes = audio_codes[semantic_mask].clamp(min=0) + offsets = ( + torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size + ).unsqueeze(0) + codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16) + + # Normalize by sqrt(num_codebooks + 1) as in the reference model + # (scale_codebook_embeddings=True for fish_qwen3_omni). + inputs_embeds_out[semantic_mask] = (inputs_embeds_out[semantic_mask] + codebook_sum) / math.sqrt( + self._num_codebooks + 1 + ) return inputs_embeds_out, audio_codes.to(dtype=torch.long) @@ -666,4 +672,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if truncated: logger.info("Truncated %d RoPE cos_sin_cache buffers to bf16 precision", truncated) + try: + self.fast_ar.warmup_compile( + device=self.codebook_embeddings.weight.device, + dtype=torch.bfloat16, + batch_sizes=(1,), + ) + except Exception as exc: + logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc) + return loaded_params diff --git a/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml b/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml index c00528d1035..9b37972f18c 100644 --- a/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml +++ b/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml @@ -4,6 +4,7 @@ stage_args: stage_type: llm runtime: devices: "0" + max_batch_size: 16 engine_args: max_num_seqs: 4 model_stage: fish_speech_slow_ar @@ -15,10 +16,10 @@ stage_args: async_scheduling: false enable_prefix_caching: false engine_output_type: latent - gpu_memory_utilization: 0.4 + gpu_memory_utilization: 0.6 distributed_executor_backend: "mp" - max_num_batched_tokens: 2048 - max_model_len: 32768 + max_num_batched_tokens: 3072 + max_model_len: 16384 custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.fish_speech.slow_ar_to_dac_decoder_async_chunk output_connectors: to_stage_1: connector_of_shared_memory @@ -37,6 +38,7 @@ stage_args: stage_type: llm runtime: devices: "0" + max_batch_size: 16 engine_args: max_num_seqs: 1 model_stage: dac_decoder @@ -48,10 +50,10 @@ stage_args: async_scheduling: false enable_prefix_caching: false engine_output_type: audio - gpu_memory_utilization: 0.2 + gpu_memory_utilization: 0.1 distributed_executor_backend: "mp" max_num_batched_tokens: 8192 - max_model_len: 32768 + max_model_len: 16384 engine_input_source: [0] final_output: true final_output_type: audio @@ -70,7 +72,7 @@ runtime: enabled: true defaults: window_size: -1 - max_inflight: 1 + max_inflight: 16 connectors: connector_of_shared_memory: @@ -85,7 +87,7 @@ runtime: # 25 frames ≈ 1.16s of audio at 21.5 Hz. codec_chunk_frames: 25 codec_left_context_frames: 25 - initial_codec_chunk_frames: 0 + initial_codec_chunk_frames: 4 edges: - from: 0 diff --git a/vllm_omni/model_executor/stage_input_processors/fish_speech.py b/vllm_omni/model_executor/stage_input_processors/fish_speech.py index 5be70bc18b2..d857c9123af 100644 --- a/vllm_omni/model_executor/stage_input_processors/fish_speech.py +++ b/vllm_omni/model_executor/stage_input_processors/fish_speech.py @@ -7,8 +7,6 @@ logger = init_logger(__name__) -_NUM_CODEBOOKS = 10 # 1 semantic + 9 residual - def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None: """Extract the last frame of audio codes from the pooling output.""" @@ -76,8 +74,7 @@ def slow_ar_to_dac_decoder_async_chunk( if isinstance(pooling_output, dict): frame = _extract_last_frame(pooling_output) if frame is not None: - codec_codes = frame.cpu().tolist() - transfer_manager.code_prompt_token_ids[request_id].append(codec_codes) + transfer_manager.code_prompt_token_ids[request_id].append(frame.detach().to(device="cpu", dtype=torch.long)) elif not finished: return None @@ -114,7 +111,7 @@ def slow_ar_to_dac_decoder_async_chunk( if finished: return { "code_predictor_codes": [], - "finished": torch.tensor(True, dtype=torch.bool), + "finished": True, } return None @@ -142,10 +139,11 @@ def slow_ar_to_dac_decoder_async_chunk( window_frames = transfer_manager.code_prompt_token_ids[request_id][-end_index:] # Pack into codebook-major flat codes. - code_predictor_codes = torch.tensor(window_frames).transpose(0, 1).reshape(-1).tolist() + stacked_frames = torch.stack(window_frames, dim=0) + code_predictor_codes = stacked_frames.transpose(0, 1).reshape(-1).tolist() return { "code_predictor_codes": code_predictor_codes, "left_context_size": left_context_size, - "finished": torch.tensor(finished, dtype=torch.bool), + "finished": finished, }