Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import inspect
import warnings
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -43,6 +44,9 @@

logger = logging.get_logger(__name__)


_HIDDEN_STATES_START_POSITION = 2

_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
_CONFIG_FOR_DOC = "Wav2Vec2Config"
_TOKENIZER_FOR_DOC = "Wav2Vec2Tokenizer"
Expand All @@ -58,6 +62,35 @@
LARGE_NEGATIVE = -1e8


@dataclass
class TFWav2Vec2BaseModelOutput(ModelOutput):
"""
Output type of [`TFWav2Vec2BaseModelOutput`], with potential hidden states and attentions.

Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
extract_features (`tf.Tensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
Sequence of extracted feature vectors of the last convolutional layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

last_hidden_state: tf.Tensor = None
extract_features: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None


def input_values_processing(func, config, input_values, **kwargs):
"""
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
Expand Down Expand Up @@ -707,10 +740,10 @@ def __init__(self, config: Wav2Vec2Config, **kwargs):
self.dropout = tf.keras.layers.Dropout(rate=config.feat_proj_dropout)

def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(hidden_states)
norm_hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(norm_hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states
return hidden_states, norm_hidden_states


# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with TFBart->TFWav2Vec2
Expand Down Expand Up @@ -1222,19 +1255,20 @@ def call(
kwargs_call=kwargs,
)

hidden_states = self.feature_extractor(
extract_features = self.feature_extractor(
tf.cast(inputs["input_values"], tf.float32), training=inputs["training"]
)
# extract_features = tf.transpose(extract_features, perm=(0, 2, 1))

if inputs["attention_mask"] is not None:
# compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1))

attention_mask = tf.sequence_mask(
output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype
output_lengths, maxlen=shape_list(extract_features)[1], dtype=extract_features.dtype
)

hidden_states = self.feature_projection(hidden_states, training=inputs["training"])
hidden_states, extract_features = self.feature_projection(extract_features, training=inputs["training"])

mask_time_indices = kwargs.get("mask_time_indices", None)
if inputs["training"]:
Expand All @@ -1251,10 +1285,11 @@ def call(
hidden_states = encoder_outputs[0]

if not inputs["return_dict"]:
return (hidden_states,) + encoder_outputs[1:]
return (hidden_states, extract_features) + encoder_outputs[1:]

return TFBaseModelOutput(
return TFWav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
Expand Down Expand Up @@ -1635,7 +1670,7 @@ def call(
loss = None

if not inputs["return_dict"]:
output = (logits,) + outputs[1:]
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output

return TFCausalLMOutput(
Expand Down