[Voxtral TTS] Support voxtral tts cfg_alpha sampling params via temperature#2243
[Voxtral TTS] Support voxtral tts cfg_alpha sampling params via temperature#2243y123456y78 wants to merge 20 commits into
Conversation
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
ea989bd to
03f3d9b
Compare
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
Signed-off-by: Chen-Yo Sun <chenyo.sun@mistral.ai>
lishunyang12
left a comment
There was a problem hiding this comment.
left a couple comments
| sampling_metadata = kwargs.get("sampling_metadata") | ||
| if sampling_metadata is None or sampling_metadata.temperature is None: | ||
| raise ValueError( | ||
| "VoxtralTTS requires a non-zero 'temperature' sampling parameter (used as cfg_alpha for flow-matching)." |
There was a problem hiding this comment.
This will silently accept temperature=0, which would zero out the conditional velocity entirely (pure unconditional generation). Should probably validate temperature > 0 here, or at least != 0.
| final_output_type: text | ||
| default_sampling_params: | ||
| temperature: 0.0 | ||
| # NOTE: VoxtralTTS repurposes 'temperature' as the CFG alpha |
There was a problem hiding this comment.
Nit: might be worth adding a user-facing note somewhere (CLI help, docs) that temperature controls CFG strength for voxtral-tts — otherwise people will set temperature=0.7 expecting normal sampling behavior and get confused.
| padded_size = self._get_padded_size(actual_size) | ||
| if padded_size is None or padded_size not in self.graphs: | ||
| return self.model.compute_mm_logits(hidden_states) | ||
| return self.model.compute_mm_logits(hidden_states, cfg_alpha=cfg_alpha) |
There was a problem hiding this comment.
The 1D -> 2D reshape (unsqueeze(1)) happens inside decode_one_frame, but in the graph path static_cfg_alpha is already (size, 1). This means the eager fallback via compute_mm_logits will unsqueeze, but the graph path skips it. Works today but the shape contract is fragile — a comment on the expected shape at this interface would help.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
Resolve conflicts. |
Purpose
Testing Plan
Result