diff --git a/docs/source/en/model_doc/whisper.mdx b/docs/source/en/model_doc/whisper.mdx index 353348730add..22b08e4e61bc 100644 --- a/docs/source/en/model_doc/whisper.mdx +++ b/docs/source/en/model_doc/whisper.mdx @@ -79,6 +79,11 @@ The original code can be found [here](https://github.com/openai/whisper). [[autodoc]] WhisperForConditionalGeneration - forward +## WhisperForAudioClassification + +[[autodoc]] WhisperForAudioClassification + - forward + ## TFWhisperModel diff --git a/docs/source/en/tasks/audio_classification.mdx b/docs/source/en/tasks/audio_classification.mdx index 8fbb490c0367..d79bd9033eee 100644 --- a/docs/source/en/tasks/audio_classification.mdx +++ b/docs/source/en/tasks/audio_classification.mdx @@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit -[Audio Spectrogram Transformer](../model_doc/audio-spectrogram-transformer), [Data2VecAudio](../model_doc/data2vec-audio), [Hubert](../model_doc/hubert), [SEW](../model_doc/sew), [SEW-D](../model_doc/sew-d), [UniSpeech](../model_doc/unispeech), [UniSpeechSat](../model_doc/unispeech-sat), [Wav2Vec2](../model_doc/wav2vec2), [Wav2Vec2-Conformer](../model_doc/wav2vec2-conformer), [WavLM](../model_doc/wavlm) +[Audio Spectrogram Transformer](../model_doc/audio-spectrogram-transformer), [Data2VecAudio](../model_doc/data2vec-audio), [Hubert](../model_doc/hubert), [SEW](../model_doc/sew), [SEW-D](../model_doc/sew-d), [UniSpeech](../model_doc/unispeech), [UniSpeechSat](../model_doc/unispeech-sat), [Wav2Vec2](../model_doc/wav2vec2), [Wav2Vec2-Conformer](../model_doc/wav2vec2-conformer), [WavLM](../model_doc/wavlm), [Whisper](../model_doc/whisper) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e3a79b7d8451..e721e01016b3 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2575,6 +2575,7 @@ _import_structure["models.whisper"].extend( [ "WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST", + "WhisperForAudioClassification", "WhisperForConditionalGeneration", "WhisperModel", "WhisperPreTrainedModel", @@ -5782,6 +5783,7 @@ ) from .models.whisper import ( WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST, + WhisperForAudioClassification, WhisperForConditionalGeneration, WhisperModel, WhisperPreTrainedModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index b69761483459..11228c618990 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -877,6 +877,7 @@ ("wav2vec2", "Wav2Vec2ForSequenceClassification"), ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"), ("wavlm", "WavLMForSequenceClassification"), + ("whisper", "WhisperForAudioClassification"), ] ) diff --git a/src/transformers/models/whisper/__init__.py b/src/transformers/models/whisper/__init__.py index 61c9c5fd5d5b..3b6015a56f6f 100644 --- a/src/transformers/models/whisper/__init__.py +++ b/src/transformers/models/whisper/__init__.py @@ -49,6 +49,7 @@ "WhisperForConditionalGeneration", "WhisperModel", "WhisperPreTrainedModel", + "WhisperForAudioClassification", ] try: @@ -99,6 +100,7 @@ else: from .modeling_whisper import ( WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST, + WhisperForAudioClassification, WhisperForConditionalGeneration, WhisperModel, WhisperPreTrainedModel, diff --git a/src/transformers/models/whisper/configuration_whisper.py b/src/transformers/models/whisper/configuration_whisper.py index 7aceffd5b436..3fe936fccf34 100644 --- a/src/transformers/models/whisper/configuration_whisper.py +++ b/src/transformers/models/whisper/configuration_whisper.py @@ -136,6 +136,12 @@ class WhisperConfig(PretrainedConfig): begin_suppress_tokens (`List[int]`, *optional*, defaults to `[220,50256]`): A list containing tokens that will be supressed at the beginning of the sampling process. Initialized as the token for `" "` (`blank_token_id`) and the `eos_token_id` + use_weighted_layer_sum (`bool`, *optional*, defaults to `False`): + Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an + instance of [`WhisperForAudioClassification`]. + classifier_proj_size (`int`, *optional*, defaults to 256): + Dimensionality of the projection before token mean-pooling for classification. Only relevant when using an + instance of [`WhisperForAudioClassification`]. apply_spec_augment (`bool`, *optional*, defaults to `False`): Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see [SpecAugment: A Simple Data Augmentation Method for Automatic Speech @@ -214,6 +220,8 @@ def __init__( eos_token_id=50256, suppress_tokens=None, begin_suppress_tokens=[220, 50256], + use_weighted_layer_sum=False, + classifier_proj_size=256, apply_spec_augment=False, mask_time_prob=0.05, mask_time_length=10, @@ -244,6 +252,11 @@ def __init__( self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.max_source_positions = max_source_positions self.max_target_positions = max_target_positions + + # Audio Classification-specific parameters. Feel free to ignore for other classes. + self.classifier_proj_size = classifier_proj_size + self.use_weighted_layer_sum = use_weighted_layer_sum + # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 self.apply_spec_augment = apply_spec_augment self.mask_time_prob = mask_time_prob diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 51fffe001dc2..2e4dfa67f9c7 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -32,6 +32,7 @@ BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, + SequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings @@ -701,6 +702,33 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ +WHISPER_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + class WhisperEncoder(WhisperPreTrainedModel): """ @@ -1578,3 +1606,123 @@ def _reorder_cache(past_key_values, beam_idx): for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past + + +@add_start_docstrings( + """ + Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks + like SUPERB Keyword Spotting. + """, + WHISPER_ENCODER_INPUTS_DOCSTRING, +) +class WhisperForAudioClassification(WhisperPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.encoder = WhisperEncoder(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_encoder(self): + """ + Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will + not be updated during training. Only the projection layers and classification head will be updated. + """ + self.encoder._freeze_parameters() + + @add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_features: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification + >>> from datasets import load_dataset + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") + >>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") + + >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True) + >>> sample = next(iter(ds)) + + >>> inputs = feature_extractor( + ... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt" + ... ) + >>> input_features = inputs.input_features + + >>> with torch.no_grad(): + ... logits = model(input_features).logits + + >>> predicted_class_ids = torch.argmax(logits).item() + >>> predicted_label = model.config.id2label[predicted_class_ids] + >>> predicted_label + 'af_za' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_features, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = torch.stack(encoder_outputs, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = encoder_outputs[0] + + hidden_states = self.projector(hidden_states) + pooled_output = hidden_states.mean(dim=1) + + logits = self.classifier(pooled_output) + + loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + encoder_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 4328517226cd..d06bf1aea03a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6797,6 +6797,13 @@ def __init__(self, *args, **kwargs): WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = None +class WhisperForAudioClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class WhisperForConditionalGeneration(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 77bd01a01013..6fa50a7f0819 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -43,6 +43,7 @@ from transformers import ( WhisperFeatureExtractor, + WhisperForAudioClassification, WhisperForConditionalGeneration, WhisperModel, WhisperProcessor, @@ -1372,3 +1373,191 @@ def test_tiny_specaugment_librispeech(self): ) # fmt: on self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4)) + + +def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): + if head_mask is None: + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) + return {"input_features": input_features, "head_mask": head_mask} + + +@require_torch +class WhisperEncoderModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=60, + is_training=True, + use_labels=True, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + input_channels=1, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=20, + max_source_positions=30, + num_mel_bins=80, + num_conv_layers=1, + suppress_tokens=None, + begin_suppress_tokens=None, + classifier_proj_size=4, + num_labels=2, + is_encoder_decoder=False, + is_decoder=False, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.input_channels = input_channels + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.num_mel_bins = num_mel_bins + self.max_position_embeddings = max_position_embeddings + self.max_source_positions = max_source_positions + self.num_conv_layers = num_conv_layers + self.suppress_tokens = suppress_tokens + self.begin_suppress_tokens = begin_suppress_tokens + self.classifier_proj_size = classifier_proj_size + self.num_labels = num_labels + self.is_encoder_decoder = is_encoder_decoder + self.is_decoder = is_decoder + + def get_config(self): + return WhisperConfig( + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + input_channels=self.input_channels, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + max_source_positions=self.max_source_positions, + decoder_ffn_dim=self.hidden_size, + encoder_ffn_dim=self.hidden_size, + suppress_tokens=self.suppress_tokens, + begin_suppress_tokens=self.begin_suppress_tokens, + classifier_proj_size=self.classifier_proj_size, + num_labels=self.num_labels, + is_encoder_decoder=self.is_encoder_decoder, + is_decoder=self.is_decoder, + ) + + def prepare_config_and_inputs(self): + input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length]) + + config = self.get_config() + inputs_dict = prepare_whisper_encoder_inputs_dict( + config, + input_features=input_features, + ) + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def get_subsampled_output_lengths(self, input_lengths): + """ + Computes the output length of the convolutional layers + """ + + for i in range(self.num_conv_layers): + input_lengths = (input_lengths - 1) // 2 + 1 + + return input_lengths + + @property + def encoder_seq_length(self): + return self.get_subsampled_output_lengths(self.seq_length) + + def create_and_check_model_forward(self, config, inputs_dict, freeze_encoder=False): + model = WhisperForAudioClassification(config=config).to(torch_device).eval() + + if freeze_encoder: + model.freeze_encoder() + + input_features = inputs_dict["input_features"] + + # first forward pass + last_hidden_state = model(input_features).logits + + self.parent.assertTrue(last_hidden_state.shape, (13, 2)) + + +@require_torch +class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (WhisperForAudioClassification,) if is_torch_available() else () + is_encoder_decoder = False + fx_compatible = False + test_pruning = False + test_missing_keys = False + + input_name = "input_features" + + def setUp(self): + self.model_tester = WhisperEncoderModelTester(self) + self.config_tester = ConfigTester(self, config_class=WhisperConfig) + self.maxDiff = 3000 + + def test_config(self): + self.config_tester.run_common_tests() + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["input_features", "head_mask", "encoder_outputs"] + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + + # input embeds is meaningless for an encoder-only acoustic model + def test_inputs_embeds(self): + pass + + # the equivalent test is passing the encoder outputs directly to the model + def test_encoder_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + with torch.no_grad(): + outputs = model(**inputs)[0] + + input_ids = inputs["input_features"] + del inputs["input_features"] + + encoder = model.encoder + + with torch.no_grad(): + inputs["encoder_outputs"] = encoder(input_ids) + outputs_embeds = model(**inputs)[0] + + self.assertTrue((outputs_embeds == outputs).all()) + + # WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented + def test_model_common_attributes(self): + pass + + # WhisperEncoder cannot resize token embeddings since it has no tokens embeddings + def test_resize_tokens_embeddings(self): + pass