From 2201673cbb680d9d9d381118111287d5535590be Mon Sep 17 00:00:00 2001 From: nithinraok Date: Fri, 31 Oct 2025 12:25:40 -0700 Subject: [PATCH 1/4] add support for saving encoder only so any decoder model can be loaded Signed-off-by: nithinraok --- .../modeling_fastspeech2_conformer.py | 6 +- .../models/parakeet/convert_nemo_to_hf.py | 112 +++++++++++++----- 2 files changed, 88 insertions(+), 30 deletions(-) diff --git a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py index 5a2dc39385b3..34b90ee6af28 100644 --- a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +++ b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py @@ -490,12 +490,12 @@ def __init__(self, config: FastSpeech2ConformerConfig, module_config=None): kernel_size = module_config["kernel_size"] self.activation = ACT2FN[module_config.get("activation", "silu")] self.padding = (kernel_size - 1) // 2 - self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True) + self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias) self.depthwise_conv = nn.Conv1d( - channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True + channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=config.attention_bias ) self.norm = nn.BatchNorm1d(channels) - self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True) + self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias) def forward(self, hidden_states, attention_mask=None): """ diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index f1998fbd81b8..38b97ea0aaa8 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -24,8 +24,10 @@ from tokenizers import AddedToken from transformers import ( + ParakeetEncoderConfig, ParakeetCTCConfig, ParakeetFeatureExtractor, + ParakeetEncoder, ParakeetForCTC, ParakeetProcessor, ParakeetTokenizerFast, @@ -203,7 +205,8 @@ def write_processor(nemo_config: dict, model_files, output_dir, push_to_repo_id= processor.push_to_hub(push_to_repo_id) -def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None): +def convert_encoder_config(nemo_config): + """Convert NeMo encoder config to HF encoder config.""" encoder_keys_to_ignore = [ "att_context_size", "causal_downsampling", @@ -220,8 +223,11 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i "stochastic_depth_mode", "conv_context_size", "dropout_pre_encoder", + "reduction", + "reduction_factor", + "reduction_position" ] - enocder_config_keys_mapping = { + encoder_config_keys_mapping = { "d_model": "hidden_size", "n_heads": "num_attention_heads", "n_layers": "num_hidden_layers", @@ -234,17 +240,23 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i "dropout_emb": "dropout_positions", "dropout_att": "attention_dropout", "xscaling": "scale_input", + 'use_bias': 'attention_bias', } converted_encoder_config = {} for key, value in nemo_config["encoder"].items(): if key in encoder_keys_to_ignore: continue - if key in enocder_config_keys_mapping: - converted_encoder_config[enocder_config_keys_mapping[key]] = value + if key in encoder_config_keys_mapping: + converted_encoder_config[encoder_config_keys_mapping[key]] = value else: - raise ValueError(f"Key {key} not found in enocder_config_keys_mapping") + raise ValueError(f"Key {key} not found in encoder_config_keys_mapping") + + return ParakeetEncoderConfig(**converted_encoder_config) + +def load_and_convert_state_dict(model_files): + """Load NeMo state dict and convert keys to HF format.""" state_dict = torch.load(model_files["model_weights"], map_location="cpu", weights_only=True) converted_state_dict = {} for key, value in state_dict.items(): @@ -255,35 +267,81 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i converted_key = convert_key(key, NEMO_TO_HF_WEIGHT_MAPPING) converted_state_dict[converted_key] = value - if model_type == "ctc": - model_config = ParakeetCTCConfig( - encoder_config=converted_encoder_config, - ) - print("Loading the checkpoint in a Parakeet CTC model.") - with torch.device("meta"): - model = ParakeetForCTC(model_config) - model.load_state_dict(converted_state_dict, strict=True, assign=True) - print("Checkpoint loaded successfully.") - del model.config._name_or_path + return converted_state_dict + + +def write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id=None): + """Write CTC model using encoder config and converted state dict.""" + model_config = ParakeetCTCConfig.from_encoder_config(encoder_config) + + print("Loading the checkpoint in a Parakeet CTC model.") + with torch.device("meta"): + model = ParakeetForCTC(model_config) + model.load_state_dict(converted_state_dict, strict=True, assign=True) + print("Checkpoint loaded successfully.") + del model.config._name_or_path - print("Saving the model.") - model.save_pretrained(output_dir) + print("Saving the model.") + model.save_pretrained(output_dir) - if push_to_repo_id: - model.push_to_hub(push_to_repo_id) + if push_to_repo_id: + model.push_to_hub(push_to_repo_id) + + del model + + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + ParakeetForCTC.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto") + print("Model reloaded successfully.") + +def write_encoder_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id=None): + """Write encoder model using encoder config and converted state dict.""" + # Filter to only encoder weights (exclude CTC head if present) + encoder_state_dict = { + k.replace("encoder.", "", 1) if k.startswith("encoder.") else k: v + for k, v in converted_state_dict.items() + if k.startswith("encoder.") + } - del converted_state_dict, model + print(f"Loading the checkpoint in a Parakeet Encoder model (for TDT).") + with torch.device("meta"): + model = ParakeetEncoder(encoder_config) + + model.load_state_dict(encoder_state_dict, strict=True, assign=True) + print("Checkpoint loaded successfully.") + del model.config._name_or_path - # Safety check: reload the converted model - gc.collect() - print("Reloading the model to check if it's saved correctly.") - ParakeetForCTC.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto") - print("Model reloaded successfully.") + print("Saving the model.") + model.save_pretrained(output_dir) + + if push_to_repo_id: + model.push_to_hub(push_to_repo_id) + del model + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + ParakeetEncoder.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto") + print("Model reloaded successfully.") + +def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None): + """Main model conversion function.""" + # Step 1: Convert encoder config (shared across all model types) + encoder_config = convert_encoder_config(nemo_config) + print(f"Converted encoder config: {encoder_config}") + + # Step 2: Load and convert state dict (shared across all model types) + converted_state_dict = load_and_convert_state_dict(model_files) + + # Step 3: Write model based on type + if model_type == "encoder": + write_encoder_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id) + elif model_type == "ctc": + write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id) else: raise ValueError(f"Model type {model_type} not supported.") - def main( hf_repo_id, output_dir, @@ -303,7 +361,7 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--hf_repo_id", required=True, help="Model repo on huggingface.co") - parser.add_argument("--model_type", required=True, choices=["ctc"], help="Model type (`ctc`, `tdt`)") + parser.add_argument("--model_type", required=True, choices=["encoder", "ctc"], help="Model type (`encoder`, `ctc`)") parser.add_argument("--output_dir", required=True, help="Output directory for HuggingFace model") parser.add_argument("--push_to_repo_id", help="Repository ID to push the model to on the Hub") args = parser.parse_args() From 547f85611b05e8c1f73d6687e2afe1e9878dabbc Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Sun, 2 Nov 2025 18:42:15 +0100 Subject: [PATCH 2/4] use convolution_bias --- .../configuration_fastspeech2_conformer.py | 4 ++++ .../modeling_fastspeech2_conformer.py | 16 +++++++++++++--- .../models/parakeet/configuration_parakeet.py | 4 ++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py b/src/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py index 64c6a4eac8d7..aecfd9f18b2c 100644 --- a/src/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +++ b/src/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py @@ -147,6 +147,8 @@ class FastSpeech2ConformerConfig(PreTrainedConfig): Speaker embedding dimension. If set to > 0, assume that speaker_embedding will be provided as the input. is_encoder_decoder (`bool`, *optional*, defaults to `True`): Specifies whether the model is an encoder-decoder. + convolution_bias (`bool`, *optional*, defaults to `True`): + Specifies whether to use bias in convolutions of the conformer's convolution module. Example: @@ -224,6 +226,7 @@ def __init__( num_languages=None, speaker_embed_dim=None, is_encoder_decoder=True, + convolution_bias=True, **kwargs, ): if positionwise_conv_kernel_size % 2 == 0: @@ -318,6 +321,7 @@ def __init__( self.speaker_embed_dim = speaker_embed_dim self.duration_predictor_dropout_rate = duration_predictor_dropout_rate self.is_encoder_decoder = is_encoder_decoder + self.convolution_bias = convolution_bias super().__init__( is_encoder_decoder=is_encoder_decoder, diff --git a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py index 34b90ee6af28..fa1544a0171c 100644 --- a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +++ b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py @@ -490,12 +490,22 @@ def __init__(self, config: FastSpeech2ConformerConfig, module_config=None): kernel_size = module_config["kernel_size"] self.activation = ACT2FN[module_config.get("activation", "silu")] self.padding = (kernel_size - 1) // 2 - self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias) + self.pointwise_conv1 = nn.Conv1d( + channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias + ) self.depthwise_conv = nn.Conv1d( - channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=config.attention_bias + channels, + channels, + kernel_size, + stride=1, + padding=self.padding, + groups=channels, + bias=config.convolution_bias, ) self.norm = nn.BatchNorm1d(channels) - self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=config.attention_bias) + self.pointwise_conv2 = nn.Conv1d( + channels, channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias + ) def forward(self, hidden_states, attention_mask=None): """ diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py index 1e3d97b4182e..057259b04899 100644 --- a/src/transformers/models/parakeet/configuration_parakeet.py +++ b/src/transformers/models/parakeet/configuration_parakeet.py @@ -44,6 +44,8 @@ class ParakeetEncoderConfig(PreTrainedConfig): The non-linear activation function (function or string) in the encoder and pooler. attention_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention layers. + convolution_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in convolutions of the conformer's convolution module. conv_kernel_size (`int`, *optional*, defaults to 9): The kernel size of the convolution layers in the Conformer block. subsampling_factor (`int`, *optional*, defaults to 8): @@ -102,6 +104,7 @@ def __init__( intermediate_size=4096, hidden_act="silu", attention_bias=True, + convolution_bias=True, conv_kernel_size=9, subsampling_factor=8, subsampling_conv_channels=256, @@ -128,6 +131,7 @@ def __init__( self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.attention_bias = attention_bias + self.convolution_bias = convolution_bias if (conv_kernel_size - 1) % 2 != 0: raise ValueError(f"conv_kernel_size must be odd, got {conv_kernel_size}") From 5913d89a475c3dadb7fe5e7f453b93fe326f087b Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Sun, 2 Nov 2025 18:42:52 +0100 Subject: [PATCH 3/4] convert modular --- .../models/parakeet/modeling_parakeet.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 8ca7b7ff37d8..34697507ffc7 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -130,12 +130,22 @@ def __init__(self, config: ParakeetEncoderConfig, module_config=None): kernel_size = module_config["kernel_size"] self.activation = ACT2FN[module_config.get("activation", "silu")] self.padding = (kernel_size - 1) // 2 - self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True) + self.pointwise_conv1 = nn.Conv1d( + channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias + ) self.depthwise_conv = nn.Conv1d( - channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True + channels, + channels, + kernel_size, + stride=1, + padding=self.padding, + groups=channels, + bias=config.convolution_bias, ) self.norm = nn.BatchNorm1d(channels) - self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True) + self.pointwise_conv2 = nn.Conv1d( + channels, channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias + ) def forward(self, hidden_states, attention_mask=None): """ From 692cc3b6d4303ec10d442ed3087667c2a5dd3ddc Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Sun, 2 Nov 2025 18:50:09 +0100 Subject: [PATCH 4/4] convolution_bias in convertion script --- .../models/parakeet/convert_nemo_to_hf.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/parakeet/convert_nemo_to_hf.py b/src/transformers/models/parakeet/convert_nemo_to_hf.py index 38b97ea0aaa8..e5cbe7f785db 100644 --- a/src/transformers/models/parakeet/convert_nemo_to_hf.py +++ b/src/transformers/models/parakeet/convert_nemo_to_hf.py @@ -24,10 +24,10 @@ from tokenizers import AddedToken from transformers import ( - ParakeetEncoderConfig, ParakeetCTCConfig, + ParakeetEncoder, + ParakeetEncoderConfig, ParakeetFeatureExtractor, - ParakeetEncoder, ParakeetForCTC, ParakeetProcessor, ParakeetTokenizerFast, @@ -225,7 +225,7 @@ def convert_encoder_config(nemo_config): "dropout_pre_encoder", "reduction", "reduction_factor", - "reduction_position" + "reduction_position", ] encoder_config_keys_mapping = { "d_model": "hidden_size", @@ -240,7 +240,7 @@ def convert_encoder_config(nemo_config): "dropout_emb": "dropout_positions", "dropout_att": "attention_dropout", "xscaling": "scale_input", - 'use_bias': 'attention_bias', + "use_bias": "attention_bias", } converted_encoder_config = {} @@ -249,6 +249,9 @@ def convert_encoder_config(nemo_config): continue if key in encoder_config_keys_mapping: converted_encoder_config[encoder_config_keys_mapping[key]] = value + # NeMo uses 'use_bias' for both attention and convolution bias, but HF separates them + if key == "use_bias": + converted_encoder_config["convolution_bias"] = value else: raise ValueError(f"Key {key} not found in encoder_config_keys_mapping") @@ -295,6 +298,7 @@ def write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_re ParakeetForCTC.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto") print("Model reloaded successfully.") + def write_encoder_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id=None): """Write encoder model using encoder config and converted state dict.""" # Filter to only encoder weights (exclude CTC head if present) @@ -304,10 +308,10 @@ def write_encoder_model(encoder_config, converted_state_dict, output_dir, push_t if k.startswith("encoder.") } - print(f"Loading the checkpoint in a Parakeet Encoder model (for TDT).") + print("Loading the checkpoint in a Parakeet Encoder model (for TDT).") with torch.device("meta"): model = ParakeetEncoder(encoder_config) - + model.load_state_dict(encoder_state_dict, strict=True, assign=True) print("Checkpoint loaded successfully.") del model.config._name_or_path @@ -325,6 +329,7 @@ def write_encoder_model(encoder_config, converted_state_dict, output_dir, push_t ParakeetEncoder.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto") print("Model reloaded successfully.") + def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None): """Main model conversion function.""" # Step 1: Convert encoder config (shared across all model types) @@ -342,6 +347,7 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i else: raise ValueError(f"Model type {model_type} not supported.") + def main( hf_repo_id, output_dir, @@ -361,7 +367,9 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--hf_repo_id", required=True, help="Model repo on huggingface.co") - parser.add_argument("--model_type", required=True, choices=["encoder", "ctc"], help="Model type (`encoder`, `ctc`)") + parser.add_argument( + "--model_type", required=True, choices=["encoder", "ctc"], help="Model type (`encoder`, `ctc`)" + ) parser.add_argument("--output_dir", required=True, help="Output directory for HuggingFace model") parser.add_argument("--push_to_repo_id", help="Repository ID to push the model to on the Hub") args = parser.parse_args()