Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions examples/offline_inference/voxtral_tts/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ def parse_args() -> Namespace:
default=None,
help="Voice to use instead of audio file.",
)
parser.add_argument(
"--cfg-alpha",
type=float,
default=None,
help="CFG alpha for flow-matching guidance (default: use value from stage config, typically 1.2).",
)
return parser.parse_args()


Expand Down Expand Up @@ -349,8 +355,13 @@ def main(args: Any) -> None:

inputs = compose_request(model_name, text_chunk, audio_prompt_file, args)

extra_args = {}
if args.cfg_alpha is not None:
extra_args["cfg_alpha"] = args.cfg_alpha

sampling_params = SamplingParams(
max_tokens=max_num_tokens,
extra_args=extra_args if extra_args else None,
)
sampling_params_list = [
sampling_params,
Expand Down
27 changes: 23 additions & 4 deletions examples/online_serving/voxtral_tts/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def update_voice_dropdown(language: str) -> gr.Dropdown:
def run_inference(
voice_name: str,
text_prompt: str,
cfg_alpha: float,
base_url: str,
model: str,
) -> tuple[int, np.ndarray]:
Expand All @@ -233,6 +234,7 @@ def run_inference(
"model": model,
"response_format": "wav",
"voice": voice_name,
"extra_params": {"cfg_alpha": cfg_alpha},
}

response = httpx.post(
Expand Down Expand Up @@ -377,6 +379,14 @@ def main(
placeholder="Enter the text you want to synthesize...",
lines=4,
)
cfg_alpha_slider = gr.Slider(
minimum=1.0,
maximum=2.0,
step=0.1,
value=1.2,
label="CFG Alpha",
info="Flow-matching guidance strength (default: 1.2)",
)
with gr.Row():
reset_btn = gr.Button("Clear")
submit_btn = gr.Button("Generate audio", interactive=False)
Expand Down Expand Up @@ -415,9 +425,9 @@ def _toggle_submit(text: str):
)

# --- Wiring inference + persistence to the button ---
def _on_submit(voice: str, text: str):
def _on_submit(voice: str, text: str, cfg_alpha: float):
assert text.strip() != ""
sr, audio_array = run_inference(voice, text, base_url, model)
sr, audio_array = run_inference(voice, text, cfg_alpha, base_url, model)
if outputs_dir is not None:
share_id, saved_audio_path = _save_example(
outputs_dir,
Expand All @@ -432,7 +442,7 @@ def _on_submit(voice: str, text: str):

submit_btn.click(
fn=_on_submit,
inputs=[voice_name, text_prompt],
inputs=[voice_name, text_prompt, cfg_alpha_slider],
outputs=[output_audio, share_link_box],
)

Expand All @@ -446,6 +456,7 @@ def _on_reset():
language, # language_dropdown
voice, # voice_name
"", # text_prompt
1.2, # cfg_alpha_slider
None, # output_audio
gr.update(interactive=False), # submit_btn
"", # share_link_box
Expand All @@ -456,7 +467,15 @@ def _on_reset():
reset_btn.click(
fn=make_on_reset(languages, language_voices),
inputs=[],
outputs=[language_dropdown, voice_name, text_prompt, output_audio, submit_btn, share_link_box],
outputs=[
language_dropdown,
voice_name,
text_prompt,
cfg_alpha_slider,
output_audio,
submit_btn,
share_link_box,
],
)

def make_load_from_share(outputs_dir: Path | None, languages: list[str], language_voices: dict[str, list[str]]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(self):
def compute_mm_logits(
self,
hidden_states: torch.Tensor,
mm_sampling_tensors=None,
cfg_alpha: torch.Tensor,
):
"""Eager fallback path: replicate what the wrapper does."""
at = self.acoustic_transformer
Expand Down Expand Up @@ -216,6 +216,10 @@ def _random_hidden(batch_size, device=DEVICE, dtype=torch.bfloat16):
return torch.randn(batch_size, HIDDEN_DIM, device=device, dtype=dtype)


def _cfg_alpha(batch_size, value=1.2, device=DEVICE):
return torch.full((batch_size,), value, device=device, dtype=torch.float32)


def _unpack_audio_codes(result):
"""Unpack (fake_eos, {"audio": [list of tensors]}) into (fake_eos, audio_codes)."""
fake_eos, mm_tokens = result
Expand All @@ -235,7 +239,7 @@ def test_exact_size_output_format(model, wrapper, batch_size):
"""Graph path returns correctly shaped and bounded outputs."""
hidden = _random_hidden(batch_size)
with torch.no_grad():
graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden))
graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=_cfg_alpha(batch_size)))
assert graph_eos.shape == (batch_size,)
assert graph_codes.shape == (batch_size, 1 + N_ACOUSTIC_CODEBOOK)
# fake_eos should be 0.0 or 1.0
Expand All @@ -248,11 +252,12 @@ def test_exact_size_output_format(model, wrapper, batch_size):
def test_exact_size_deterministic(model, wrapper, batch_size):
"""Same input + same RNG state produces identical CUDA graph output."""
hidden = _random_hidden(batch_size)
cfg_alpha = _cfg_alpha(batch_size)
with torch.no_grad():
torch.manual_seed(42)
eos1, codes1 = _unpack_audio_codes(wrapper(hidden))
eos1, codes1 = _unpack_audio_codes(wrapper(hidden, cfg_alpha=cfg_alpha))
torch.manual_seed(42)
eos2, codes2 = _unpack_audio_codes(wrapper(hidden))
eos2, codes2 = _unpack_audio_codes(wrapper(hidden, cfg_alpha=cfg_alpha))
torch.testing.assert_close(eos1, eos2, atol=0, rtol=0)
torch.testing.assert_close(codes1, codes2, atol=0, rtol=0)

Expand All @@ -267,7 +272,7 @@ def test_padded_output_shape(model, wrapper, batch_size):
"""Padded decode must return output trimmed to actual batch size."""
hidden = _random_hidden(batch_size)
with torch.no_grad():
graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden))
graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=_cfg_alpha(batch_size)))
assert graph_eos.shape == (batch_size,)
assert graph_codes.shape == (batch_size, 1 + N_ACOUSTIC_CODEBOOK)

Expand All @@ -277,7 +282,7 @@ def test_padded_output_bounded(model, wrapper, batch_size):
"""Padded output audio codes should be non-negative integers."""
hidden = _random_hidden(batch_size)
with torch.no_grad():
graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden))
graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=_cfg_alpha(batch_size)))
# fake_eos should be 0.0 or 1.0
assert torch.all((graph_eos == 0.0) | (graph_eos == 1.0))
# Audio codes should be non-negative
Expand All @@ -293,11 +298,12 @@ def test_padded_output_bounded(model, wrapper, batch_size):
def test_fallback_eager_exact_match(model, wrapper, batch_size):
"""Cudagraph fallback to eager. Two eager runs should produce identical results."""
hidden = _random_hidden(batch_size)
alpha = _cfg_alpha(batch_size)
with torch.no_grad():
torch.manual_seed(100)
eager_eos, eager_codes = _unpack_audio_codes(model.compute_mm_logits(hidden))
eager_eos, eager_codes = _unpack_audio_codes(model.compute_mm_logits(hidden, cfg_alpha=alpha))
torch.manual_seed(100)
graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden))
graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=alpha))
torch.testing.assert_close(graph_eos, eager_eos, atol=0, rtol=0)
torch.testing.assert_close(graph_codes, eager_codes, atol=0, rtol=0)

Expand All @@ -310,12 +316,13 @@ def test_fallback_eager_exact_match(model, wrapper, batch_size):
def test_disabled_wrapper_matches_eager(model, wrapper):
"""Cudagraph fallback to eager. Two eager runs should produce identical results."""
hidden = _random_hidden(4)
alpha = _cfg_alpha(4)
wrapper.enabled = False
with torch.no_grad():
torch.manual_seed(200)
eager_eos, eager_codes = _unpack_audio_codes(model.compute_mm_logits(hidden))
eager_eos, eager_codes = _unpack_audio_codes(model.compute_mm_logits(hidden, cfg_alpha=alpha))
torch.manual_seed(200)
graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden))
graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=alpha))
wrapper.enabled = True
torch.testing.assert_close(graph_eos, eager_eos, atol=0, rtol=0)
torch.testing.assert_close(graph_codes, eager_codes, atol=0, rtol=0)
Expand All @@ -329,10 +336,11 @@ def test_disabled_wrapper_matches_eager(model, wrapper):
def test_deterministic_across_calls(model, wrapper):
"""Same input + same RNG state. Two cudagraph runs should produce identical results."""
hidden = _random_hidden(4)
alpha = _cfg_alpha(4)
with torch.no_grad():
torch.manual_seed(300)
eos1, codes1 = _unpack_audio_codes(wrapper(hidden))
eos1, codes1 = _unpack_audio_codes(wrapper(hidden, cfg_alpha=alpha))
torch.manual_seed(300)
eos2, codes2 = _unpack_audio_codes(wrapper(hidden))
eos2, codes2 = _unpack_audio_codes(wrapper(hidden, cfg_alpha=alpha))
torch.testing.assert_close(eos1, eos2, atol=0, rtol=0)
torch.testing.assert_close(codes1, codes2, atol=0, rtol=0)
1 change: 1 addition & 0 deletions vllm_omni/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class OmniModelConfig(ModelConfig):
codec_frame_rate_hz: float | None = None
task_type: str | None = None
enable_sleep_mode: bool = False
has_sampling_extra_args: bool = False

@property
def registry(self):
Expand Down
2 changes: 2 additions & 0 deletions vllm_omni/deploy/voxtral_tts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ stages:
max_tokens: 2048
seed: 42
repetition_penalty: 1.1
extra_args:
cfg_alpha: 1.2
tokenizer_mode: mistral
config_format: mistral
load_format: mistral
Expand Down
2 changes: 2 additions & 0 deletions vllm_omni/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def _add_omni_specific_args(cls, parser: argparse.ArgumentParser) -> argparse.Ar
output_modalities: list[str] | None = None
log_stats: bool = False
custom_pipeline_args: dict[str, Any] | None = None
has_sampling_extra_args: bool = False

def __post_init__(self) -> None:
if self.worker_cls is None:
Expand Down Expand Up @@ -319,6 +320,7 @@ def create_model_config(self) -> OmniModelConfig:
subtalker_sampling_params=self.subtalker_sampling_params,
omni_kv_config=self.omni_kv_config,
task_type=self.task_type,
has_sampling_extra_args=self.has_sampling_extra_args,
)
return omni_config

Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/engine/stage_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,10 @@ def build_engine_args_dict(
if stage_type != "diffusion":
resolve_worker_cls(engine_args_dict)

# Check whether the stage's default_sampling_params defines extra_args.
default_sp = _to_dict(getattr(stage_config, "default_sampling_params", {}))
engine_args_dict["has_sampling_extra_args"] = bool(default_sp.get("extra_args"))

return engine_args_dict


Expand Down
6 changes: 5 additions & 1 deletion vllm_omni/entrypoints/openai/protocol/audio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Literal
from typing import Any, Literal

import numpy as np
from pydantic import AliasChoices, BaseModel, Field, field_validator, model_validator
Expand Down Expand Up @@ -74,6 +74,10 @@ class OpenAICreateSpeechRequest(BaseModel):
ge=0,
description="Per-request initial chunk size override. If null, computed dynamically based on server load.",
)
extra_params: dict[str, Any] | None = Field(
default=None,
description=("Optional model-specific parameters passed directly to the model's extra_args."),
)

@field_validator("stream_format")
@classmethod
Expand Down
18 changes: 17 additions & 1 deletion vllm_omni/entrypoints/openai/serving_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import struct
import time
from concurrent.futures import ThreadPoolExecutor
from http import HTTPStatus
from pathlib import Path
from typing import Any

import numpy as np
import soundfile as sf
import torch
from fastapi import Request, UploadFile
from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import Response, StreamingResponse
from transformers.utils.hub import cached_file
from vllm.entrypoints.launcher import terminate_if_errored
Expand Down Expand Up @@ -1664,6 +1665,21 @@ async def _prepare_speech_generation(
max_tokens,
)

# Apply model-specific extra parameters
if request.extra_params is not None and sampling_params_list:
if not isinstance(request.extra_params, dict):
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="extra_params must be a JSON object/dict.",
)
import copy

sampling_params_list = copy.deepcopy(sampling_params_list)
if sampling_params_list[0].extra_args is None:
sampling_params_list[0].extra_args = {}
sampling_params_list[0].extra_args.update(request.extra_params)
logger.info("Applied extra_params: %s", request.extra_params)

# Fish defaults come from stage_configs YAML. Only override when the caller
# explicitly requests a different generation length.
if self._is_fish_speech and request.max_new_tokens is not None and sampling_params_list:
Expand Down
Loading
Loading