diff --git a/examples/offline_inference/voxtral_tts/end2end.py b/examples/offline_inference/voxtral_tts/end2end.py index cf28f917e99..0a6f88715a9 100644 --- a/examples/offline_inference/voxtral_tts/end2end.py +++ b/examples/offline_inference/voxtral_tts/end2end.py @@ -298,6 +298,12 @@ def parse_args() -> Namespace: default=None, help="Voice to use instead of audio file.", ) + parser.add_argument( + "--cfg-alpha", + type=float, + default=None, + help="CFG alpha for flow-matching guidance (default: use value from stage config, typically 1.2).", + ) return parser.parse_args() @@ -349,8 +355,13 @@ def main(args: Any) -> None: inputs = compose_request(model_name, text_chunk, audio_prompt_file, args) + extra_args = {} + if args.cfg_alpha is not None: + extra_args["cfg_alpha"] = args.cfg_alpha + sampling_params = SamplingParams( max_tokens=max_num_tokens, + extra_args=extra_args if extra_args else None, ) sampling_params_list = [ sampling_params, diff --git a/examples/online_serving/voxtral_tts/gradio_demo.py b/examples/online_serving/voxtral_tts/gradio_demo.py index 35d6b590c97..7905c62618c 100644 --- a/examples/online_serving/voxtral_tts/gradio_demo.py +++ b/examples/online_serving/voxtral_tts/gradio_demo.py @@ -216,6 +216,7 @@ def update_voice_dropdown(language: str) -> gr.Dropdown: def run_inference( voice_name: str, text_prompt: str, + cfg_alpha: float, base_url: str, model: str, ) -> tuple[int, np.ndarray]: @@ -233,6 +234,7 @@ def run_inference( "model": model, "response_format": "wav", "voice": voice_name, + "extra_params": {"cfg_alpha": cfg_alpha}, } response = httpx.post( @@ -377,6 +379,14 @@ def main( placeholder="Enter the text you want to synthesize...", lines=4, ) + cfg_alpha_slider = gr.Slider( + minimum=1.0, + maximum=2.0, + step=0.1, + value=1.2, + label="CFG Alpha", + info="Flow-matching guidance strength (default: 1.2)", + ) with gr.Row(): reset_btn = gr.Button("Clear") submit_btn = gr.Button("Generate audio", interactive=False) @@ -415,9 +425,9 @@ def _toggle_submit(text: str): ) # --- Wiring inference + persistence to the button --- - def _on_submit(voice: str, text: str): + def _on_submit(voice: str, text: str, cfg_alpha: float): assert text.strip() != "" - sr, audio_array = run_inference(voice, text, base_url, model) + sr, audio_array = run_inference(voice, text, cfg_alpha, base_url, model) if outputs_dir is not None: share_id, saved_audio_path = _save_example( outputs_dir, @@ -432,7 +442,7 @@ def _on_submit(voice: str, text: str): submit_btn.click( fn=_on_submit, - inputs=[voice_name, text_prompt], + inputs=[voice_name, text_prompt, cfg_alpha_slider], outputs=[output_audio, share_link_box], ) @@ -446,6 +456,7 @@ def _on_reset(): language, # language_dropdown voice, # voice_name "", # text_prompt + 1.2, # cfg_alpha_slider None, # output_audio gr.update(interactive=False), # submit_btn "", # share_link_box @@ -456,7 +467,15 @@ def _on_reset(): reset_btn.click( fn=make_on_reset(languages, language_voices), inputs=[], - outputs=[language_dropdown, voice_name, text_prompt, output_audio, submit_btn, share_link_box], + outputs=[ + language_dropdown, + voice_name, + text_prompt, + cfg_alpha_slider, + output_audio, + submit_btn, + share_link_box, + ], ) def make_load_from_share(outputs_dir: Path | None, languages: list[str], language_voices: dict[str, list[str]]): diff --git a/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py b/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py index 847adae06fa..c7b023361a7 100644 --- a/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py +++ b/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py @@ -137,7 +137,7 @@ def __init__(self): def compute_mm_logits( self, hidden_states: torch.Tensor, - mm_sampling_tensors=None, + cfg_alpha: torch.Tensor, ): """Eager fallback path: replicate what the wrapper does.""" at = self.acoustic_transformer @@ -216,6 +216,10 @@ def _random_hidden(batch_size, device=DEVICE, dtype=torch.bfloat16): return torch.randn(batch_size, HIDDEN_DIM, device=device, dtype=dtype) +def _cfg_alpha(batch_size, value=1.2, device=DEVICE): + return torch.full((batch_size,), value, device=device, dtype=torch.float32) + + def _unpack_audio_codes(result): """Unpack (fake_eos, {"audio": [list of tensors]}) into (fake_eos, audio_codes).""" fake_eos, mm_tokens = result @@ -235,7 +239,7 @@ def test_exact_size_output_format(model, wrapper, batch_size): """Graph path returns correctly shaped and bounded outputs.""" hidden = _random_hidden(batch_size) with torch.no_grad(): - graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden)) + graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=_cfg_alpha(batch_size))) assert graph_eos.shape == (batch_size,) assert graph_codes.shape == (batch_size, 1 + N_ACOUSTIC_CODEBOOK) # fake_eos should be 0.0 or 1.0 @@ -248,11 +252,12 @@ def test_exact_size_output_format(model, wrapper, batch_size): def test_exact_size_deterministic(model, wrapper, batch_size): """Same input + same RNG state produces identical CUDA graph output.""" hidden = _random_hidden(batch_size) + cfg_alpha = _cfg_alpha(batch_size) with torch.no_grad(): torch.manual_seed(42) - eos1, codes1 = _unpack_audio_codes(wrapper(hidden)) + eos1, codes1 = _unpack_audio_codes(wrapper(hidden, cfg_alpha=cfg_alpha)) torch.manual_seed(42) - eos2, codes2 = _unpack_audio_codes(wrapper(hidden)) + eos2, codes2 = _unpack_audio_codes(wrapper(hidden, cfg_alpha=cfg_alpha)) torch.testing.assert_close(eos1, eos2, atol=0, rtol=0) torch.testing.assert_close(codes1, codes2, atol=0, rtol=0) @@ -267,7 +272,7 @@ def test_padded_output_shape(model, wrapper, batch_size): """Padded decode must return output trimmed to actual batch size.""" hidden = _random_hidden(batch_size) with torch.no_grad(): - graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden)) + graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=_cfg_alpha(batch_size))) assert graph_eos.shape == (batch_size,) assert graph_codes.shape == (batch_size, 1 + N_ACOUSTIC_CODEBOOK) @@ -277,7 +282,7 @@ def test_padded_output_bounded(model, wrapper, batch_size): """Padded output audio codes should be non-negative integers.""" hidden = _random_hidden(batch_size) with torch.no_grad(): - graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden)) + graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=_cfg_alpha(batch_size))) # fake_eos should be 0.0 or 1.0 assert torch.all((graph_eos == 0.0) | (graph_eos == 1.0)) # Audio codes should be non-negative @@ -293,11 +298,12 @@ def test_padded_output_bounded(model, wrapper, batch_size): def test_fallback_eager_exact_match(model, wrapper, batch_size): """Cudagraph fallback to eager. Two eager runs should produce identical results.""" hidden = _random_hidden(batch_size) + alpha = _cfg_alpha(batch_size) with torch.no_grad(): torch.manual_seed(100) - eager_eos, eager_codes = _unpack_audio_codes(model.compute_mm_logits(hidden)) + eager_eos, eager_codes = _unpack_audio_codes(model.compute_mm_logits(hidden, cfg_alpha=alpha)) torch.manual_seed(100) - graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden)) + graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=alpha)) torch.testing.assert_close(graph_eos, eager_eos, atol=0, rtol=0) torch.testing.assert_close(graph_codes, eager_codes, atol=0, rtol=0) @@ -310,12 +316,13 @@ def test_fallback_eager_exact_match(model, wrapper, batch_size): def test_disabled_wrapper_matches_eager(model, wrapper): """Cudagraph fallback to eager. Two eager runs should produce identical results.""" hidden = _random_hidden(4) + alpha = _cfg_alpha(4) wrapper.enabled = False with torch.no_grad(): torch.manual_seed(200) - eager_eos, eager_codes = _unpack_audio_codes(model.compute_mm_logits(hidden)) + eager_eos, eager_codes = _unpack_audio_codes(model.compute_mm_logits(hidden, cfg_alpha=alpha)) torch.manual_seed(200) - graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden)) + graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=alpha)) wrapper.enabled = True torch.testing.assert_close(graph_eos, eager_eos, atol=0, rtol=0) torch.testing.assert_close(graph_codes, eager_codes, atol=0, rtol=0) @@ -329,10 +336,11 @@ def test_disabled_wrapper_matches_eager(model, wrapper): def test_deterministic_across_calls(model, wrapper): """Same input + same RNG state. Two cudagraph runs should produce identical results.""" hidden = _random_hidden(4) + alpha = _cfg_alpha(4) with torch.no_grad(): torch.manual_seed(300) - eos1, codes1 = _unpack_audio_codes(wrapper(hidden)) + eos1, codes1 = _unpack_audio_codes(wrapper(hidden, cfg_alpha=alpha)) torch.manual_seed(300) - eos2, codes2 = _unpack_audio_codes(wrapper(hidden)) + eos2, codes2 = _unpack_audio_codes(wrapper(hidden, cfg_alpha=alpha)) torch.testing.assert_close(eos1, eos2, atol=0, rtol=0) torch.testing.assert_close(codes1, codes2, atol=0, rtol=0) diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py index 96a34a8d79f..af89184a6f5 100644 --- a/vllm_omni/config/model.py +++ b/vllm_omni/config/model.py @@ -114,6 +114,7 @@ class OmniModelConfig(ModelConfig): codec_frame_rate_hz: float | None = None task_type: str | None = None enable_sleep_mode: bool = False + has_sampling_extra_args: bool = False @property def registry(self): diff --git a/vllm_omni/deploy/voxtral_tts.yaml b/vllm_omni/deploy/voxtral_tts.yaml index b0899a3997d..87d999c67e0 100644 --- a/vllm_omni/deploy/voxtral_tts.yaml +++ b/vllm_omni/deploy/voxtral_tts.yaml @@ -36,6 +36,8 @@ stages: max_tokens: 2048 seed: 42 repetition_penalty: 1.1 + extra_args: + cfg_alpha: 1.2 tokenizer_mode: mistral config_format: mistral load_format: mistral diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index c93bd32c2f1..23e0e05850c 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -170,6 +170,7 @@ def _add_omni_specific_args(cls, parser: argparse.ArgumentParser) -> argparse.Ar output_modalities: list[str] | None = None log_stats: bool = False custom_pipeline_args: dict[str, Any] | None = None + has_sampling_extra_args: bool = False def __post_init__(self) -> None: if self.worker_cls is None: @@ -319,6 +320,7 @@ def create_model_config(self) -> OmniModelConfig: subtalker_sampling_params=self.subtalker_sampling_params, omni_kv_config=self.omni_kv_config, task_type=self.task_type, + has_sampling_extra_args=self.has_sampling_extra_args, ) return omni_config diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py index cd6cd6a69c6..94b9faa802c 100644 --- a/vllm_omni/engine/stage_init_utils.py +++ b/vllm_omni/engine/stage_init_utils.py @@ -411,6 +411,10 @@ def build_engine_args_dict( if stage_type != "diffusion": resolve_worker_cls(engine_args_dict) + # Check whether the stage's default_sampling_params defines extra_args. + default_sp = _to_dict(getattr(stage_config, "default_sampling_params", {})) + engine_args_dict["has_sampling_extra_args"] = bool(default_sp.get("extra_args")) + return engine_args_dict diff --git a/vllm_omni/entrypoints/openai/protocol/audio.py b/vllm_omni/entrypoints/openai/protocol/audio.py index 8468efd8613..59b5777a874 100644 --- a/vllm_omni/entrypoints/openai/protocol/audio.py +++ b/vllm_omni/entrypoints/openai/protocol/audio.py @@ -1,5 +1,5 @@ import math -from typing import Literal +from typing import Any, Literal import numpy as np from pydantic import AliasChoices, BaseModel, Field, field_validator, model_validator @@ -74,6 +74,10 @@ class OpenAICreateSpeechRequest(BaseModel): ge=0, description="Per-request initial chunk size override. If null, computed dynamically based on server load.", ) + extra_params: dict[str, Any] | None = Field( + default=None, + description=("Optional model-specific parameters passed directly to the model's extra_args."), + ) @field_validator("stream_format") @classmethod diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index c275c779590..5f686c66671 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -8,13 +8,14 @@ import struct import time from concurrent.futures import ThreadPoolExecutor +from http import HTTPStatus from pathlib import Path from typing import Any import numpy as np import soundfile as sf import torch -from fastapi import Request, UploadFile +from fastapi import HTTPException, Request, UploadFile from fastapi.responses import Response, StreamingResponse from transformers.utils.hub import cached_file from vllm.entrypoints.launcher import terminate_if_errored @@ -1664,6 +1665,21 @@ async def _prepare_speech_generation( max_tokens, ) + # Apply model-specific extra parameters + if request.extra_params is not None and sampling_params_list: + if not isinstance(request.extra_params, dict): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="extra_params must be a JSON object/dict.", + ) + import copy + + sampling_params_list = copy.deepcopy(sampling_params_list) + if sampling_params_list[0].extra_args is None: + sampling_params_list[0].extra_args = {} + sampling_params_list[0].extra_args.update(request.extra_params) + logger.info("Applied extra_params: %s", request.extra_params) + # Fish defaults come from stage_configs YAML. Only override when the caller # explicitly requests a different generation length. if self._is_fish_speech and request.max_new_tokens is not None and sampling_params_list: diff --git a/vllm_omni/model_executor/models/voxtral_tts/cuda_graph_acoustic_transformer_wrapper.py b/vllm_omni/model_executor/models/voxtral_tts/cuda_graph_acoustic_transformer_wrapper.py index ff053342dbe..d7407afe561 100644 --- a/vllm_omni/model_executor/models/voxtral_tts/cuda_graph_acoustic_transformer_wrapper.py +++ b/vllm_omni/model_executor/models/voxtral_tts/cuda_graph_acoustic_transformer_wrapper.py @@ -48,13 +48,13 @@ def __init__( self.n_acoustic_codebook = self.acoustic_transformer.model_args.n_acoustic_codebook self.acoustic_embeddings_levels = self.acoustic_transformer.acoustic_embeddings_levels - self.cfg_alpha = 1.2 self.n_steps = self.acoustic_transformer.acoustic_transformer_args.n_decoding_steps # Graph storage self.graphs: dict[int, CUDAGraph] = {} self.static_inputs: dict[int, torch.Tensor] = {} self.static_noise: dict[int, torch.Tensor] = {} + self.static_cfg_alpha: dict[int, torch.Tensor] = {} self.static_fake_eos: dict[int, torch.Tensor] = {} self.static_audio_codes: dict[int, torch.Tensor] = {} @@ -80,8 +80,10 @@ def _warmup_and_capture(self, device: torch.device, dtype: torch.dtype, hidden_d # Phase 1: Eager warmup for ALL capture sizes for size in self.capture_sizes: dummy = torch.zeros(size, hidden_dim, device=device, dtype=dtype) + dummy_cfg_alpha = torch.full((size, 1), 1.2, device=device, dtype=dtype) + dummy_noise = torch.randn(size, self.n_acoustic_codebook, device=device, dtype=dtype) with torch.no_grad(): - self._forward_cudagraph_compatible(dummy) + self._forward_cudagraph_compatible(dummy, cfg_alpha=dummy_cfg_alpha, noise=dummy_noise) torch.cuda.synchronize(device) @@ -105,7 +107,12 @@ def _warmup_and_capture(self, device: torch.device, dtype: torch.dtype, hidden_d len(self.capture_sizes), ) - def _forward_cudagraph_compatible(self, hidden_states: torch.Tensor, noise: torch.Tensor | None = None): + def _forward_cudagraph_compatible( + self, + hidden_states: torch.Tensor, + cfg_alpha: torch.Tensor, + noise: torch.Tensor, + ): """ The actual computation captured by the CUDA graph. @@ -117,6 +124,7 @@ def _forward_cudagraph_compatible(self, hidden_states: torch.Tensor, noise: torc - Calls _predict_velocity directly - Uses a pre-allocated noise buffer to avoid baking random state into the CUDA graph + - Uses a pre-allocated cfg_alpha buffer for per-request CFG strength """ at = self.acoustic_transformer B = hidden_states.shape[0] @@ -132,10 +140,7 @@ def _forward_cudagraph_compatible(self, hidden_states: torch.Tensor, noise: torc # --- Flow matching: Euler ODE --- should_decode = semantic_code.squeeze(1) != self.end_audio_token_id - if noise is not None: - x = noise - else: - x = torch.randn(B, self.n_acoustic_codebook, device=hidden_states.device, dtype=hidden_states.dtype) + x = noise # Pre-compute zero hidden states for unconditional CFG branch hidden_states_zero = torch.zeros_like(hidden_states) @@ -154,8 +159,8 @@ def _forward_cudagraph_compatible(self, hidden_states: torch.Tensor, noise: torc v_all = at._predict_velocity(x_t=x_batched, llm_output=llm_batched, t_emb=t_emb_batched) v_t, uncond_v_t = v_all[:B], v_all[B:] - # CFG combination - v_t = self.cfg_alpha * v_t + (1 - self.cfg_alpha) * uncond_v_t + # CFG combination (cfg_alpha is (B, 1), v_t is (B, C)) + v_t = cfg_alpha * v_t + (1 - cfg_alpha) * uncond_v_t x = x + v_t * dt @@ -188,10 +193,11 @@ def _capture_graph_for_size( """Capture a CUDA graph for a specific batch size.""" static_input = torch.zeros(size, hidden_dim, device=device, dtype=dtype) static_noise = torch.randn(size, self.n_acoustic_codebook, device=device, dtype=dtype) + static_cfg_alpha = torch.full((size, 1), 1.2, device=device, dtype=dtype) # Stabilizing eager run with torch.no_grad(): - _ = self._forward_cudagraph_compatible(static_input, noise=static_noise) + _ = self._forward_cudagraph_compatible(static_input, cfg_alpha=static_cfg_alpha, noise=static_noise) torch.cuda.synchronize(device) @@ -199,12 +205,13 @@ def _capture_graph_for_size( with torch.no_grad(): with torch.cuda.graph(graph, pool=current_platform.get_global_graph_pool()): static_fake_eos, static_audio_codes = self._forward_cudagraph_compatible( - static_input, noise=static_noise + static_input, cfg_alpha=static_cfg_alpha, noise=static_noise ) self.graphs[size] = graph self.static_inputs[size] = static_input self.static_noise[size] = static_noise + self.static_cfg_alpha[size] = static_cfg_alpha self.static_fake_eos[size] = static_fake_eos self.static_audio_codes[size] = static_audio_codes @@ -218,6 +225,7 @@ def _get_padded_size(self, actual_size: int) -> int | None: def __call__( self, hidden_states: torch.Tensor, + cfg_alpha: torch.Tensor, ) -> tuple[torch.Tensor, dict[str, list[torch.Tensor]] | None]: """ Drop-in replacement for model.compute_mm_logits(). @@ -229,16 +237,20 @@ def __call__( actual_size = hidden_states.shape[0] if not self.enabled or not self._warmed_up: - return self.model.compute_mm_logits(hidden_states) + return self.model.compute_mm_logits(hidden_states, cfg_alpha=cfg_alpha) padded_size = self._get_padded_size(actual_size) if padded_size is None or padded_size not in self.graphs: - return self.model.compute_mm_logits(hidden_states) + return self.model.compute_mm_logits(hidden_states, cfg_alpha=cfg_alpha) # Zero static input, then copy actual data self.static_inputs[padded_size].zero_() self.static_inputs[padded_size][:actual_size] = hidden_states + # Copy per-request cfg_alpha into static buffer (pad with 1.2 default) + self.static_cfg_alpha[padded_size].fill_(1.2) + self.static_cfg_alpha[padded_size][:actual_size, 0] = cfg_alpha + # Fill noise buffer with fresh random values before replay so the # flow-matching ODE starts from different initial noise each time. self.static_noise[padded_size].normal_() diff --git a/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts.py b/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts.py index 127171067d6..c7808915098 100644 --- a/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts.py +++ b/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts.py @@ -283,6 +283,30 @@ def forward( multimodal_outputs={"audio": batch_audio_arrays}, ) + _DEFAULT_CFG_ALPHA = 1.2 + + def _extract_cfg_alpha(self, input_hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + """Extract per-request cfg_alpha from sampling_extra_args. + + Returns a 1-D tensor of shape (B,) with per-request cfg_alpha values. + Falls back to default if sampling_extra_args is missing or incomplete. + """ + B = input_hidden_states.shape[0] + sampling_extra_args = kwargs.get("sampling_extra_args") + if sampling_extra_args is None: + return torch.full( + (B,), + self._DEFAULT_CFG_ALPHA, + device=input_hidden_states.device, + dtype=input_hidden_states.dtype, + ) + cfg_alpha_values = [ea.get("cfg_alpha", self._DEFAULT_CFG_ALPHA) for ea in sampling_extra_args] + return torch.tensor( + cfg_alpha_values, + device=input_hidden_states.device, + dtype=input_hidden_states.dtype, + ) + def make_omni_output( self, model_outputs: torch.Tensor | OmniOutput | tuple, logits_index: int | None = None, **kwargs ) -> OmniOutput: @@ -291,10 +315,15 @@ def make_omni_output( hidden_states = model_outputs assert logits_index is not None input_hidden_states = hidden_states[logits_index] + cfg_alpha = self._extract_cfg_alpha(input_hidden_states, **kwargs) if self._cudagraph_acoustic_transformer is not None: - fake_eos, multimodal_outputs = self._cudagraph_acoustic_transformer(input_hidden_states) + fake_eos, multimodal_outputs = self._cudagraph_acoustic_transformer( + input_hidden_states, cfg_alpha=cfg_alpha + ) else: - fake_eos, multimodal_outputs = self.model.compute_mm_logits(input_hidden_states) + fake_eos, multimodal_outputs = self.model.compute_mm_logits( + input_hidden_states, cfg_alpha=cfg_alpha + ) hidden_states[logits_index, 0] = fake_eos return OmniOutput( text_hidden_states=hidden_states, diff --git a/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts_audio_generation.py b/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts_audio_generation.py index cd67e4f0740..8b7dd7d1370 100644 --- a/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts_audio_generation.py +++ b/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts_audio_generation.py @@ -438,8 +438,6 @@ def __init__( # Flow matching constants self._n_steps = args.n_decoding_steps - # TODO(chenyo): hardcoded, need to fix - self._cfg_alpha = 1.2 self._noise_scale = 1.0 self.register_buffer( "_timesteps", @@ -512,6 +510,7 @@ def decode_one_frame( self, semantic_code: torch.Tensor, llm_hidden: torch.Tensor, + cfg_alpha: torch.Tensor, ) -> torch.Tensor: B = semantic_code.shape[0] @@ -525,6 +524,10 @@ def decode_one_frame( timesteps = self._timesteps.to(dtype=llm_hidden.dtype) llm_hidden_zero = torch.zeros_like(llm_hidden) + # Reshape cfg_alpha for broadcasting: (B,) -> (B, 1) + cfg_alpha = cfg_alpha.to(dtype=llm_hidden.dtype, device=llm_hidden.device) + cfg_alpha = cfg_alpha.unsqueeze(1) # (B, 1) for broadcasting with (B, C) + # Euler integration with batched conditional + unconditional velocity sampled = x_0 for i in range(len(timesteps) - 1): @@ -544,7 +547,7 @@ def decode_one_frame( t_emb=t_emb_batched, ) v_t, uncond_v_t = v_all[:B], v_all[B:] - v_t = self._cfg_alpha * v_t + (1 - self._cfg_alpha) * uncond_v_t + v_t = cfg_alpha * v_t + (1 - cfg_alpha) * uncond_v_t sampled = sampled + v_t * dt @@ -585,6 +588,7 @@ def _predict_velocity( def forward( self, llm_hidden: torch.Tensor, + cfg_alpha: torch.Tensor, ) -> torch.Tensor: # llm_hidden: BxD semantic_logit = self.semantic_codebook_output(llm_hidden).float() @@ -594,10 +598,10 @@ def forward( # semantic_logit: Bx1 semantic_code = semantic_logit.argmax(dim=-1, keepdim=True) - # acoustic codes, TODO(@chenyo): config sampling acoustic_codes = self.decode_one_frame( semantic_code.squeeze(1), llm_hidden, + cfg_alpha=cfg_alpha, ) audio_codes = torch.concatenate( @@ -1035,11 +1039,13 @@ def compute_logits( def compute_mm_logits( self, hidden_states: torch.Tensor, + cfg_alpha: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor | None]: audio_codes = None mm_tokens = None audio_codes = self.acoustic_transformer( llm_hidden=hidden_states, + cfg_alpha=cfg_alpha, ) fake_eos = torch.where( audio_codes[:, 0] == AudioSpecialTokens.id(AudioSpecialTokens.end_audio), diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index d3ccbaaf303..7a6f3b4538d 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1038,6 +1038,15 @@ def _build_model_kwargs_extra(self) -> dict: import traceback traceback.print_exc() + + if getattr(self.model_config, "has_sampling_extra_args", False): + extra_args_list: list[dict] = [] + for req_id in self.input_batch.req_ids: + req = self.requests[req_id] + sp = req.sampling_params if req else None + extra_args_list.append(sp.extra_args if sp and sp.extra_args else {}) + model_kwargs_extra["sampling_extra_args"] = extra_args_list + return model_kwargs_extra def _process_additional_information_updates(