Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/source/en/model_doc/whisper.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,42 @@ The original code can be found [here](https://github.com/openai/whisper).
- Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release.
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.

This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts).
The original code can be found [here](https://github.com/openai/whisper).

## Inference

Here is a step-by-step guide to transcribing an audio sample using a pre-trained Whisper model:

```python
>>> from datasets import load_dataset
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration

>>> # Select an audio file and read it:
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> audio_sample = ds[0]["audio"]
>>> waveform = audio_sample["array"]
>>> sampling_rate = audio_sample["sampling_rate"]

>>> # Load the Whisper model in Hugging Face format:
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")

>>> # Use the model and processor to transcribe the audio:
>>> input_features = processor(
... waveform, sampling_rate=sampling_rate, return_tensors="pt"
... ).input_features

>>> # Generate token ids
>>> predicted_ids = model.generate(input_features)

>>> # Decode token ids to text
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

>>> transcription[0]
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```

## WhisperConfig

[[autodoc]] WhisperConfig
Expand Down
20 changes: 10 additions & 10 deletions src/transformers/models/whisper/configuration_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ class WhisperConfig(PretrainedConfig):
num_mel_bins (`int`, *optional*, defaults to 80):
Number of mel features used per input features. Should correspond to the value used in the
`WhisperProcessor` class.
encoder_layers (`int`, *optional*, defaults to 6):
encoder_layers (`int`, *optional*, defaults to 4):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 6):
decoder_layers (`int`, *optional*, defaults to 4):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 4):
encoder_attention_heads (`int`, *optional*, defaults to 6):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 4):
decoder_attention_heads (`int`, *optional*, defaults to 6):
Number of attention heads for each attention layer in the Transformer decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 1536):
Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
Expand All @@ -106,7 +106,7 @@ class WhisperConfig(PretrainedConfig):
activation_function (`str`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
d_model (`int`, *optional*, defaults to 256):
d_model (`int`, *optional*, defaults to 384):
Dimensionality of the layers.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
Expand Down Expand Up @@ -197,10 +197,10 @@ def __init__(
self,
vocab_size=51865,
num_mel_bins=80,
encoder_layers=6,
encoder_attention_heads=4,
decoder_layers=6,
decoder_attention_heads=4,
encoder_layers=4,
encoder_attention_heads=6,
decoder_layers=4,
decoder_attention_heads=6,
decoder_ffn_dim=1536,
encoder_ffn_dim=1536,
encoder_layerdrop=0.0,
Expand All @@ -209,7 +209,7 @@ def __init__(
use_cache=True,
is_encoder_decoder=True,
activation_function="gelu",
d_model=256,
d_model=384,
dropout=0.0,
attention_dropout=0.0,
activation_dropout=0.0,
Expand Down
14 changes: 9 additions & 5 deletions src/transformers/models/whisper/convert_openai_to_hf.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python
"""Converts a Whisper model in OpenAI format to Hugging Face format."""
# Copyright 2022 The HuggingFace Inc. team and the OpenAI team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,6 +16,7 @@

import argparse
import hashlib
import io
import os
import urllib
import warnings
Expand Down Expand Up @@ -90,7 +93,7 @@ def make_linear_from_emb(emb):
return lin_layer


def _download(url: str, root: str) -> bytes:
def _download(url: str, root: str) -> io.BytesIO:
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)

Expand All @@ -103,7 +106,7 @@ def _download(url: str, root: str) -> bytes:
if os.path.isfile(download_target):
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes
return torch.load(io.BytesIO(model_bytes))
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")

Expand All @@ -125,12 +128,13 @@ def _download(url: str, root: str) -> bytes:
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)

return model_bytes
return torch.load(io.BytesIO(model_bytes))


def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
if ".pt" not in checkpoint_path:
original_checkpoint = _download(_MODELS[checkpoint_path])
root = os.path.dirname(pytorch_dump_folder_path) or "."
original_checkpoint = _download(_MODELS[checkpoint_path], root)
else:
original_checkpoint = torch.load(checkpoint_path, map_location="cpu")
dimensions = original_checkpoint["dims"]
Expand All @@ -151,7 +155,7 @@ def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
encoder_layers=dimensions["n_audio_layer"],
encoder_attention_heads=dimensions["n_audio_head"],
decoder_layers=dimensions["n_text_layer"],
decoder_attention_heads=dimensions["n_text_state"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surprised we managed to convert the original checkpoints with this bug @ArthurZucker 🤔 The state dicts surely won't have matched? Maybe we hardcoded this before?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but I think I hardcoded the values when converting and then later on made it automatic. I checked by actually re-running the script and seeing that this was a nice type 🤣 but good sign that no one else tried to convert the checkpoints !

Copy link
Contributor Author

@zuazo zuazo Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the history behind this. But looking at the blame, the script was correct once:

decoder_attention_heads=dimensions["n_text_head"],

Then it was deleted and recovered in: #20600

That's where the problem seems to come from. So the original script you used may have worked properly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice digging! Yep I think I uploaded an old version late by a few commits

decoder_attention_heads=dimensions["n_text_head"],
max_source_positions=dimensions["n_audio_ctx"],
)

Expand Down