diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py index d7e15fb4f41e..fb44424a1c74 100644 --- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py @@ -16,6 +16,7 @@ import inspect import warnings +from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union import numpy as np @@ -43,6 +44,9 @@ logger = logging.get_logger(__name__) + +_HIDDEN_STATES_START_POSITION = 2 + _CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h" _CONFIG_FOR_DOC = "Wav2Vec2Config" _TOKENIZER_FOR_DOC = "Wav2Vec2Tokenizer" @@ -58,6 +62,35 @@ LARGE_NEGATIVE = -1e8 +@dataclass +class TFWav2Vec2BaseModelOutput(ModelOutput): + """ + Output type of [`TFWav2Vec2BaseModelOutput`], with potential hidden states and attentions. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + extract_features (`tf.Tensor` of shape `(batch_size, sequence_length, conv_dim[-1])`): + Sequence of extracted feature vectors of the last convolutional layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: tf.Tensor = None + extract_features: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None + + def input_values_processing(func, config, input_values, **kwargs): """ Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input @@ -707,10 +740,10 @@ def __init__(self, config: Wav2Vec2Config, **kwargs): self.dropout = tf.keras.layers.Dropout(rate=config.feat_proj_dropout) def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.projection(hidden_states) + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) hidden_states = self.dropout(hidden_states, training=training) - return hidden_states + return hidden_states, norm_hidden_states # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with TFBart->TFWav2Vec2 @@ -1222,19 +1255,20 @@ def call( kwargs_call=kwargs, ) - hidden_states = self.feature_extractor( + extract_features = self.feature_extractor( tf.cast(inputs["input_values"], tf.float32), training=inputs["training"] ) + # extract_features = tf.transpose(extract_features, perm=(0, 2, 1)) if inputs["attention_mask"] is not None: # compute real output lengths according to convolution formula output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1)) attention_mask = tf.sequence_mask( - output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype + output_lengths, maxlen=shape_list(extract_features)[1], dtype=extract_features.dtype ) - hidden_states = self.feature_projection(hidden_states, training=inputs["training"]) + hidden_states, extract_features = self.feature_projection(extract_features, training=inputs["training"]) mask_time_indices = kwargs.get("mask_time_indices", None) if inputs["training"]: @@ -1251,10 +1285,11 @@ def call( hidden_states = encoder_outputs[0] if not inputs["return_dict"]: - return (hidden_states,) + encoder_outputs[1:] + return (hidden_states, extract_features) + encoder_outputs[1:] - return TFBaseModelOutput( + return TFWav2Vec2BaseModelOutput( last_hidden_state=hidden_states, + extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) @@ -1635,7 +1670,7 @@ def call( loss = None if not inputs["return_dict"]: - output = (logits,) + outputs[1:] + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] return ((loss,) + output) if loss is not None else output return TFCausalLMOutput(