diff --git a/docs/source/en/model_doc/whisper.md b/docs/source/en/model_doc/whisper.md index fbf806cd41df..05b1947df3ac 100644 --- a/docs/source/en/model_doc/whisper.md +++ b/docs/source/en/model_doc/whisper.md @@ -99,6 +99,10 @@ The original code can be found [here](https://github.com/openai/whisper). [[autodoc]] TFWhisperForConditionalGeneration - call +## TFWhisperForAudioClassification + +[[autodoc]] TFWhisperForAudioClassification + - call ## FlaxWhisperModel diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2253bda3908a..e9e444d82a5d 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3705,6 +3705,7 @@ _import_structure["models.whisper"].extend( [ "TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFWhisperForAudioClassification", "TFWhisperForConditionalGeneration", "TFWhisperModel", "TFWhisperPreTrainedModel", @@ -7124,6 +7125,7 @@ ) from .models.whisper import ( TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST, + TFWhisperForAudioClassification, TFWhisperForConditionalGeneration, TFWhisperModel, TFWhisperPreTrainedModel, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index ecf9b06da5c6..72156dd6e468 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -357,7 +357,13 @@ ("xlnet", "TFXLNetForQuestionAnsweringSimple"), ] ) -TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")]) + +TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + ("wav2vec2", "TFWav2Vec2ForSequenceClassification"), + ("whisper", "TFWhisperForAudioClassification"), + ] +) TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ @@ -461,6 +467,10 @@ ) TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES) +TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES +) + TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES) TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES) TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) diff --git a/src/transformers/models/whisper/__init__.py b/src/transformers/models/whisper/__init__.py index cd962478e34d..e88bb4454315 100644 --- a/src/transformers/models/whisper/__init__.py +++ b/src/transformers/models/whisper/__init__.py @@ -63,6 +63,7 @@ "TFWhisperForConditionalGeneration", "TFWhisperModel", "TFWhisperPreTrainedModel", + "TFWhisperForAudioClassification", ] try: @@ -115,6 +116,7 @@ else: from .modeling_tf_whisper import ( TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST, + TFWhisperForAudioClassification, TFWhisperForConditionalGeneration, TFWhisperModel, TFWhisperPreTrainedModel, diff --git a/src/transformers/models/whisper/modeling_tf_whisper.py b/src/transformers/models/whisper/modeling_tf_whisper.py index 474c04499515..2f4e47d15502 100644 --- a/src/transformers/models/whisper/modeling_tf_whisper.py +++ b/src/transformers/models/whisper/modeling_tf_whisper.py @@ -32,6 +32,7 @@ TFBaseModelOutputWithPastAndCrossAttentions, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput, + TFSequenceClassifierOutput, ) from ...modeling_tf_utils import ( TFCausalLanguageModelingLoss, @@ -133,7 +134,7 @@ def build(self, input_shape): def call(self, input_ids, past_key_values_length=0): past_key_values_length = tf.cast(past_key_values_length, tf.int32) - gather_indices = tf.range(tf.shape(input_ids)[1], delta=1) + past_key_values_length + gather_indices = tf.range(tf.shape(input_ids)[-1], delta=1) + past_key_values_length return tf.gather(self.weight, gather_indices) @@ -627,6 +628,12 @@ def __init__(self, config: WhisperConfig, **kwargs): self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") self.dropout = tf.keras.layers.Dropout(config.dropout) + + def get_input_embeddings(self): + return self.conv1 + + def set_input_embeddings(self, value): + self.conv1 = value @unpack_inputs def call( @@ -1599,3 +1606,94 @@ def prepare_inputs_for_generation( "decoder_attention_mask": decoder_attention_mask, "decoder_position_ids": decoder_position_ids, } + + +class TFWhisperForAudioClassification(TFWhisperPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.encoder = TFWhisperEncoder(config) + num_layers = config.num_hidden_layers + 1 + if config.use_weighted_layer_sum: + self.layer_weights = tf.Variable(tf.ones(shape=(num_layers,)) / num_layers) + self.projector = tf.keras.layers.Dense(units=config.classifier_proj_size, input_shape=(config.hidden_size,)) + self.classifier = tf.keras.layers.Dense( + units=config.num_labels, input_shape=(config.classifier_proj_size,), activation=None + ) + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, value): + self.encoder.set_input_embeddings(value) + + @property + def dummy_inputs(self) -> Dict[str, tf.Tensor]: + """ + Dummy inputs to build the network. + + Returns: + `Dict[str, tf.Tensor]`: The dummy inputs. + """ + return { + self.main_input_name: tf.random.uniform( + [1, self.config.num_mel_bins, self.config.max_source_positions * 2 - 1], dtype=tf.float32 + ), + } + + @property + def input_signature(self): + return { + "input_features": tf.TensorSpec((None, self.config.num_mel_bins, None), tf.float32, name="input_features"), + } + + @unpack_inputs + def call( + self, + input_features: Optional[tf.Tensor] = None, + head_mask: Optional[tf.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[tf.Tensor]]] = None, + labels: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_features, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + if self.config.use_weighted_layer_sum: + hidden_states = tf.stack(encoder_outputs, axis=1) + norm_weights = tf.nn.softmax(self.layer_weights, axis=-1) + hidden_states = tf.reduce_sum(hidden_states * tf.reshape(norm_weights, [-1, 1, 1]), axis=1) + else: + hidden_states = encoder_outputs[0] + + hidden_states = self.projector(hidden_states) + pooled_output = tf.reduce_mean(hidden_states, axis=1) + + logits = self.classifier(pooled_output) + + loss = None + + if labels is not None: + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + loss = loss_fn(tf.reshape(labels, [-1]), tf.reshape(logits, [-1, self.config.num_labels])) + + if not return_dict: + output = (logits,) + encoder_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 46cde8ffbef4..f5c294362ed4 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -2699,6 +2699,13 @@ def __init__(self, *args, **kwargs): TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = None +class TFWhisperForAudioClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFWhisperForConditionalGeneration(metaclass=DummyObject): _backends = ["tf"] diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index 1bf5c2ccc230..97e1dcf37eec 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Testing suite for the TensorFlow Whisper model. """ - from __future__ import annotations +import copy import inspect import tempfile import traceback @@ -41,7 +41,12 @@ if is_tf_available(): import tensorflow as tf - from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed + from transformers import ( + TFWhisperForAudioClassification, + TFWhisperForConditionalGeneration, + TFWhisperModel, + set_seed, + ) from transformers.models.whisper.modeling_tf_whisper import TFWhisperDecoder, TFWhisperEncoder @@ -850,6 +855,190 @@ def _test_large_batched_generation(in_queue, out_queue, timeout): out_queue.join() +@require_tf +class TFWhisperEncoderModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=60, + is_training=True, + use_labels=True, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + input_channels=1, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=20, + max_source_positions=30, + num_mel_bins=80, + num_conv_layers=1, + suppress_tokens=None, + begin_suppress_tokens=None, + classifier_proj_size=4, + num_labels=2, + is_encoder_decoder=False, + is_decoder=False, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.input_channels = input_channels + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.num_mel_bins = num_mel_bins + self.max_position_embeddings = max_position_embeddings + self.max_source_positions = max_source_positions + self.num_conv_layers = num_conv_layers + self.suppress_tokens = suppress_tokens + self.begin_suppress_tokens = begin_suppress_tokens + self.classifier_proj_size = classifier_proj_size + self.num_labels = num_labels + self.is_encoder_decoder = is_encoder_decoder + self.is_decoder = is_decoder + + def get_config(self): + return WhisperConfig( + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + input_channels=self.input_channels, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + max_source_positions=self.max_source_positions, + decoder_ffn_dim=self.hidden_size, + encoder_ffn_dim=self.hidden_size, + suppress_tokens=self.suppress_tokens, + begin_suppress_tokens=self.begin_suppress_tokens, + classifier_proj_size=self.classifier_proj_size, + num_labels=self.num_labels, + is_encoder_decoder=self.is_encoder_decoder, + is_decoder=self.is_decoder, + ) + + def prepare_config_and_inputs(self): + input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length]) + + config = self.get_config() + inputs_dict = prepare_whisper_encoder_inputs_dict( + config, + input_features=input_features, + ) + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def get_subsampled_output_lengths(self, input_lengths): + """ + Computes the output length of the convolutional layers + """ + + for i in range(self.num_conv_layers): + input_lengths = (input_lengths - 1) // 2 + 1 + + return input_lengths + + @property + def encoder_seq_length(self): + return self.get_subsampled_output_lengths(self.seq_length) + + def create_and_check_model_forward(self, config, inputs_dict, freeze_encoder=False): + model = TFWhisperForAudioClassification(config=config) + + if freeze_encoder: + model.freeze_encoder() + + input_features = inputs_dict["input_features"] + + # first forward pass + last_hidden_state = model(input_features).logits + + self.parent.assertTrue(last_hidden_state.shape, (13, 2)) + + +def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): + if head_mask is None: + head_mask = tf.ones([config.encoder_layers, config.encoder_attention_heads]) + return {"input_features": input_features, "head_mask": head_mask} + + +@require_tf +class TFWhisperEncoderModelTest(TFModelTesterMixin, unittest.TestCase): + all_model_classes = (TFWhisperForAudioClassification,) if is_tf_available() else () + is_encoder_decoder = False + fx_compatible = False + test_pruning = False + test_missing_keys = False + test_onnx = False + + input_name = "input_features" + + def setUp(self): + self.model_tester = TFWhisperEncoderModelTester(self) + self.config_tester = ConfigTester(self, config_class=WhisperConfig) + self.maxDiff = 3000 + + def test_config(self): + self.config_tester.run_common_tests() + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.call) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["input_features", "head_mask", "encoder_outputs"] + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + + # input embeds is meaningless for an encoder-only acoustic model + def test_inputs_embeds(self): + pass + + # the equivalent test is passing the encoder outputs directly to the model + def test_encoder_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + outputs = model(**inputs)[0] + + input_ids = inputs["input_features"] + + encoder = model.encoder + + inputs["encoder_outputs"] = encoder(input_ids) + outputs_embeds = model(**inputs)[0] + + self.assertTrue(tf.experimental.numpy.all(outputs_embeds == outputs)) + + # WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented + def test_model_common_attributes(self): + pass + + # WhisperEncoder cannot resize token embeddings since it has no tokens embeddings + def test_resize_tokens_embeddings(self): + pass + + @require_tf @require_tokenizers class TFWhisperModelIntegrationTests(unittest.TestCase):