Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/whisper.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ The original code can be found [here](https://github.com/openai/whisper).
[[autodoc]] WhisperForConditionalGeneration
- forward

## WhisperForAudioClassification

[[autodoc]] WhisperForAudioClassification
- forward


## TFWhisperModel

Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/tasks/audio_classification.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit

<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->

[Audio Spectrogram Transformer](../model_doc/audio-spectrogram-transformer), [Data2VecAudio](../model_doc/data2vec-audio), [Hubert](../model_doc/hubert), [SEW](../model_doc/sew), [SEW-D](../model_doc/sew-d), [UniSpeech](../model_doc/unispeech), [UniSpeechSat](../model_doc/unispeech-sat), [Wav2Vec2](../model_doc/wav2vec2), [Wav2Vec2-Conformer](../model_doc/wav2vec2-conformer), [WavLM](../model_doc/wavlm)
[Audio Spectrogram Transformer](../model_doc/audio-spectrogram-transformer), [Data2VecAudio](../model_doc/data2vec-audio), [Hubert](../model_doc/hubert), [SEW](../model_doc/sew), [SEW-D](../model_doc/sew-d), [UniSpeech](../model_doc/unispeech), [UniSpeechSat](../model_doc/unispeech-sat), [Wav2Vec2](../model_doc/wav2vec2), [Wav2Vec2-Conformer](../model_doc/wav2vec2-conformer), [WavLM](../model_doc/wavlm), [Whisper](../model_doc/whisper)

<!--End of the generated tip-->

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2575,6 +2575,7 @@
_import_structure["models.whisper"].extend(
[
"WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST",
"WhisperForAudioClassification",
"WhisperForConditionalGeneration",
"WhisperModel",
"WhisperPreTrainedModel",
Expand Down Expand Up @@ -5782,6 +5783,7 @@
)
from .models.whisper import (
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
WhisperForAudioClassification,
WhisperForConditionalGeneration,
WhisperModel,
WhisperPreTrainedModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,7 @@
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
("wavlm", "WavLMForSequenceClassification"),
("whisper", "WhisperForAudioClassification"),
]
)

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"WhisperForConditionalGeneration",
"WhisperModel",
"WhisperPreTrainedModel",
"WhisperForAudioClassification",
]

try:
Expand Down Expand Up @@ -99,6 +100,7 @@
else:
from .modeling_whisper import (
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
WhisperForAudioClassification,
WhisperForConditionalGeneration,
WhisperModel,
WhisperPreTrainedModel,
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/models/whisper/configuration_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ class WhisperConfig(PretrainedConfig):
begin_suppress_tokens (`List[int]`, *optional*, defaults to `[220,50256]`):
A list containing tokens that will be supressed at the beginning of the sampling process. Initialized as
the token for `" "` (`blank_token_id`) and the `eos_token_id`
use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
instance of [`WhisperForAudioClassification`].
classifier_proj_size (`int`, *optional*, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification. Only relevant when using an
instance of [`WhisperForAudioClassification`].
apply_spec_augment (`bool`, *optional*, defaults to `False`):
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech
Expand Down Expand Up @@ -214,6 +220,8 @@ def __init__(
eos_token_id=50256,
suppress_tokens=None,
begin_suppress_tokens=[220, 50256],
use_weighted_layer_sum=False,
classifier_proj_size=256,
apply_spec_augment=False,
mask_time_prob=0.05,
mask_time_length=10,
Expand Down Expand Up @@ -244,6 +252,11 @@ def __init__(
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions

# Audio Classification-specific parameters. Feel free to ignore for other classes.
self.classifier_proj_size = classifier_proj_size
self.use_weighted_layer_sum = use_weighted_layer_sum

# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
self.apply_spec_augment = apply_spec_augment
self.mask_time_prob = mask_time_prob
Expand Down
148 changes: 148 additions & 0 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
SequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
Expand Down Expand Up @@ -701,6 +702,33 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

WHISPER_ENCODER_INPUTS_DOCSTRING = r"""
Args:
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:

- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


class WhisperEncoder(WhisperPreTrainedModel):
"""
Expand Down Expand Up @@ -1578,3 +1606,123 @@ def _reorder_cache(past_key_values, beam_idx):
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past


@add_start_docstrings(
"""
Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
like SUPERB Keyword Spotting.
""",
WHISPER_ENCODER_INPUTS_DOCSTRING,
)
class WhisperForAudioClassification(WhisperPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.encoder = WhisperEncoder(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
Comment thread
sanchit-gandhi marked this conversation as resolved.
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)

# Initialize weights and apply final processing
self.post_init()

def freeze_encoder(self):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training. Only the projection layers and classification head will be updated.
"""
self.encoder._freeze_parameters()

@add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we could also use your pretrained model here if it is sota?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved in d6cba06

def forward(
self,
input_features: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).

Returns:

Example:

```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification
>>> from datasets import load_dataset

>>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
>>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")

>>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
>>> sample = next(iter(ds))

>>> inputs = feature_extractor(
... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt"
... )
>>> input_features = inputs.input_features

>>> with torch.no_grad():
... logits = model(input_features).logits

>>> predicted_class_ids = torch.argmax(logits).item()
>>> predicted_label = model.config.id2label[predicted_class_ids]
>>> predicted_label
'af_za'
```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if encoder_outputs is None:
encoder_outputs = self.encoder(
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

if self.config.use_weighted_layer_sum:
hidden_states = torch.stack(encoder_outputs, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = encoder_outputs[0]

hidden_states = self.projector(hidden_states)
pooled_output = hidden_states.mean(dim=1)

logits = self.classifier(pooled_output)

loss = None

if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + encoder_outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -6797,6 +6797,13 @@ def __init__(self, *args, **kwargs):
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = None


class WhisperForAudioClassification(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class WhisperForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
Loading