Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions vllm_omni/entrypoints/openai/serving_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
| _MING_TTS_MODEL_STAGES
| _MOSS_TTS_MODEL_STAGES
)
_SAMPLING_MAX_TOKENS_TTS_MODEL_TYPES = {"fish_tts", "qwen3_tts", "voxtral_tts", "cosyvoice3", "voxcpm2"}
_TTS_LANGUAGES: set[str] = {
"Auto",
"Chinese",
Expand Down Expand Up @@ -876,6 +877,11 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non
if self._tts_model_type == "voxcpm":
return self._validate_voxcpm_request(request)
if self._tts_model_type == "voxcpm2":
if request.max_new_tokens is not None:
if request.max_new_tokens < _TTS_MAX_NEW_TOKENS_MIN:
return f"max_new_tokens must be at least {_TTS_MAX_NEW_TOKENS_MIN}"
if request.max_new_tokens > _TTS_MAX_NEW_TOKENS_MAX:
return f"max_new_tokens cannot exceed {_TTS_MAX_NEW_TOKENS_MAX}"
return None # VoxCPM2 accepts any text input
if self._tts_model_type == "ming_flash_omni_tts":
return self._validate_ming_tts_request(request)
Expand Down Expand Up @@ -1850,13 +1856,23 @@ async def _prepare_speech_generation(
sampling_params_list[0].extra_args.update(request.extra_params)
logger.info("Applied extra_params: %s", request.extra_params)

# Fish defaults come from stage_configs YAML. Only override when the caller
# explicitly requests a different generation length.
if self._is_fish_speech and request.max_new_tokens is not None and sampling_params_list:
# Some TTS model defaults come from deploy YAML. Their AR
# generation length is controlled by SamplingParams.max_tokens, so only
# override it when the caller explicitly requests max_new_tokens.
if (
self._tts_model_type in _SAMPLING_MAX_TOKENS_TTS_MODEL_TYPES
and request.max_new_tokens is not None
and sampling_params_list
):
import copy

sampling_params_list = copy.deepcopy(sampling_params_list)
sampling_params_list[0].max_tokens = request.max_new_tokens
if self._tts_model_type == "cosyvoice3":
sampling_params_list[0].min_tokens = min(
getattr(sampling_params_list[0], "min_tokens", 0),
request.max_new_tokens,
)

# Propagate per-request seed to sampling params so both Slow AR
# and Fast AR produce deterministic output for the same seed.
Expand Down
Loading