Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class FastSpeech2ConformerConfig(PreTrainedConfig):
Speaker embedding dimension. If set to > 0, assume that speaker_embedding will be provided as the input.
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
Specifies whether the model is an encoder-decoder.
convolution_bias (`bool`, *optional*, defaults to `True`):
Specifies whether to use bias in convolutions of the conformer's convolution module.

Example:

Expand Down Expand Up @@ -224,6 +226,7 @@ def __init__(
num_languages=None,
speaker_embed_dim=None,
is_encoder_decoder=True,
convolution_bias=True,
**kwargs,
):
if positionwise_conv_kernel_size % 2 == 0:
Expand Down Expand Up @@ -318,6 +321,7 @@ def __init__(
self.speaker_embed_dim = speaker_embed_dim
self.duration_predictor_dropout_rate = duration_predictor_dropout_rate
self.is_encoder_decoder = is_encoder_decoder
self.convolution_bias = convolution_bias

super().__init__(
is_encoder_decoder=is_encoder_decoder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,12 +490,22 @@ def __init__(self, config: FastSpeech2ConformerConfig, module_config=None):
kernel_size = module_config["kernel_size"]
self.activation = ACT2FN[module_config.get("activation", "silu")]
self.padding = (kernel_size - 1) // 2
self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True)
self.pointwise_conv1 = nn.Conv1d(
channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
)
self.depthwise_conv = nn.Conv1d(
channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True
channels,
channels,
kernel_size,
stride=1,
padding=self.padding,
groups=channels,
bias=config.convolution_bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
self.pointwise_conv2 = nn.Conv1d(
channels, channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
)

def forward(self, hidden_states, attention_mask=None):
"""
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/parakeet/configuration_parakeet.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class ParakeetEncoderConfig(PreTrainedConfig):
The non-linear activation function (function or string) in the encoder and pooler.
attention_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the attention layers.
convolution_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in convolutions of the conformer's convolution module.
conv_kernel_size (`int`, *optional*, defaults to 9):
The kernel size of the convolution layers in the Conformer block.
subsampling_factor (`int`, *optional*, defaults to 8):
Expand Down Expand Up @@ -102,6 +104,7 @@ def __init__(
intermediate_size=4096,
hidden_act="silu",
attention_bias=True,
convolution_bias=True,
conv_kernel_size=9,
subsampling_factor=8,
subsampling_conv_channels=256,
Expand All @@ -128,6 +131,7 @@ def __init__(
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.attention_bias = attention_bias
self.convolution_bias = convolution_bias

if (conv_kernel_size - 1) % 2 != 0:
raise ValueError(f"conv_kernel_size must be odd, got {conv_kernel_size}")
Expand Down
118 changes: 92 additions & 26 deletions src/transformers/models/parakeet/convert_nemo_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

from transformers import (
ParakeetCTCConfig,
ParakeetEncoder,
ParakeetEncoderConfig,
ParakeetFeatureExtractor,
ParakeetForCTC,
ParakeetProcessor,
Expand Down Expand Up @@ -203,7 +205,8 @@ def write_processor(nemo_config: dict, model_files, output_dir, push_to_repo_id=
processor.push_to_hub(push_to_repo_id)


def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None):
def convert_encoder_config(nemo_config):
"""Convert NeMo encoder config to HF encoder config."""
encoder_keys_to_ignore = [
"att_context_size",
"causal_downsampling",
Expand All @@ -220,8 +223,11 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i
"stochastic_depth_mode",
"conv_context_size",
"dropout_pre_encoder",
"reduction",
"reduction_factor",
"reduction_position",
]
enocder_config_keys_mapping = {
encoder_config_keys_mapping = {
"d_model": "hidden_size",
"n_heads": "num_attention_heads",
"n_layers": "num_hidden_layers",
Expand All @@ -234,17 +240,26 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i
"dropout_emb": "dropout_positions",
"dropout_att": "attention_dropout",
"xscaling": "scale_input",
"use_bias": "attention_bias",
}
converted_encoder_config = {}

for key, value in nemo_config["encoder"].items():
if key in encoder_keys_to_ignore:
continue
if key in enocder_config_keys_mapping:
converted_encoder_config[enocder_config_keys_mapping[key]] = value
if key in encoder_config_keys_mapping:
converted_encoder_config[encoder_config_keys_mapping[key]] = value
# NeMo uses 'use_bias' for both attention and convolution bias, but HF separates them
if key == "use_bias":
converted_encoder_config["convolution_bias"] = value
else:
raise ValueError(f"Key {key} not found in enocder_config_keys_mapping")
raise ValueError(f"Key {key} not found in encoder_config_keys_mapping")

return ParakeetEncoderConfig(**converted_encoder_config)


def load_and_convert_state_dict(model_files):
"""Load NeMo state dict and convert keys to HF format."""
state_dict = torch.load(model_files["model_weights"], map_location="cpu", weights_only=True)
converted_state_dict = {}
for key, value in state_dict.items():
Expand All @@ -255,31 +270,80 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i
converted_key = convert_key(key, NEMO_TO_HF_WEIGHT_MAPPING)
converted_state_dict[converted_key] = value

if model_type == "ctc":
model_config = ParakeetCTCConfig(
encoder_config=converted_encoder_config,
)
print("Loading the checkpoint in a Parakeet CTC model.")
with torch.device("meta"):
model = ParakeetForCTC(model_config)
model.load_state_dict(converted_state_dict, strict=True, assign=True)
print("Checkpoint loaded successfully.")
del model.config._name_or_path
return converted_state_dict


def write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id=None):
"""Write CTC model using encoder config and converted state dict."""
model_config = ParakeetCTCConfig.from_encoder_config(encoder_config)

print("Loading the checkpoint in a Parakeet CTC model.")
with torch.device("meta"):
model = ParakeetForCTC(model_config)
model.load_state_dict(converted_state_dict, strict=True, assign=True)
print("Checkpoint loaded successfully.")
del model.config._name_or_path

print("Saving the model.")
model.save_pretrained(output_dir)

if push_to_repo_id:
model.push_to_hub(push_to_repo_id)

print("Saving the model.")
model.save_pretrained(output_dir)
del model

if push_to_repo_id:
model.push_to_hub(push_to_repo_id)
# Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
ParakeetForCTC.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto")
print("Model reloaded successfully.")

del converted_state_dict, model

# Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
ParakeetForCTC.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto")
print("Model reloaded successfully.")
def write_encoder_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id=None):
"""Write encoder model using encoder config and converted state dict."""
# Filter to only encoder weights (exclude CTC head if present)
encoder_state_dict = {
k.replace("encoder.", "", 1) if k.startswith("encoder.") else k: v
for k, v in converted_state_dict.items()
if k.startswith("encoder.")
}

print("Loading the checkpoint in a Parakeet Encoder model (for TDT).")
with torch.device("meta"):
model = ParakeetEncoder(encoder_config)

model.load_state_dict(encoder_state_dict, strict=True, assign=True)
print("Checkpoint loaded successfully.")
del model.config._name_or_path

print("Saving the model.")
model.save_pretrained(output_dir)

if push_to_repo_id:
model.push_to_hub(push_to_repo_id)
del model

# Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
ParakeetEncoder.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto")
print("Model reloaded successfully.")


def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None):
"""Main model conversion function."""
# Step 1: Convert encoder config (shared across all model types)
encoder_config = convert_encoder_config(nemo_config)
print(f"Converted encoder config: {encoder_config}")

# Step 2: Load and convert state dict (shared across all model types)
converted_state_dict = load_and_convert_state_dict(model_files)

# Step 3: Write model based on type
if model_type == "encoder":
write_encoder_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id)
elif model_type == "ctc":
write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id)
else:
raise ValueError(f"Model type {model_type} not supported.")

Expand All @@ -303,7 +367,9 @@ def main(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--hf_repo_id", required=True, help="Model repo on huggingface.co")
parser.add_argument("--model_type", required=True, choices=["ctc"], help="Model type (`ctc`, `tdt`)")
parser.add_argument(
"--model_type", required=True, choices=["encoder", "ctc"], help="Model type (`encoder`, `ctc`)"
)
parser.add_argument("--output_dir", required=True, help="Output directory for HuggingFace model")
parser.add_argument("--push_to_repo_id", help="Repository ID to push the model to on the Hub")
args = parser.parse_args()
Expand Down
16 changes: 13 additions & 3 deletions src/transformers/models/parakeet/modeling_parakeet.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,22 @@ def __init__(self, config: ParakeetEncoderConfig, module_config=None):
kernel_size = module_config["kernel_size"]
self.activation = ACT2FN[module_config.get("activation", "silu")]
self.padding = (kernel_size - 1) // 2
self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True)
self.pointwise_conv1 = nn.Conv1d(
channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
)
self.depthwise_conv = nn.Conv1d(
channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True
channels,
channels,
kernel_size,
stride=1,
padding=self.padding,
groups=channels,
bias=config.convolution_bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
self.pointwise_conv2 = nn.Conv1d(
channels, channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
)

def forward(self, hidden_states, attention_mask=None):
"""
Expand Down