Skip to content

Conversation

@Manalelaidouni
Copy link
Contributor

What does this PR do?

This PR aims at integrating Vocos model to transformers.

Vocos is a neural vocoder designed for high quality audio synthesis in TTS pipelines and related tasks, outpeforms HifiGan and it is significantly faster. It has 2 main variants :

  • VocosModel can be used as a standalone vocoder in audio generation pipeline, the goal is to use it as a drop in vocoder in YuE model. It can also be used together with VocosFeatureExtractor to synthesis audio from mel-spectrogram features.
  • VocosWithEncodecModel : integrates the EnCodec neural audio codec model into Vocos for end-to-end audio compression and reconstruction.

This is a continuation of integrating model components for the new YuE model (mention in #36784).

Who can review?

Anyone in the community is free to review the PR once the tests have passed.
@ArthurZucker @eustlb @ylacombe

@Manalelaidouni Manalelaidouni marked this pull request as draft July 14, 2025 22:50
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! My main comment is to remove the hidden states post processing!

@ArthurZucker ArthurZucker requested a review from eustlb July 16, 2025 13:33
@Manalelaidouni Manalelaidouni marked this pull request as ready for review July 22, 2025 13:07
@Manalelaidouni Manalelaidouni marked this pull request as draft July 22, 2025 13:26
@Manalelaidouni Manalelaidouni marked this pull request as ready for review July 22, 2025 15:29
@Manalelaidouni
Copy link
Contributor Author

Manalelaidouni commented Jul 22, 2025

Thanks for reviewing! the failing tests seem unrelated to my changes, but I realized the latest datasets 4.0.0 loads different audio samples than earlier versions which was causing integration tests to fail in CI.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for my late review!

@ArthurZucker
Copy link
Collaborator

If you can merge main adress the small comment and we can merge!

@ebezzam
Copy link
Contributor

ebezzam commented Oct 10, 2025

run-slow: vocos

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/vocos']
quantizations: [] ...

@Manalelaidouni
Copy link
Contributor Author

@Manalelaidouni Thanks! Tests are again passing on my machine and will try GitHub actions soon.

Also I flipped back again to the functional ISTFT as you used to have 🙈 because @eustlb and I had the same idea of putting it in audio utils.

Nice, great idea actually, now it looks like the model is good to go right?

@ebezzam ebezzam requested a review from eustlb October 10, 2025 14:43
Copy link
Contributor

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we see that different codepaths in the processor/ model are used in different situations, and cannot be used in a crossed manner.

  • if we use the mel spectrogram inputs, then we should not have an adaptative layernorm
  • if we use the non-mel spectrogram inputs, then we should have an adaptative layernorm.

We should therefore have two models Vocos, and VocosEncodec.
VocosEncodec should be defined using modular, simply replacing the norm by a VocosAdaptativeLayerNorm.

We should therefore have two models Vocos, and VocosEncodec.
VocosEncodec should be defined using modular, simply replacing the norm by a VocosAdaptativeLayerNorm.

Moreover it should have an embedding layer that does the input_features preparation that is currently in the processing for audio codes → inputs_embeds, with a torch.no_grad(). I do get that we duplicate the weights necessary for codebook embedding in processor and the model, yet it's only 4MB and I'd rather fix the inputs of the model in token ids rather, so that it can be used without the processor given you have encodec codebook token ids. Morever, by passing codes directly, we do not need to pass the bandwidth id to VocosEncodec, which makes no sense to pass in the forward since it can be directly inferred from the input_ids shape.

VocosEncodec would simply have in it's modular something like:

class VocosEncodecModel(VocosModel)
	def __init__(self, config):
		super.__init__(config)
		self.embed_tokens = nn.embedding()
		self.offsets = register_buffer(torch.arange(config.num_codebooks) * codebook_size, persistent=False)
		self.norms = nn.ModuleList(...)
        del self.norm

	def forward(
		input_ids, # shape (batch_size, seq_len, num_codebooks)
		inputs_embeds=None,
	)
		if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
            
        num_codebooks = input_ids.shape[2]
        if num_codebooks not in self.config.supported_num_codebooks:
	        raise ValueError(f"{num_codebooks}...{self.config.supported_num_codebooks}")
	    else:
		    norm_idx = self.config.supported_num_codebooks.index(num_codebooks)
        self.norm = self.norms[norm_idx]
			
		if inputs_embeds is None:
			inputs_embeds = self.embed_tokens(audio_input_ids + self.offsets[:num_codebooks])
			
		super().forward(input_features=inputs_embeds)

Comment on lines +44 to +46

The original code can be found [here](https://github.com/gemelo-ai/vocos) and original checkpoints [here](https://huggingface.co/charactr/vocos-mel-24khz) and [here](https://huggingface.co/charactr/vocos-encodec-24khz).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd want (best before merging if authors are responsive) to merge converted checkpoint directly to their repo, way better than having them under hf-audio

# load the Bark model and processor
bark_id = "suno/bark-small"
bark_processor = BarkProcessor.from_pretrained(bark_id)
bark = BarkModel.from_pretrained(bark_id, device_map="auto")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device_map auto fails here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm works for me. What error do you get?

Comment on lines 617 to 693
def istft(input, n_fft: int, padding=None, **kwargs) -> "torch.Tensor":
"""
Performs the Inverse Short Time Fourier Transform (ISTFT) on STFT coefficients to reconstruct audio in the time domain.
Adds support for `same` padding as in Vocos:
https://github.com/gemelo-ai/vocos/blob/c859e3b7b534f3776a357983029d34170ddd6fc3/vocos/spectral_ops.py#L7
Otherwise falls back to PyTorch's built-in ISTFT implementation `torch.istft`.
Args:
input (`torch.Tensor`): Complex-valued STFT coefficients of shape (batch_size, freq_bins, time_frames).
n_fft (`int`): Size of the FFT.
padding (`str`, *optional*): Padding mode. Either "center" or "same".
**kwargs: Additional arguments passed to torch.istft or used for "same" padding:
- win_length (`int`, *optional*): Window length. Defaults to n_fft.
- hop_length (`int`, *optional*): Hop length. Defaults to n_fft // 4.
- window (`torch.Tensor`, *optional*): Window function. Defaults to Hann window.
- center (`bool`, *optional*): Used only for "center" padding mode.
Returns:
`torch.Tensor`: Reconstructed audio waveform.
It computes ISTFT differently depending on padding:
if `center` : uses PyTorch's built-in ISTFT implementation since it uses `center=True` by default.
if `same` : uses custom implementation of ISTFT with the overlap-add method, since the Pytorch version fails the
Nonzero Overlap Add (NOLA) condition when center is False. See issue: https://github.com/pytorch/pytorch/issues/62323
"""
requires_backends(istft, ["torch"])

if padding == "center" or padding is None:
# user may provide center=False in kwargs
center = kwargs.get("center", True)
audio = torch.istft(
input,
n_fft=n_fft,
center=center,
**kwargs,
)

elif padding == "same":
win_length = kwargs.get("win_length", n_fft)
hop_length = kwargs.get("hop_length", n_fft // 4)
window = kwargs.get("window", torch.hann_window(win_length))

_, _, num_time_frames = input.shape
pad = (win_length - hop_length) // 2
# the inverse FFT of each frame
inverse_fft = torch.fft.irfft(input, n=n_fft, dim=1, norm="backward")
inverse_fft = inverse_fft * window[None, :, None]

# combine the overlapping frame with windowing and normalizing by the sum of squared window values across overlapping frames
# to make sure the reconstruction of the audio is accurate
output_length = (num_time_frames - 1) * hop_length + win_length
audio = F.fold(
inverse_fft,
output_size=(1, output_length),
kernel_size=(1, win_length),
stride=(1, hop_length),
)[:, 0, 0, pad:-pad]
window_sqrt = window.square().expand(1, num_time_frames, -1).transpose(1, 2)
norm = F.fold(
window_sqrt,
output_size=(1, output_length),
kernel_size=(1, win_length),
stride=(1, hop_length),
).squeeze()[pad:-pad]

if torch.any(norm <= 1e-11):
raise ValueError(
"Normalization tensor `norm` contains values ≤ 1e-11, it would cause division by zero. check the n_fft, hop_length and padding parameters."
)
audio = audio / norm

else:
raise ValueError(f"Unsupported padding mode: {padding}. Supported modes are 'center' and 'same'.")

return audio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice initiative but let's revert and simply add a TODO in the original code. Such an important function would require proper testing! work that has to be done when refacto audio_utils

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've put in the modeling file!

FYI while Mel variant uses "center" Encodec variant and Xcodec2 uses "same"

Comment on lines 233 to 234
audio_spectrogram: Optional[torch.FloatTensor] = None,
input_features: Optional[torch.FloatTensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and what if the user provides both? we silently use input_features. This happens because there is no reason to provide both, so such a use case has no reason to be integrated in the model forward signature. input_feature and audio_spectrogram are both input_features (without considering the renaming, let's keep it to input_features for now).

Comment on lines 49 to 50
use_adaptive_norm (`bool`, *optional*, defaults to `False`):
Whether to use adaptive layer normalization.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we have bandwidths that can be setted and adaptative norm not used? that seem overcomplicated, especially when looking at config.json, user will see bandwiths that are not used. Let's target expliciteness: either bandwidths are specified and in this case we use them, either it is None and therefore we do not use them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, removing this.

audio_spectrogram: Optional[torch.FloatTensor] = None,
input_features: Optional[torch.FloatTensor] = None,
bandwidth: Optional[float] = None,
**kwargs: Unpack[TransformersKwargs],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we handle kwargs here? and why should it be TransformersKwargs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because of padding_mask that could be inputted if we do something like:

inputs = processor(audio=audios, return_tensors="pt")
output = model(**inputs)  # `inputs.padding_mask` would be inputted

Is it ok to have an unused padding_mask input to forward for the convenience of model(**inputs)?

The padding_mask is useful for trimming individual audios (at the output) in the case of batch processing.

Copy link
Contributor

@eustlb eustlb Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep of course I know what a padding_mask is :) and in this case we use a padding_mask kwarg directly if necessary. As you can see, padding_mask is not a TransformersKwargs. To the best of my knowledge there is nothing forcing us from returning a padding_mask in the processor

If padding_mask is not supported, then the processor should not return padding_mask

Copy link
Contributor

@ebezzam ebezzam Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woops sorry I know you do 😄 I meant more to explain why it's here but not used.

Do you think it could be useful for below batch usage outside of modeling code? (to remove known padding)

inputs = processor(audio=audio, bandwidth=bandwidth, sampling_rate=sampling_rate)
outputs = model(**inputs)
audio_vocos = outputs.audio

# use padding mask to extract audio with same length as original `audio`
for i in range(audio_vocos.shape[0]):
    # remove padding
    padding_mask = inputs.padding_mask[i].bool()
    valid_audio = audio_vocos[i][padding_mask]

Because same padding (used by Encodec approach) actually pads before the audio (original which is also accounted for in the processor for both audio and padding mask), so just truncating according to original length is not enough.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it should be taken as an input, and returned in the output and documented, ie using directly outputs.padding_mask, something similar to that

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, vocos

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, vocos

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, vocos

@Manalelaidouni
Copy link
Contributor Author

@eustlb @ebezzam I’ll work on splitting this back into two models Vocos and VocosEncodec

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, vocos, vocos_encodec

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, vocos, vocos_encodec

@ebezzam
Copy link
Contributor

ebezzam commented Oct 21, 2025

@Manalelaidouni thanks! I've started the splitting. Before you start/continue, let me confirm something with @eustlb first about the Encodec variant.

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, vocos, vocos_encodec

@Manalelaidouni
Copy link
Contributor Author

hey @ebezzam are you done with the changes so I can handle the rest?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants