-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Add xcodec2 model #37868
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add xcodec2 model #37868
Changes from 45 commits
277a96f
349feae
e5f1da8
fc0907c
ea0acbf
8542db7
e98d981
d1cd3ac
74fa506
02f5c94
dd0a17c
3786203
93dbfad
d4d8c6a
c40912e
17eb48c
3760438
a2faa55
31319fb
8d9f8df
e5a1838
473f95a
dd8aace
f6cf875
244bdb6
a84a69f
9d743e8
2e23505
0316080
fcbeab7
3c50dd2
bfe535b
f287f6a
88cc8a7
8eddd59
dda588b
1218679
5378c81
1005594
ee09e64
857562f
a92c67c
de3a7f8
767208b
48a69c2
2dfb96e
1cb9e89
892d8e8
edf0738
185afea
755aaec
2a2d037
af636f9
fcb0ee5
e10b04a
bc50545
5493319
2908abd
cf9fd28
d97b48c
3c0a5c7
4d78a95
bd9f37f
3d365e5
6b812a0
2bfc30a
1911438
dd3f45f
8cfce62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| <!--Copyright 2025 The HuggingFace Team. All rights reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
| the License. You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations under the License. | ||
|
|
||
| ⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
| rendered properly in your Markdown viewer. | ||
|
|
||
| --> | ||
| *This model was released on 2025-02-06 and added to Hugging Face Transformers on 2025-04-29.* | ||
|
|
||
| # X-Codec2 | ||
|
|
||
| <div class="flex flex-wrap space-x-1"> | ||
| <img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
| <img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
| </div> | ||
|
|
||
| ## Overview | ||
|
|
||
| The X-Codec2 model was proposed in [Llasa: Scaling Train-Time and Inference-Time Compute for Llama-based Speech Synthesis](https://huggingface.co/papers/2502.04128). | ||
|
|
||
| X-Codec2 is a neural audio codec designed to improve speech synthesis and general audio generation for large language model (LLM) pipelines. It extends the original X-Codec by refining how semantic and acoustic information is integrated and tokenized, enabling efficient and high-fidelity audio representation. | ||
|
|
||
| Its architecture is based on [X-Codec](./xcodec) with several major differences: | ||
|
|
||
| - **Unified Semantic-Acoustic Tokenization**: X-Codec2 fuses outputs from a semantic encoder (e.g., Wav2Vec2-BERT) and an acoustic encoder into a single embedding, capturing both high-level meaning (e.g., text content, emotion) and low-level audio details (e.g., timbre). | ||
| - **Single-Stage Vector Quantization (VQ)**: Unlike the multi-layer residual VQ in most approaches (e.g., [X-Codec](./xcodec), [DAC](./dac), [EnCodec](./encodec)), X-Codec2 uses a single-layer Feature-Space Quantization (FSQ) for stability and compatibility with causal, autoregressive LLMs. | ||
| - **Semantic Supervision During Training**: It adds a semantic reconstruction loss, ensuring that the discrete tokens preserve meaningful linguistic and emotional information — crucial for TTS tasks. | ||
| - **Transformer-Friendly Design**: The 1D token structure of X-Codec2 naturally aligns with the autoregressive modeling in LLMs like LLaMA, improving training efficiency and downstream compatibility. | ||
|
|
||
| ## Usage example | ||
|
|
||
| Here is a quick example of how to encode and decode an audio using this model: | ||
|
|
||
| ```python | ||
| >>> import torch | ||
| >>> from datasets import Audio, load_dataset | ||
| >>> from transformers import AutoFeatureExtractor, Xcodec2Model | ||
|
|
||
| >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
|
||
| >>> # load model and feature extractor | ||
| >>> model_id = "hf-audio/xcodec2" | ||
| >>> model = Xcodec2Model.from_pretrained(model_id).to(torch_device).eval() | ||
| >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) | ||
|
|
||
| >>> # load data | ||
| >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") | ||
| >>> dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate)) | ||
| >>> audio = dataset[0]["audio"]["array"] | ||
|
|
||
| >>> # prepare data | ||
| >>> inputs = feature_extractor(raw_audio=audio, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt").to(torch_device) | ||
|
|
||
| >>> # encoder and decode | ||
| >>> audio_codes = model.encode(inputs["input_values"]).audio_codes | ||
| >>> audio_values = model.decode(audio_codes).audio_values | ||
| >>> # or the equivalent with a forward pass | ||
| >>> model_output = model(inputs["input_values"]) | ||
| >>> audio_codes = model_output.audio_codes | ||
| >>> audio_values = model_output.audio_values | ||
| ``` | ||
|
|
||
| This model was contributed by [Steven Zheng](https://huggingface.co/Steveeeeeeen) and [Eric Bezzam](https://huggingface.co/bezzam). | ||
| The original code can be found [here](https://github.com/zhenye234/X-Codec-2.0). | ||
|
|
||
|
|
||
| ## Xcodec2Config | ||
|
|
||
| [[autodoc]] Xcodec2Config | ||
|
|
||
| ## Xcodec2Model | ||
|
|
||
| [[autodoc]] Xcodec2Model | ||
| - decode | ||
| - encode | ||
| - forward | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -539,7 +539,7 @@ def forward( | |
|
|
||
| >>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="pt") | ||
|
|
||
| >>> outputs = model(**inputs) | ||
| >>> outputs = model(inputs["input_values"]) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @eustlb DAC, Xcodec, and Xcodec2 don't support
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the codecs we currently have:
Is the model inherently uncompatible with a padding mask approach or is it just not implemented in the original codebase?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Padding has been added to input, see how it's used here |
||
| >>> audio_codes = outputs.audio_codes | ||
| >>> audio_values = outputs.audio_values | ||
| ``` | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| # Copyright 2025 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from ...utils import _LazyModule | ||
| from ...utils.import_utils import define_import_structure | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from .configuration_xcodec2 import * | ||
| from .modeling_xcodec2 import * | ||
| else: | ||
| import sys | ||
|
|
||
| _file = globals()["__file__"] | ||
| sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,190 @@ | ||
| # coding=utf-8 | ||
| # Copyright 2025 The HuggingFace Inc. team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Xcodec2 model configuration""" | ||
|
|
||
| import math | ||
|
|
||
| import numpy as np | ||
|
|
||
| from transformers import AutoConfig, Wav2Vec2BertConfig | ||
|
|
||
| from ...configuration_utils import PretrainedConfig | ||
| from ...utils import logging | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| class Xcodec2Config(PretrainedConfig): | ||
| r""" | ||
| This is the configuration class to store the configuration of an [`Xcodec2Model`]. It is used to instantiate a | ||
| Xcodec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration | ||
| with the defaults will yield a similar configuration to that of the | ||
| [HKUSTAudio/xcodec2](https://huggingface.co/HKUSTAudio/xcodec2) architecture. | ||
ebezzam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | ||
| documentation from [`PretrainedConfig`] for more information. | ||
|
|
||
| Args: | ||
| encoder_hidden_size (`int`, *optional*, defaults to 1024): | ||
| Hidden size for the audio encoder model. | ||
| downsampling_ratios (`list[int]`, *optional*, defaults to `[2, 2, 4, 4, 5]`): | ||
| Ratios for downsampling in the encoder. | ||
| decoder_hidden_size (`int`, *optional*, defaults to 1024): | ||
| Hidden size for the audio decoder model. | ||
| semantic_model_config (`Union[Dict, Wav2Vec2BertConfig]`, *optional*): | ||
| An instance of the configuration object for the semantic (Wav2Vec2BertConfig) model. | ||
| initializer_range (`float`, *optional*, defaults to 0.02): | ||
| The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | ||
ebezzam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| sampling_rate (`int`, *optional*, defaults to 16000): | ||
| The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). | ||
| num_attention_heads (`int`, *optional*, defaults to 16): | ||
| Number of attention heads for the model. | ||
| num_key_value_heads (`int`, *optional*, defaults to 16): | ||
| Number of key value heads for the model. | ||
| num_hidden_layers (`int`, *optional*, defaults to 12): | ||
| Number of hidden layers in the Transformer decoder. | ||
| resnet_dropout (`float`, *optional*, defaults to 0.1): | ||
| Dropout rate for the ResNet blocks in the decoder. | ||
| attention_dropout (`float`, *optional*, defaults to 0.0): | ||
| Dropout rate for the attention layer. | ||
| attention_bias (`bool`, *optional*, defaults to `False`): | ||
| Whether to use bias in the attention layer. | ||
| hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): | ||
| The non-linear activation function (function or string) in the decoder. | ||
| rms_norm_eps (`float`, *optional*, defaults to 1e-06): | ||
| Epsilon for RMS normalization. | ||
| head_dim (`int`, *optional*, defaults to 64): | ||
| Head dimension for the model. | ||
| vq_dim (`int`, *optional*, defaults to 2048): | ||
| Dimension for the VQ codebook. | ||
| vq_levels (`list[int]`, *optional*, defaults to `[4, 4, 4, 4, 4, 4, 4, 4]`): | ||
| Levels for the VQ codebook. | ||
| max_position_embeddings (`int`, *optional*, defaults to 4096): | ||
| The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). | ||
| rope_theta (`float`, *optional*, defaults to 10000.0): | ||
| The base period of the rotary position embeddings. | ||
| """ | ||
|
|
||
| model_type = "xcodec2" | ||
|
|
||
| sub_configs = { | ||
| "semantic_model_config": Wav2Vec2BertConfig, | ||
| } | ||
|
|
||
| def __init__( | ||
| self, | ||
| encoder_hidden_size=1024, | ||
| downsampling_ratios=[2, 2, 4, 4, 5], | ||
| decoder_hidden_size=1024, | ||
| semantic_model_config=None, | ||
| initializer_range=0.02, | ||
| sampling_rate=16000, | ||
| num_attention_heads=16, | ||
| num_key_value_heads=16, | ||
| num_hidden_layers=12, | ||
| resnet_dropout=0.1, | ||
| attention_dropout=0.0, | ||
| attention_bias=False, | ||
| hidden_act="silu", | ||
| rms_norm_eps=1e-6, | ||
| head_dim=64, | ||
| vq_dim=2048, | ||
| vq_levels=[4, 4, 4, 4, 4, 4, 4, 4], | ||
| max_position_embeddings=4096, | ||
| rope_theta=10000.0, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self.encoder_hidden_size = encoder_hidden_size | ||
| self.downsampling_ratios = downsampling_ratios | ||
|
|
||
| self.semantic_model_id = "facebook/w2v-bert-2.0" # needed for feature extractor | ||
| if semantic_model_config is None: | ||
| self.semantic_model_config = Wav2Vec2BertConfig() | ||
| elif isinstance(semantic_model_config, dict): | ||
| if "_name_or_path" in semantic_model_config: | ||
| # If the config is a path, load it using AutoConfig | ||
| self.semantic_model_config = AutoConfig.from_pretrained(semantic_model_config["_name_or_path"]) | ||
| self.semantic_model_id = semantic_model_config["_name_or_path"] | ||
| else: | ||
| # assume HubertConfig as probably created from scratch | ||
| logger.warning( | ||
| "Could not determine semantic model type from config architecture. Defaulting to `Wav2Vec2BertConfig`." | ||
| ) | ||
| self.semantic_model_config = Wav2Vec2BertConfig(**semantic_model_config) | ||
| elif isinstance(semantic_model_config, Wav2Vec2BertConfig): | ||
| self.semantic_model_config = semantic_model_config | ||
| else: | ||
| raise ValueError( | ||
| f"semantic_model_config must be a dict or Wav2Vec2BertConfig instance, but got {type(semantic_model_config)}" | ||
| ) | ||
|
|
||
| self.initializer_range = initializer_range | ||
| self.sampling_rate = sampling_rate | ||
|
|
||
| # decoder parameters, which has hybrid ResNet-Transformer architecture | ||
| self.decoder_hidden_size = decoder_hidden_size | ||
| self.head_dim = head_dim | ||
| self.num_attention_heads = num_attention_heads | ||
| self.num_key_value_heads = num_key_value_heads | ||
| self.num_hidden_layers = num_hidden_layers | ||
| self.resnet_dropout = resnet_dropout | ||
| self.attention_dropout = attention_dropout | ||
| self.attention_bias = attention_bias | ||
| self.hidden_act = hidden_act | ||
| self.rms_norm_eps = rms_norm_eps | ||
| self.max_position_embeddings = max_position_embeddings | ||
| self.rope_theta = rope_theta | ||
|
|
||
| # single codebook VQ is main feature of Xcodec2 | ||
| self.num_quantizers = 1 | ||
| self.vq_dim = vq_dim | ||
| self.vq_levels = vq_levels | ||
|
|
||
| @property | ||
| def frame_rate(self) -> int: | ||
| return math.ceil(self.sampling_rate / self.hop_length) | ||
|
|
||
| @property | ||
| def semantic_hidden_size(self) -> int: | ||
| return self.semantic_model_config.hidden_size | ||
|
|
||
| @property | ||
| def intermediate_size(self) -> int: | ||
| # Semantic and acoustic features are combined for a "fused feature embedding" | ||
| # See Encoder section on p. 3 of https://arxiv.org/pdf/2502.04128 | ||
| return self.encoder_hidden_size + self.semantic_hidden_size | ||
|
|
||
| @property | ||
| def hop_length(self) -> int: | ||
| return int(np.prod(self.downsampling_ratios)) | ||
|
|
||
| @property | ||
| def hidden_size(self) -> int: | ||
| # For Transformer used in decoder | ||
| # See Decoder > Acoustic Reconstruction on p. 3 of https://arxiv.org/pdf/2502.04128 | ||
| return self.decoder_hidden_size | ||
|
|
||
| @property | ||
| def codebook_size(self) -> int: | ||
| return int(np.prod(self.vq_levels)) | ||
|
|
||
| @property | ||
| def codebook_dim(self) -> int: | ||
| return len(self.vq_levels) | ||
|
|
||
|
|
||
| __all__ = ["Xcodec2Config"] | ||
Uh oh!
There was an error while loading. Please reload this page.