Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
277a96f
Add xcodec model
Apr 29, 2025
349feae
code formatting
Apr 29, 2025
e5f1da8
typo xcodec2 name
Deep-unlearning May 20, 2025
fc0907c
add xcodec2 in init file
Deep-unlearning May 20, 2025
ea0acbf
fix import
Deep-unlearning May 20, 2025
8542db7
fix weight_norm init
Deep-unlearning May 21, 2025
e98d981
remove unused import
Deep-unlearning May 21, 2025
d1cd3ac
add convert file
Deep-unlearning May 26, 2025
74fa506
add ModelOutput class
Deep-unlearning May 26, 2025
02f5c94
nit
Deep-unlearning May 26, 2025
dd0a17c
fix device issue
May 27, 2025
3786203
fix forward
May 27, 2025
93dbfad
nit
May 27, 2025
d4d8c6a
doc draft
May 27, 2025
c40912e
draft test
May 27, 2025
17eb48c
match tensor with the orignal implementation
Jun 3, 2025
3760438
Add doc file for xcodec2
Jun 4, 2025
a2faa55
finish model doc for xcodec2
Jun 4, 2025
31319fb
update doc
Deep-unlearning Jun 5, 2025
8d9f8df
working xcodec2
Deep-unlearning Jun 5, 2025
e5a1838
add test file for xcodec2
Deep-unlearning Jun 5, 2025
473f95a
nit
Deep-unlearning Jun 24, 2025
dd8aace
xcodec2 use EncodecFeatureExtractor
Deep-unlearning Jul 7, 2025
f6cf875
Merge branch 'main' into add-xcodec2
ebezzam Aug 26, 2025
244bdb6
Standardize with Xcodec.
ebezzam Aug 29, 2025
a84a69f
Merge branch 'main' into add-xcodec2
ebezzam Aug 29, 2025
9d743e8
Merge branch 'main' into add-xcodec2
ebezzam Aug 29, 2025
2e23505
Address some PR comments and standardize.
ebezzam Sep 2, 2025
0316080
Remove Sequential.
ebezzam Sep 3, 2025
fcbeab7
Remove weight norm from model definition.
ebezzam Sep 3, 2025
3c50dd2
Remove padding.
ebezzam Sep 3, 2025
bfe535b
Better use of modular and better init.
ebezzam Sep 3, 2025
f287f6a
Style and format checks.
ebezzam Sep 3, 2025
88cc8a7
Address some modeling tests.
ebezzam Sep 4, 2025
8eddd59
Better use of modular for Attention and cleanup.
ebezzam Sep 5, 2025
dda588b
Remove asserts, expose params in config.
ebezzam Sep 5, 2025
1218679
Clean up internal and better docstrings.
ebezzam Sep 5, 2025
5378c81
Fix example.
ebezzam Sep 5, 2025
1005594
Make clear how self.causal is used
ebezzam Sep 6, 2025
ee09e64
Add padding consistent to original.
ebezzam Sep 6, 2025
857562f
Clean up and matching integration test.
ebezzam Sep 8, 2025
a92c67c
Update docs and clean up.
ebezzam Sep 8, 2025
de3a7f8
Correct doc location among audio models.
ebezzam Sep 8, 2025
767208b
Remove flash tests.
ebezzam Sep 8, 2025
48a69c2
Skip flash attention tests.
ebezzam Sep 8, 2025
2dfb96e
Switch to processor.
ebezzam Sep 26, 2025
1cb9e89
New feature extractor.
ebezzam Oct 2, 2025
892d8e8
Update modular.
ebezzam Oct 2, 2025
edf0738
Add feature extractor.
ebezzam Oct 2, 2025
185afea
Clean feature extractor and expose some parameters.
ebezzam Oct 2, 2025
755aaec
Clean up feature extractor and add torch support (wip).
ebezzam Oct 2, 2025
2a2d037
Clean up seamless feature extractor.
ebezzam Oct 2, 2025
af636f9
Modular cleanup.
ebezzam Oct 2, 2025
fcb0ee5
Remove processor.
ebezzam Oct 2, 2025
e10b04a
Update docs.
ebezzam Oct 2, 2025
bc50545
Fix copy paths.
ebezzam Oct 3, 2025
5493319
Fix modeling tests for audio_spectrogram input.
ebezzam Oct 3, 2025
2908abd
Fix torch support with new spectrogram torch utility.
ebezzam Oct 3, 2025
cf9fd28
Repo consistency.
ebezzam Oct 3, 2025
d97b48c
Feature extraction tests.
ebezzam Oct 3, 2025
3c0a5c7
Merge branch 'main' into add-xcodec2
ebezzam Oct 3, 2025
4d78a95
Make style happy
ebezzam Oct 3, 2025
bd9f37f
Remove unprotected import.
ebezzam Oct 3, 2025
3d365e5
Another unprotected import.
ebezzam Oct 3, 2025
6b812a0
Remove more unprotected torches.
ebezzam Oct 3, 2025
2bfc30a
zero_mean_unit_var_norm needed for a test
ebezzam Oct 3, 2025
1911438
Modify Vocos component to be able to use modular later.
ebezzam Oct 3, 2025
dd3f45f
Update modular
ebezzam Oct 3, 2025
8cfce62
Make style happy.
ebezzam Oct 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def pytest_configure(config):
config.addinivalue_line("markers", "torch_compile_test: mark test which tests torch compile functionality")
config.addinivalue_line("markers", "torch_export_test: mark test which tests torch export functionality")

os.environ['DISABLE_SAFETENSORS_CONVERSION'] = 'true'
os.environ["DISABLE_SAFETENSORS_CONVERSION"] = "true"


def pytest_collection_modifyitems(items):
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,8 @@
title: Whisper
- local: model_doc/xcodec
title: X-Codec
- local: model_doc/xcodec2
title: X-Codec2
- local: model_doc/xls_r
title: XLS-R
- local: model_doc/xlsr_wav2vec2
Expand Down
89 changes: 89 additions & 0 deletions docs/source/en/model_doc/xcodec2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->
*This model was released on 2025-02-06 and added to Hugging Face Transformers on 2025-04-29.*

# X-Codec2

<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

## Overview

The X-Codec2 model was proposed in [Llasa: Scaling Train-Time and Inference-Time Compute for Llama-based Speech Synthesis](https://huggingface.co/papers/2502.04128).

X-Codec2 is a neural audio codec designed to improve speech synthesis and general audio generation for large language model (LLM) pipelines. It extends the original X-Codec by refining how semantic and acoustic information is integrated and tokenized, enabling efficient and high-fidelity audio representation.

Its architecture is based on [X-Codec](./xcodec) with several major differences:

- **Unified Semantic-Acoustic Tokenization**: X-Codec2 fuses outputs from a semantic encoder (e.g., Wav2Vec2-BERT) and an acoustic encoder into a single embedding, capturing both high-level meaning (e.g., text content, emotion) and low-level audio details (e.g., timbre).
- **Single-Stage Vector Quantization (VQ)**: Unlike the multi-layer residual VQ in most approaches (e.g., [X-Codec](./xcodec), [DAC](./dac), [EnCodec](./encodec)), X-Codec2 uses a single-layer Feature-Space Quantization (FSQ) for stability and compatibility with causal, autoregressive LLMs.
- **Semantic Supervision During Training**: It adds a semantic reconstruction loss, ensuring that the discrete tokens preserve meaningful linguistic and emotional information — crucial for TTS tasks.
- **Transformer-Friendly Design**: The 1D token structure of X-Codec2 naturally aligns with the autoregressive modeling in LLMs like LLaMA, improving training efficiency and downstream compatibility.

## Usage example

Here is a quick example of how to encode and decode an audio using this model:

```python
>>> import torch
>>> from datasets import Audio, load_dataset
>>> from transformers import AutoFeatureExtractor, Xcodec2Model

>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"

>>> # load model and feature extractor
>>> model_id = "hf-audio/xcodec2"
>>> model = Xcodec2Model.from_pretrained(model_id).to(torch_device).eval()
>>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)

>>> # load data
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
>>> audio = dataset[0]["audio"]["array"]

>>> # prepare data
>>> inputs = feature_extractor(audio=audio, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt").to(torch_device)

>>> # encoder and decoder
>>> audio_codes = model.encode(**inputs).audio_codes
>>> audio_values = model.decode(audio_codes).audio_values
>>> # or the equivalent with a forward pass
>>> model_output = model(**inputs)
>>> audio_codes = model_output.audio_codes
>>> audio_values = model_output.audio_values
```

This model was contributed by [Steven Zheng](https://huggingface.co/Steveeeeeeen) and [Eric Bezzam](https://huggingface.co/bezzam).
The original code can be found [here](https://github.com/zhenye234/X-Codec-2.0).


## Xcodec2Config

[[autodoc]] Xcodec2Config

## Xcodec2FeatureExtractor

[[autodoc]] Xcodec2FeatureExtractor
- __call__

## Xcodec2Model

[[autodoc]] Xcodec2Model
- decode
- encode
- forward
130 changes: 130 additions & 0 deletions src/transformers/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,17 @@
is_librosa_available,
is_numpy_array,
is_soundfile_available,
is_torch_available,
is_torch_tensor,
is_torchcodec_available,
requires_backends,
)


if is_torch_available():
import torch
import torch.nn.functional as F

if TYPE_CHECKING:
import torch

Expand Down Expand Up @@ -1032,6 +1037,131 @@ def spectrogram_batch(
return spectrogram_list


def spectrogram_torch(
Copy link
Contributor

@ebezzam ebezzam Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added Torch-equivalent to spectrogram_batched (namely Mel feature extraction with Kaldi-style pre-processing which I didn't see supported in other torch implementations)

So that torch/GPU is supported by the feature extractor. Could also update SeamlesssM4T?

waveform_list: list["torch.Tensor"],
window: "torch.Tensor",
frame_length: int,
hop_length: int,
fft_length: Optional[int] = None,
power: float = 1.0,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
dither: float = 0.0,
preemphasis: Optional[float] = None,
mel_filters: Optional["torch.Tensor"] = None,
mel_floor: float = 1e-10,
log_mel: Optional[str] = None,
reference: float = 1.0,
min_value: float = 1e-10,
db_range: Optional[float] = None,
remove_dc_offset: Optional[bool] = False,
device: str = "cpu",
dtype: str = "float32",
):
"""
PyTorch version of spectrogram_batch().

For spectrogram computation, tensors are promoted to `torch.float64` for better precision,
and returned according to `dtype`.
"""

window_length = len(window)
if fft_length is None:
fft_length = frame_length
if frame_length > fft_length:
raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
if window_length != frame_length:
raise ValueError(f"window_length ({window_length}) must equal frame_length ({frame_length})")
if hop_length <= 0:
raise ValueError("hop_length must be greater than zero")
if dtype not in ["float16", "float32", "float64"]:
raise ValueError(f"dtype must be one of 'float16', 'float32', 'float64', got {dtype}")
dtype = getattr(torch, dtype)

# Convert list of waveforms → padded tensor [B, T]
max_len = max(w.shape[-1] for w in waveform_list)
padded_waveforms = torch.stack([F.pad(w, (0, max_len - w.shape[-1]), value=0.0) for w in waveform_list]).to(
device=device, dtype=torch.float64
)

# Optional centering (reflect pad)
if center:
pad_amt = frame_length // 2
padded_waveforms = F.pad(padded_waveforms, (pad_amt, pad_amt), mode=pad_mode)

B, T = padded_waveforms.shape
num_frames = 1 + (T - frame_length) // hop_length
fft_func = torch.fft.rfft if onesided else torch.fft.fft
num_bins = (fft_length // 2 + 1) if onesided else fft_length

# Promote to float64 for better precision
window = window.to(device=device, dtype=torch.float64)
mel_filters = mel_filters.to(device=device, dtype=torch.float64) if mel_filters is not None else None

# Create output buffer
spectrogram = torch.empty((B, num_frames, num_bins), dtype=torch.complex128, device=device)
buffer = torch.zeros((B, fft_length), dtype=torch.float64, device=device)

for frame_idx in range(num_frames):
t0 = frame_idx * hop_length
buffer[:, :frame_length] = padded_waveforms[:, t0 : t0 + frame_length]

# Dither
if dither != 0.0:
buffer[:, :frame_length] += dither * torch.randn_like(buffer[:, :frame_length])

# DC offset removal
if remove_dc_offset:
buffer[:, :frame_length] -= buffer[:, :frame_length].mean(dim=1, keepdim=True)

# Preemphasis
if preemphasis is not None:
buffer[:, 1:frame_length] -= preemphasis * buffer[:, : frame_length - 1]
buffer[:, 0] *= 1 - preemphasis

# Apply window
buffer[:, :frame_length] *= window

# FFT
spectrogram[:, frame_idx] = fft_func(buffer, n=fft_length)

# Magnitude / power
if power is not None:
spectrogram = torch.abs(spectrogram).pow(power)

# Mel projection
if mel_filters is not None:
# spectrogram: [batch, num_frames, num_bins], mel_filters: [num_bins, num_mels]
if mel_filters.shape[0] != spectrogram.shape[-1]:
raise ValueError(
f"Mel filter input bins ({mel_filters.shape[0]}) must match spectrogram frequency bins ({spectrogram.shape[-1]}). "
f"Please check that mel_filters were designed for fft_length={spectrogram.shape[-1] * 2 - 2 if spectrogram.shape[-1] > 1 else spectrogram.shape[-1]}."
)
spectrogram = torch.matmul(spectrogram, mel_filters)
spectrogram = torch.maximum(spectrogram, torch.tensor(mel_floor, device=device))

# Log scaling
if power is not None and log_mel is not None:
if log_mel == "log":
spectrogram = torch.log(torch.clamp(spectrogram, min=min_value))
elif log_mel == "log10":
spectrogram = torch.log10(torch.clamp(spectrogram, min=min_value))
elif log_mel == "dB":
ref = torch.tensor(reference, device=device)
spectrogram = 10.0 * torch.log10(torch.clamp(spectrogram / ref, min=min_value))
if db_range is not None:
max_val = spectrogram.amax(dim=-1, keepdim=True)
spectrogram = torch.maximum(spectrogram, max_val - db_range)
else:
raise ValueError(f"Unknown log_mel option: {log_mel}")

# Return list of [num_bins, num_frames_i]
true_frames = [1 + (w.shape[-1] - frame_length) // hop_length for w in waveform_list]
spec_list = [spectrogram[i, :n, :].T.to(dtype) for i, n in enumerate(true_frames)]
return spec_list


def power_to_db(
spectrogram: np.ndarray,
reference: float = 1.0,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@
from .whisper import *
from .x_clip import *
from .xcodec import *
from .xcodec2 import *
from .xglm import *
from .xlm import *
from .xlm_roberta import *
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@
("whisper", "WhisperConfig"),
("xclip", "XCLIPConfig"),
("xcodec", "XcodecConfig"),
("xcodec2", "Xcodec2Config"),
("xglm", "XGLMConfig"),
("xlm", "XLMConfig"),
("xlm-prophetnet", "XLMProphetNetConfig"),
Expand Down Expand Up @@ -904,6 +905,7 @@
("whisper", "Whisper"),
("xclip", "X-CLIP"),
("xcodec", "X-CODEC"),
("xcodec2", "X-CODEC2"),
("xglm", "XGLM"),
("xlm", "XLM"),
("xlm-prophetnet", "XLM-ProphetNet"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/feature_extraction_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
("wavlm", "Wav2Vec2FeatureExtractor"),
("whisper", "WhisperFeatureExtractor"),
("xcodec", "DacFeatureExtractor"),
("xcodec2", "Xcodec2FeatureExtractor"),
]
)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("whisper", "WhisperModel"),
("xclip", "XCLIPModel"),
("xcodec", "XcodecModel"),
("xcodec2", "Xcodec2Model"),
("xglm", "XGLMModel"),
("xlm", "XLMModel"),
("xlm-prophetnet", "XLMProphetNetModel"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ def _extract_fbank_features(
waveform: np.ndarray,
) -> np.ndarray:
"""
Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs
and hence the waveform should not be normalized before feature extraction.
Get mel-filter bank features using Numpy method to mimic Kaldi.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstring since it wasn't using TorchAudio!

"""
# by default, it extracts the left channel if stereo
if len(waveform.shape) == 2:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/xcodec/modeling_xcodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def forward(

>>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="pt")

>>> outputs = model(**inputs)
>>> outputs = model(inputs["input_values"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eustlb DAC, Xcodec, and Xcodec2 don't support model(**inputs) as padding_mask is not an accepted input. Is that fine? or should padding_mask be added as an input even if it isn't used?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is padding_mask not an accepted input? Shouldn't it be accepted for batched inference?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the codecs we currently have:

  • dac: not supported
  • encodec: does support paddind mask
  • mimi: padding_mask supported, but not used
  • xcodec: not supported

Is the model inherently uncompatible with a padding mask approach or is it just not implemented in the original codebase?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Padding has been added to input, see how it's used here

>>> audio_codes = outputs.audio_codes
>>> audio_values = outputs.audio_values
```
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/models/xcodec2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_xcodec2 import *
from .feature_extraction_xcodec2 import *
from .modeling_xcodec2 import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading