-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Add xcodec2 model #37868
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add xcodec2 model #37868
Changes from all commits
277a96f
349feae
e5f1da8
fc0907c
ea0acbf
8542db7
e98d981
d1cd3ac
74fa506
02f5c94
dd0a17c
3786203
93dbfad
d4d8c6a
c40912e
17eb48c
3760438
a2faa55
31319fb
8d9f8df
e5a1838
473f95a
dd8aace
f6cf875
244bdb6
a84a69f
9d743e8
2e23505
0316080
fcbeab7
3c50dd2
bfe535b
f287f6a
88cc8a7
8eddd59
dda588b
1218679
5378c81
1005594
ee09e64
857562f
a92c67c
de3a7f8
767208b
48a69c2
2dfb96e
1cb9e89
892d8e8
edf0738
185afea
755aaec
2a2d037
af636f9
fcb0ee5
e10b04a
bc50545
5493319
2908abd
cf9fd28
d97b48c
3c0a5c7
4d78a95
bd9f37f
3d365e5
6b812a0
2bfc30a
1911438
dd3f45f
8cfce62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| <!--Copyright 2025 The HuggingFace Team. All rights reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
| the License. You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations under the License. | ||
|
|
||
| ⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
| rendered properly in your Markdown viewer. | ||
|
|
||
| --> | ||
| *This model was released on 2025-02-06 and added to Hugging Face Transformers on 2025-04-29.* | ||
|
|
||
| # X-Codec2 | ||
|
|
||
| <div class="flex flex-wrap space-x-1"> | ||
| <img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
| <img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
| </div> | ||
|
|
||
| ## Overview | ||
|
|
||
| The X-Codec2 model was proposed in [Llasa: Scaling Train-Time and Inference-Time Compute for Llama-based Speech Synthesis](https://huggingface.co/papers/2502.04128). | ||
|
|
||
| X-Codec2 is a neural audio codec designed to improve speech synthesis and general audio generation for large language model (LLM) pipelines. It extends the original X-Codec by refining how semantic and acoustic information is integrated and tokenized, enabling efficient and high-fidelity audio representation. | ||
|
|
||
| Its architecture is based on [X-Codec](./xcodec) with several major differences: | ||
|
|
||
| - **Unified Semantic-Acoustic Tokenization**: X-Codec2 fuses outputs from a semantic encoder (e.g., Wav2Vec2-BERT) and an acoustic encoder into a single embedding, capturing both high-level meaning (e.g., text content, emotion) and low-level audio details (e.g., timbre). | ||
| - **Single-Stage Vector Quantization (VQ)**: Unlike the multi-layer residual VQ in most approaches (e.g., [X-Codec](./xcodec), [DAC](./dac), [EnCodec](./encodec)), X-Codec2 uses a single-layer Feature-Space Quantization (FSQ) for stability and compatibility with causal, autoregressive LLMs. | ||
| - **Semantic Supervision During Training**: It adds a semantic reconstruction loss, ensuring that the discrete tokens preserve meaningful linguistic and emotional information — crucial for TTS tasks. | ||
| - **Transformer-Friendly Design**: The 1D token structure of X-Codec2 naturally aligns with the autoregressive modeling in LLMs like LLaMA, improving training efficiency and downstream compatibility. | ||
|
|
||
| ## Usage example | ||
|
|
||
| Here is a quick example of how to encode and decode an audio using this model: | ||
|
|
||
| ```python | ||
| >>> import torch | ||
| >>> from datasets import Audio, load_dataset | ||
| >>> from transformers import AutoFeatureExtractor, Xcodec2Model | ||
|
|
||
| >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
|
||
| >>> # load model and feature extractor | ||
| >>> model_id = "hf-audio/xcodec2" | ||
| >>> model = Xcodec2Model.from_pretrained(model_id).to(torch_device).eval() | ||
| >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) | ||
|
|
||
| >>> # load data | ||
| >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") | ||
| >>> dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate)) | ||
| >>> audio = dataset[0]["audio"]["array"] | ||
|
|
||
| >>> # prepare data | ||
| >>> inputs = feature_extractor(audio=audio, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt").to(torch_device) | ||
|
|
||
| >>> # encoder and decoder | ||
| >>> audio_codes = model.encode(**inputs).audio_codes | ||
| >>> audio_values = model.decode(audio_codes).audio_values | ||
| >>> # or the equivalent with a forward pass | ||
| >>> model_output = model(**inputs) | ||
| >>> audio_codes = model_output.audio_codes | ||
| >>> audio_values = model_output.audio_values | ||
| ``` | ||
|
|
||
| This model was contributed by [Steven Zheng](https://huggingface.co/Steveeeeeeen) and [Eric Bezzam](https://huggingface.co/bezzam). | ||
| The original code can be found [here](https://github.com/zhenye234/X-Codec-2.0). | ||
|
|
||
|
|
||
| ## Xcodec2Config | ||
|
|
||
| [[autodoc]] Xcodec2Config | ||
|
|
||
| ## Xcodec2FeatureExtractor | ||
|
|
||
| [[autodoc]] Xcodec2FeatureExtractor | ||
| - __call__ | ||
|
|
||
| ## Xcodec2Model | ||
|
|
||
| [[autodoc]] Xcodec2Model | ||
| - decode | ||
| - encode | ||
| - forward | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,12 +33,17 @@ | |
| is_librosa_available, | ||
| is_numpy_array, | ||
| is_soundfile_available, | ||
| is_torch_available, | ||
| is_torch_tensor, | ||
| is_torchcodec_available, | ||
| requires_backends, | ||
| ) | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| if TYPE_CHECKING: | ||
| import torch | ||
|
|
||
|
|
@@ -1032,6 +1037,131 @@ def spectrogram_batch( | |
| return spectrogram_list | ||
|
|
||
|
|
||
| def spectrogram_torch( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added Torch-equivalent to So that torch/GPU is supported by the feature extractor. Could also update |
||
| waveform_list: list["torch.Tensor"], | ||
| window: "torch.Tensor", | ||
| frame_length: int, | ||
| hop_length: int, | ||
| fft_length: Optional[int] = None, | ||
| power: float = 1.0, | ||
| center: bool = True, | ||
| pad_mode: str = "reflect", | ||
| onesided: bool = True, | ||
| dither: float = 0.0, | ||
| preemphasis: Optional[float] = None, | ||
| mel_filters: Optional["torch.Tensor"] = None, | ||
| mel_floor: float = 1e-10, | ||
| log_mel: Optional[str] = None, | ||
| reference: float = 1.0, | ||
| min_value: float = 1e-10, | ||
| db_range: Optional[float] = None, | ||
| remove_dc_offset: Optional[bool] = False, | ||
| device: str = "cpu", | ||
| dtype: str = "float32", | ||
| ): | ||
| """ | ||
| PyTorch version of spectrogram_batch(). | ||
|
|
||
| For spectrogram computation, tensors are promoted to `torch.float64` for better precision, | ||
| and returned according to `dtype`. | ||
| """ | ||
|
|
||
| window_length = len(window) | ||
| if fft_length is None: | ||
| fft_length = frame_length | ||
| if frame_length > fft_length: | ||
| raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})") | ||
| if window_length != frame_length: | ||
| raise ValueError(f"window_length ({window_length}) must equal frame_length ({frame_length})") | ||
| if hop_length <= 0: | ||
| raise ValueError("hop_length must be greater than zero") | ||
| if dtype not in ["float16", "float32", "float64"]: | ||
| raise ValueError(f"dtype must be one of 'float16', 'float32', 'float64', got {dtype}") | ||
| dtype = getattr(torch, dtype) | ||
|
|
||
| # Convert list of waveforms → padded tensor [B, T] | ||
| max_len = max(w.shape[-1] for w in waveform_list) | ||
| padded_waveforms = torch.stack([F.pad(w, (0, max_len - w.shape[-1]), value=0.0) for w in waveform_list]).to( | ||
| device=device, dtype=torch.float64 | ||
| ) | ||
|
|
||
| # Optional centering (reflect pad) | ||
| if center: | ||
| pad_amt = frame_length // 2 | ||
| padded_waveforms = F.pad(padded_waveforms, (pad_amt, pad_amt), mode=pad_mode) | ||
|
|
||
| B, T = padded_waveforms.shape | ||
| num_frames = 1 + (T - frame_length) // hop_length | ||
| fft_func = torch.fft.rfft if onesided else torch.fft.fft | ||
| num_bins = (fft_length // 2 + 1) if onesided else fft_length | ||
|
|
||
| # Promote to float64 for better precision | ||
| window = window.to(device=device, dtype=torch.float64) | ||
| mel_filters = mel_filters.to(device=device, dtype=torch.float64) if mel_filters is not None else None | ||
|
|
||
| # Create output buffer | ||
| spectrogram = torch.empty((B, num_frames, num_bins), dtype=torch.complex128, device=device) | ||
| buffer = torch.zeros((B, fft_length), dtype=torch.float64, device=device) | ||
|
|
||
| for frame_idx in range(num_frames): | ||
| t0 = frame_idx * hop_length | ||
| buffer[:, :frame_length] = padded_waveforms[:, t0 : t0 + frame_length] | ||
|
|
||
| # Dither | ||
| if dither != 0.0: | ||
| buffer[:, :frame_length] += dither * torch.randn_like(buffer[:, :frame_length]) | ||
|
|
||
| # DC offset removal | ||
| if remove_dc_offset: | ||
| buffer[:, :frame_length] -= buffer[:, :frame_length].mean(dim=1, keepdim=True) | ||
|
|
||
| # Preemphasis | ||
| if preemphasis is not None: | ||
| buffer[:, 1:frame_length] -= preemphasis * buffer[:, : frame_length - 1] | ||
| buffer[:, 0] *= 1 - preemphasis | ||
|
|
||
| # Apply window | ||
| buffer[:, :frame_length] *= window | ||
|
|
||
| # FFT | ||
| spectrogram[:, frame_idx] = fft_func(buffer, n=fft_length) | ||
|
|
||
| # Magnitude / power | ||
| if power is not None: | ||
| spectrogram = torch.abs(spectrogram).pow(power) | ||
|
|
||
| # Mel projection | ||
| if mel_filters is not None: | ||
| # spectrogram: [batch, num_frames, num_bins], mel_filters: [num_bins, num_mels] | ||
| if mel_filters.shape[0] != spectrogram.shape[-1]: | ||
| raise ValueError( | ||
| f"Mel filter input bins ({mel_filters.shape[0]}) must match spectrogram frequency bins ({spectrogram.shape[-1]}). " | ||
| f"Please check that mel_filters were designed for fft_length={spectrogram.shape[-1] * 2 - 2 if spectrogram.shape[-1] > 1 else spectrogram.shape[-1]}." | ||
| ) | ||
| spectrogram = torch.matmul(spectrogram, mel_filters) | ||
| spectrogram = torch.maximum(spectrogram, torch.tensor(mel_floor, device=device)) | ||
|
|
||
| # Log scaling | ||
| if power is not None and log_mel is not None: | ||
| if log_mel == "log": | ||
| spectrogram = torch.log(torch.clamp(spectrogram, min=min_value)) | ||
| elif log_mel == "log10": | ||
| spectrogram = torch.log10(torch.clamp(spectrogram, min=min_value)) | ||
| elif log_mel == "dB": | ||
| ref = torch.tensor(reference, device=device) | ||
| spectrogram = 10.0 * torch.log10(torch.clamp(spectrogram / ref, min=min_value)) | ||
| if db_range is not None: | ||
| max_val = spectrogram.amax(dim=-1, keepdim=True) | ||
| spectrogram = torch.maximum(spectrogram, max_val - db_range) | ||
| else: | ||
| raise ValueError(f"Unknown log_mel option: {log_mel}") | ||
|
|
||
| # Return list of [num_bins, num_frames_i] | ||
| true_frames = [1 + (w.shape[-1] - frame_length) // hop_length for w in waveform_list] | ||
| spec_list = [spectrogram[i, :n, :].T.to(dtype) for i, n in enumerate(true_frames)] | ||
| return spec_list | ||
|
|
||
|
|
||
| def power_to_db( | ||
| spectrogram: np.ndarray, | ||
| reference: float = 1.0, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -117,8 +117,7 @@ def _extract_fbank_features( | |
| waveform: np.ndarray, | ||
| ) -> np.ndarray: | ||
| """ | ||
| Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs | ||
| and hence the waveform should not be normalized before feature extraction. | ||
| Get mel-filter bank features using Numpy method to mimic Kaldi. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update docstring since it wasn't using TorchAudio! |
||
| """ | ||
| # by default, it extracts the left channel if stereo | ||
| if len(waveform.shape) == 2: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -558,7 +558,7 @@ def forward( | |
|
|
||
| >>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="pt") | ||
|
|
||
| >>> outputs = model(**inputs) | ||
| >>> outputs = model(inputs["input_values"]) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @eustlb DAC, Xcodec, and Xcodec2 don't support
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the codecs we currently have:
Is the model inherently uncompatible with a padding mask approach or is it just not implemented in the original codebase?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Padding has been added to input, see how it's used here |
||
| >>> audio_codes = outputs.audio_codes | ||
| >>> audio_values = outputs.audio_values | ||
| ``` | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| # Copyright 2025 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from ...utils import _LazyModule | ||
| from ...utils.import_utils import define_import_structure | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from .configuration_xcodec2 import * | ||
| from .feature_extraction_xcodec2 import * | ||
| from .modeling_xcodec2 import * | ||
| else: | ||
| import sys | ||
|
|
||
| _file = globals()["__file__"] | ||
| sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) |
Uh oh!
There was an error while loading. Please reload this page.