Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 216 additions & 0 deletions src/transformers/models/whisper/convert_hf_to_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
#!/usr/bin/env python
"""Converts a Whisper model in Hugging Face format to OpenAI format.

This script is based on the following script to do the opposite:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/convert_openai_to_hf.py

Requirements:

```bash
pip install -U openai-whisper
```

Example:

```bash
# Converts the model from Hugging Face to OpenAI format:
python convert_hf_to_openai.py \
--checkpoint openai/whisper-tiny \
--whisper_dump_path whisper-tiny-openai.pt
```

```python
>>> # Disabled doctest because it requries the openai-whisper package.
>> import whisper
>> from transformers.models.whisper.convert_hf_to_openai import convert_tfms_to_openai_whisper

>> # Converts the model from Hugging Face to OpenAI format:
>> convert_tfms_to_openai_whisper(
.. "openai/whisper-tiny", "whisper-tiny-openai.pt"
.. )
HF model path: openai/whisper-tiny
OpenAI model path: whisper-tiny-openai.pt
>> # Select an audio file:
>> audio_path = "https://huggingface.co/datasets/sanchit-gandhi/librispeech_long/resolve/main/audio.wav"

>> # Load the Whisper model in OpenAI format:
>> model = whisper.load_model("whisper-tiny-openai.pt")

>> # Transcribe the audio:
>> prediction = model.transcribe(audio_path)
>> prediction["text"][:70]
' chapter 16. I might have told you of the beginning of this liaison in'
```
"""
# Copyright 2023 Xabier de Zuazo and the Aholab team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import torch
from torch import nn

from transformers import WhisperConfig, WhisperForConditionalGeneration


# Create the reverse mapping adapting it from the original `WHISPER_MAPPING` in
# the `convert_openai_to_hf.py` script:
REVERSE_WHISPER_MAPPING = {
"layers": "blocks",
"fc1": "mlp.0",
"fc2": "mlp.2",
"final_layer_norm": "mlp_ln",
".self_attn.q_proj": ".attn.query",
".self_attn.k_proj": ".attn.key",
".self_attn.v_proj": ".attn.value",
".self_attn_layer_norm": ".attn_ln",
".self_attn.out_proj": ".attn.out",
".encoder_attn.q_proj": ".cross_attn.query",
".encoder_attn.k_proj": ".cross_attn.key",
".encoder_attn.v_proj": ".cross_attn.value",
".encoder_attn_layer_norm": ".cross_attn_ln",
".encoder_attn.out_proj": ".cross_attn.out",
"decoder.layer_norm.": "decoder.ln.",
"encoder.layer_norm.": "encoder.ln_post.",
"embed_tokens": "token_embedding",
"encoder.embed_positions.weight": "encoder.positional_embedding",
"decoder.embed_positions.weight": "decoder.positional_embedding",
}


def reverse_rename_keys(s_dict: dict) -> dict:
"""Renames the keys back from Hugging Face to OpenAI Whisper format.

By using this function on an HF model's state_dict, we should get the names in the format expected by Whisper.

Args:
s_dict (`dict`): A dictionary with keys in Hugging Face format.

Returns:
`dict`: The same dictionary but in OpenAI Whisper format.
"""
keys = list(s_dict.keys())
for orig_key in keys:
new_key = orig_key
for key_r, value_r in REVERSE_WHISPER_MAPPING.items():
if key_r in orig_key:
new_key = new_key.replace(key_r, value_r)

# print(f"{orig_key} -> {new_key}")

s_dict[new_key] = s_dict.pop(orig_key)
return s_dict


def make_emb_from_linear(linear: nn.Linear) -> nn.Embedding:
"""Converts a linear layer's weights into an embedding layer.

The linear layer's `in_features` dimension corresponds to the vocabulary size and its `out_features` dimension
corresponds to the embedding size.

Args:
linear (`nn.Linear`): The linear layer to be converted.

Returns:
`nn.Embedding`:
An embedding layer with weights set to those of the input linear layer.

"""
vocab_size, emb_size = linear.weight.data.shape
emb_layer = nn.Embedding(vocab_size, emb_size, _weight=linear.weight.data)
return emb_layer


def extract_dims_from_hf(config: WhisperConfig) -> dict:
"""Extracts necessary dimensions from Hugging Face's WhisperConfig.

Extracts necessary dimensions and related configuration data from the Hugging Face model and then restructure it
for the OpenAI Whisper format.

Args:
config (`WhisperConfig`): Configuration of the Hugging Face's model.

Returns:
`dict`: The `dims` of the OpenAI Whisper model.
"""
dims = {
"n_vocab": config.vocab_size,
"n_mels": config.num_mel_bins,
"n_audio_state": config.d_model,
"n_text_ctx": config.max_target_positions,
"n_audio_layer": config.encoder_layers,
"n_audio_head": config.encoder_attention_heads,
"n_text_layer": config.decoder_layers,
"n_text_head": config.decoder_attention_heads,
"n_text_state": config.d_model,
"n_audio_ctx": config.max_source_positions,
}
return dims


def convert_tfms_to_openai_whisper(hf_model_path: str, whisper_dump_path: str):
"""Converts a Whisper model from the Hugging Face to the OpenAI format.

Takes in the path to a Hugging Face Whisper model, extracts its state_dict, renames keys as needed, and then saves
the model OpenAI's format.

Args:
hf_model_path (`str`):
Path to the pretrained Whisper model in Hugging Face format.
whisper_dump_path (`str`):
Destination path where the converted model in Whisper/OpenAI format will be saved.

Returns:
`None`
"""
print("HF model path:", hf_model_path)
print("OpenAI model path:", whisper_dump_path)

# Load the HF model and its state_dict
model = WhisperForConditionalGeneration.from_pretrained(hf_model_path)
state_dict = model.state_dict()

# Use a reverse mapping to rename state_dict keys
state_dict = reverse_rename_keys(state_dict)

# Extract configurations and other necessary metadata
dims = extract_dims_from_hf(model.config)

# Remove the proj_out weights from state dictionary
del state_dict["proj_out.weight"]

# Construct the Whisper checkpoint structure
state_dict = {k.replace("model.", "", 1): v for k, v in state_dict.items()}
whisper_checkpoint = {"dims": dims, "model_state_dict": state_dict}

# Save in Whisper's format
torch.save(whisper_checkpoint, whisper_dump_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint",
type=str,
help="Path of name of the Hugging Face checkpoint.", # noqa: E501
)
parser.add_argument(
"--whisper_dump_path",
type=str,
help="Path to the output Whisper model.", # noqa: E501
)
args = parser.parse_args()

convert_tfms_to_openai_whisper(args.checkpoint, args.whisper_dump_path)