Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 29 additions & 195 deletions src/transformers/models/hubert/modeling_tf_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" TensorFlow Hubert model."""
import inspect
import warnings
from collections.abc import Mapping
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import tensorflow as tf

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...modeling_tf_utils import (
TFPreTrainedModel,
get_initializer,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
Expand All @@ -47,124 +49,6 @@
LARGE_NEGATIVE = -1e8


# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.input_values_processing
def input_values_processing(func, config, input_values, **kwargs):
"""
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
has to be named accordingly to the parameters name, i.e. `input_values = tf.keras.Input(shape=(128,),
dtype='float32', name="input_values")` otherwise the order of the tensors will not be guaranteed during the
training.

Args:
func (`callable`):
The callable function of the TensorFlow model.
config ([`PretrainedConfig`]):
The config of the running model.
**kwargs:
The inputs of the model.

Returns:
Two lists, one for the missing layers, and another one for the unexpected layers.
"""
signature = dict(inspect.signature(func).parameters)
signature.pop("kwargs", None)
signature.pop("self", None)
parameter_names = list(signature.keys())
output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)

for k, v in kwargs.items():
if isinstance(v, allowed_types) or v is None:
output[k] = v
else:
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")

if isinstance(input_values, (tuple, list)):
for i, input in enumerate(input_values):
# EagerTensors don't allow to use the .name property so we check for a real Tensor
if type(input) == tf.Tensor:
# Tensor names have always the pattern `name:id` then we check only the
# `name` part
tensor_name = input.name.split(":")[0]

if tensor_name in parameter_names:
output[tensor_name] = input
else:
output[parameter_names[i]] = input
elif isinstance(input, allowed_types) or input is None:
output[parameter_names[i]] = input
else:
raise ValueError(
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
f" {parameter_names[i]}."
)
elif isinstance(input_values, Mapping):
if "inputs" in input_values:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values`"
" instead.",
FutureWarning,
)

output["input_values"] = input_values.pop("inputs")

if "decoder_cached_states" in input_values:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
" `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = input_values.pop("decoder_cached_states")

for k, v in dict(input_values).items():
if isinstance(v, allowed_types) or v is None:
output[k] = v
elif k not in parameter_names and "args" not in parameter_names:
logger.warning(
f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
)
continue
else:
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
else:
if isinstance(input_values, tf.Tensor) or input_values is None:
output[parameter_names[0]] = input_values
else:
raise ValueError(
f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for"
f" {parameter_names[0]}."
)

for name in parameter_names:
if name not in list(output.keys()) and name != "args":
output[name] = kwargs.pop(name, signature[name].default)

# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
# So to respect the proper output we have to add this exception
if "args" in output:
if output["args"] is not None and type(output["args"]) == tf.Tensor:
tensor_name = output["args"].name.split(":")[0]
output[tensor_name] = output["args"]
else:
# `args` in this case is always the first parameter, then `input_values`
output["input_values"] = output["args"]

del output["args"]

if "kwargs" in output:
del output["kwargs"]

boolean_dict = {
k: v
for k, v in output.items()
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
}

output.update(booleans_processing(config=config, **boolean_dict))

return output


# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._sample_without_replacement
def _sample_without_replacement(distribution, num_samples):
"""
Expand Down Expand Up @@ -1208,6 +1092,7 @@ def _mask_hidden_states(self, hidden_states: tf.Tensor, mask_time_indices: Optio

return hidden_states

@unpack_inputs
def call(
self,
input_values: tf.Tensor,
Expand All @@ -1222,51 +1107,33 @@ def call(
training: bool = False,
**kwargs: Any,
):
inputs = input_values_processing(
func=self.call,
config=self.config,
input_values=input_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)

hidden_states = self.feature_extractor(
tf.cast(inputs["input_values"], tf.float32), training=inputs["training"]
)
hidden_states = self.feature_extractor(tf.cast(input_values, tf.float32), training=training)

if inputs["attention_mask"] is not None:
if attention_mask is not None:
# compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1))
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1))

attention_mask = tf.sequence_mask(
output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype
)

hidden_states = self.feature_projection(hidden_states, training=inputs["training"])
hidden_states = self.feature_projection(hidden_states, training=training)

mask_time_indices = kwargs.get("mask_time_indices", None)
if inputs["training"]:
if training:
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)

encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
hidden_states = encoder_outputs[0]

if not inputs["return_dict"]:
if not return_dict:
return (hidden_states,) + encoder_outputs[1:]

return TFBaseModelOutput(
Expand Down Expand Up @@ -1428,6 +1295,7 @@ def __init__(self, config: HubertConfig, *inputs, **kwargs):

@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call(
self,
input_values: tf.Tensor,
Expand Down Expand Up @@ -1469,9 +1337,11 @@ def call(
>>> hidden_states = model(input_values).last_hidden_state
```"""

inputs = input_values_processing(
func=self.call,
config=self.config,
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
output_attentions = output_attentions if output_attentions else self.config.output_attentions
return_dict = return_dict if return_dict else self.config.return_dict

outputs = self.hubert(
input_values=input_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -1484,27 +1354,6 @@ def call(
training=training,
)

inputs["output_hidden_states"] = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] else self.config.output_hidden_states
)
inputs["output_attentions"] = (
inputs["output_attentions"] if inputs["output_attentions"] else self.config.output_attentions
)
inputs["return_dict"] = inputs["return_dict"] if inputs["return_dict"] else self.config.return_dict

outputs = self.hubert(
input_values=inputs["input_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)

return outputs

def serving_output(self, output):
Expand Down Expand Up @@ -1548,6 +1397,7 @@ def freeze_feature_encoder(self):

@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call(
self,
input_values: tf.Tensor,
Expand Down Expand Up @@ -1605,9 +1455,8 @@ def call(

>>> loss = model(input_values, labels=labels).loss
```"""
inputs = input_values_processing(
func=self.call,
config=self.config,

outputs = self.hubert(
input_values=input_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -1619,21 +1468,8 @@ def call(
return_dict=return_dict,
training=training,
)

outputs = self.hubert(
input_values=inputs["input_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states, training=inputs["training"])
hidden_states = self.dropout(hidden_states, training=training)

logits = self.lm_head(hidden_states)

Expand All @@ -1642,9 +1478,7 @@ def call(
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")

attention_mask = (
inputs["attention_mask"]
if inputs["attention_mask"] is not None
else tf.ones_like(inputs["input_values"], dtype=tf.float32)
attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32)
)
input_lengths = self.hubert._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1))

Expand All @@ -1671,7 +1505,7 @@ def call(
else:
loss = None

if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

Expand Down
Loading