diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 2ab3e7938117..17f82828d968 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -344,6 +344,46 @@ def booleans_processing(config, **kwargs): return final_booleans +def unpack_inputs(func): + """ + Decorator that processes the inputs to a Keras layer, passing them to the layer as keyword arguments. This enables + downstream use of the inputs by their variable name, even if they arrive packed as a dictionary in the first input + (common case in Keras). + + Args: + func (`callable`): + The callable function of the TensorFlow model. + + Returns: + A callable that wraps the original `func` with the behavior described above. + """ + + original_signature = inspect.signature(func) + + @functools.wraps(func) + def run_call_with_unpacked_inputs(self, *args, **kwargs): + # isolates the actual `**kwargs` for the decorated function + kwargs_call = {key: val for key, val in kwargs.items() if key not in dict(original_signature.parameters)} + fn_args_and_kwargs = {key: val for key, val in kwargs.items() if key not in kwargs_call} + fn_args_and_kwargs.update({"kwargs_call": kwargs_call}) + + # move any arg into kwargs, if they exist + fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args))) + + # process the inputs and call the wrapped function + main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1]) + main_input = fn_args_and_kwargs.pop(main_input_name) + unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs) + return func(self, **unpacked_inputs) + + # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This + # function does not follow wrapper chains (i.e. ignores `functools.wraps()`), meaning that without the line below + # Keras would attempt to check the first argument against the literal signature of the wrapper. + run_call_with_unpacked_inputs.__signature__ = original_signature + + return run_call_with_unpacked_inputs + + def input_processing(func, config, input_ids, **kwargs): """ Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input diff --git a/src/transformers/models/bert/modeling_tf_bert.py b/src/transformers/models/bert/modeling_tf_bert.py index f83dc186598b..5aaebaea5bc3 100644 --- a/src/transformers/models/bert/modeling_tf_bert.py +++ b/src/transformers/models/bert/modeling_tf_bert.py @@ -55,8 +55,8 @@ TFSequenceClassificationLoss, TFTokenClassificationLoss, get_initializer, - input_processing, keras_serializable, + unpack_inputs, ) from ...tf_utils import shape_list from ...utils import logging @@ -720,6 +720,7 @@ class PreTrainedModel """ raise NotImplementedError + @unpack_inputs def call( self, input_ids: Optional[TFModelInputType] = None, @@ -738,59 +739,40 @@ def call( training: bool = False, **kwargs, ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) if not self.config.is_decoder: - inputs["use_cache"] = False + use_cache = False - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape - if inputs["past_key_values"] is None: + if past_key_values is None: past_key_values_length = 0 - inputs["past_key_values"] = [None] * len(self.encoder.layer) + past_key_values = [None] * len(self.encoder.layer) else: - past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2] + past_key_values_length = shape_list(past_key_values[0][0])[-2] - if inputs["attention_mask"] is None: - inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) - if inputs["token_type_ids"] is None: - inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) embedding_output = self.embeddings( - input_ids=inputs["input_ids"], - position_ids=inputs["position_ids"], - token_type_ids=inputs["token_type_ids"], - inputs_embeds=inputs["inputs_embeds"], + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, - training=inputs["training"], + training=training, ) # We create a 3D attention mask from a 2D tensor mask. @@ -798,7 +780,7 @@ def call( # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(inputs["attention_mask"]) + attention_mask_shape = shape_list(attention_mask) mask_seq_length = seq_length + past_key_values_length # Copied from `modeling_tf_t5.py` @@ -811,18 +793,18 @@ def call( tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), seq_ids[None, :, None], ) - causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype) - extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :] + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] attention_mask_shape = shape_list(extended_attention_mask) extended_attention_mask = tf.reshape( extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) ) - if inputs["past_key_values"][0] is not None: + if past_key_values[0] is not None: # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] else: extended_attention_mask = tf.reshape( - inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for @@ -836,18 +818,16 @@ def call( extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 - if self.is_decoder and inputs["encoder_attention_mask"] is not None: + if self.is_decoder and encoder_attention_mask is not None: # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - inputs["encoder_attention_mask"] = tf.cast( - inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype - ) - num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"])) + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 @@ -863,29 +843,29 @@ def call( # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if inputs["head_mask"] is not None: + if head_mask is not None: raise NotImplementedError else: - inputs["head_mask"] = [None] * self.config.num_hidden_layers + head_mask = [None] * self.config.num_hidden_layers encoder_outputs = self.encoder( hidden_states=embedding_output, attention_mask=extended_attention_mask, - head_mask=inputs["head_mask"], - encoder_hidden_states=inputs["encoder_hidden_states"], + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=inputs["past_key_values"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None - if not inputs["return_dict"]: + if not return_dict: return ( sequence_output, pooled_output, @@ -1063,6 +1043,7 @@ def __init__(self, config: BertConfig, *inputs, **kwargs): self.bert = TFBertMainLayer(config, name="bert") + @unpack_inputs @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1108,9 +1089,7 @@ def call( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). Set to `False` during training, `True` during generation """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1125,25 +1104,7 @@ def call( output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, - kwargs_call=kwargs, ) - outputs = self.bert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - encoder_hidden_states=inputs["encoder_hidden_states"], - encoder_attention_mask=inputs["encoder_attention_mask"], - past_key_values=inputs["past_key_values"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], - ) - return outputs def serving_output( @@ -1196,6 +1157,7 @@ def get_prefix_bias_name(self) -> str: warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name + @unpack_inputs @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) def call( @@ -1244,9 +1206,7 @@ def call( >>> outputs = model(input_ids) >>> prediction_scores, seq_relationship_scores = outputs[:2] ```""" - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1256,34 +1216,19 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, - next_sentence_label=next_sentence_label, training=training, - kwargs_call=kwargs, - ) - outputs = self.bert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output, pooled_output = outputs[:2] - prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"]) + prediction_scores = self.mlm(sequence_output=sequence_output, training=training) seq_relationship_score = self.nsp(pooled_output=pooled_output) total_loss = None - if inputs["labels"] is not None and inputs["next_sentence_label"] is not None: - d_labels = {"labels": inputs["labels"]} - d_labels["next_sentence_label"] = inputs["next_sentence_label"] + if labels is not None and next_sentence_label is not None: + d_labels = {"labels": labels} + d_labels["next_sentence_label"] = next_sentence_label total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score)) - if not inputs["return_dict"]: + if not return_dict: output = (prediction_scores, seq_relationship_score) + outputs[2:] return ((total_loss,) + output) if total_loss is not None else output @@ -1336,6 +1281,7 @@ def get_prefix_bias_name(self) -> str: warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name + @unpack_inputs @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1364,9 +1310,7 @@ def call( config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1376,31 +1320,13 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.bert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] - prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"]) - loss = ( - None - if inputs["labels"] is None - else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores) - ) + prediction_scores = self.mlm(sequence_output=sequence_output, training=training) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) - if not inputs["return_dict"]: + if not return_dict: output = (prediction_scores,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1455,6 +1381,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + @unpack_inputs @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC, @@ -1503,9 +1430,7 @@ def call( Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., config.vocab_size - 1]`. """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1519,37 +1444,19 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.bert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - encoder_hidden_states=inputs["encoder_hidden_states"], - encoder_attention_mask=inputs["encoder_attention_mask"], - past_key_values=inputs["past_key_values"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] - logits = self.mlm(sequence_output=sequence_output, training=inputs["training"]) + logits = self.mlm(sequence_output=sequence_output, training=training) loss = None - if inputs["labels"] is not None: + if labels is not None: # shift labels to the left and cut last logit token shifted_logits = logits[:, :-1] - labels = inputs["labels"][:, 1:] + labels = labels[:, 1:] loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1597,6 +1504,7 @@ def __init__(self, config: BertConfig, *inputs, **kwargs): self.bert = TFBertMainLayer(config, name="bert") self.nsp = TFBertNSPHead(config, name="nsp___cls") + @unpack_inputs @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) def call( @@ -1633,9 +1541,7 @@ def call( >>> logits = model(encoding["input_ids"], token_type_ids=encoding["token_type_ids"])[0] >>> assert logits[0][0] < logits[0][1] # the next sentence was random ```""" - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1645,31 +1551,17 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - next_sentence_label=next_sentence_label, training=training, - kwargs_call=kwargs, - ) - outputs = self.bert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) pooled_output = outputs[1] seq_relationship_scores = self.nsp(pooled_output=pooled_output) next_sentence_loss = ( None - if inputs["next_sentence_label"] is None - else self.hf_compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores) + if next_sentence_label is None + else self.hf_compute_loss(labels=next_sentence_label, logits=seq_relationship_scores) ) - if not inputs["return_dict"]: + if not return_dict: output = (seq_relationship_scores,) + outputs[2:] return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output @@ -1715,6 +1607,7 @@ def __init__(self, config: BertConfig, *inputs, **kwargs): name="classifier", ) + @unpack_inputs @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1743,9 +1636,7 @@ def call( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1755,28 +1646,14 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.bert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) pooled_output = outputs[1] - pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"]) + pooled_output = self.dropout(inputs=pooled_output, training=training) logits = self.classifier(inputs=pooled_output) - loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1825,6 +1702,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: """ return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} + @unpack_inputs @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1852,51 +1730,26 @@ def call( Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - labels=labels, - training=training, - kwargs_call=kwargs, - ) - - if inputs["input_ids"] is not None: - num_choices = shape_list(inputs["input_ids"])[1] - seq_length = shape_list(inputs["input_ids"])[2] + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] else: - num_choices = shape_list(inputs["inputs_embeds"])[1] - seq_length = shape_list(inputs["inputs_embeds"])[2] + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] - flat_input_ids = ( - tf.reshape(tensor=inputs["input_ids"], shape=(-1, seq_length)) if inputs["input_ids"] is not None else None - ) + flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None flat_attention_mask = ( - tf.reshape(tensor=inputs["attention_mask"], shape=(-1, seq_length)) - if inputs["attention_mask"] is not None - else None + tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None ) flat_token_type_ids = ( - tf.reshape(tensor=inputs["token_type_ids"], shape=(-1, seq_length)) - if inputs["token_type_ids"] is not None - else None + tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None ) flat_position_ids = ( - tf.reshape(tensor=inputs["position_ids"], shape=(-1, seq_length)) - if inputs["position_ids"] is not None - else None + tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None ) flat_inputs_embeds = ( - tf.reshape(tensor=inputs["inputs_embeds"], shape=(-1, seq_length, shape_list(inputs["inputs_embeds"])[3])) - if inputs["inputs_embeds"] is not None + tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None else None ) outputs = self.bert( @@ -1904,22 +1757,20 @@ def call( attention_mask=flat_attention_mask, token_type_ids=flat_token_type_ids, position_ids=flat_position_ids, - head_mask=inputs["head_mask"], + head_mask=head_mask, inputs_embeds=flat_inputs_embeds, - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) pooled_output = outputs[1] - pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"]) + pooled_output = self.dropout(inputs=pooled_output, training=training) logits = self.classifier(inputs=pooled_output) reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) - loss = ( - None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=reshaped_logits) - ) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits) - if not inputs["return_dict"]: + if not return_dict: output = (reshaped_logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1985,6 +1836,7 @@ def __init__(self, config: BertConfig, *inputs, **kwargs): name="classifier", ) + @unpack_inputs @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -2011,9 +1863,7 @@ def call( labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -2023,28 +1873,14 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.bert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] - sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"]) + sequence_output = self.dropout(inputs=sequence_output, training=training) logits = self.classifier(inputs=sequence_output) - loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -2091,6 +1927,7 @@ def __init__(self, config: BertConfig, *inputs, **kwargs): name="qa_outputs", ) + @unpack_inputs @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -2124,9 +1961,7 @@ def call( Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -2136,22 +1971,7 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - start_positions=start_positions, - end_positions=end_positions, training=training, - kwargs_call=kwargs, - ) - outputs = self.bert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] logits = self.qa_outputs(inputs=sequence_output) @@ -2160,12 +1980,12 @@ def call( end_logits = tf.squeeze(input=end_logits, axis=-1) loss = None - if inputs["start_positions"] is not None and inputs["end_positions"] is not None: - labels = {"start_position": inputs["start_positions"]} - labels["end_position"] = inputs["end_positions"] + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) - if not inputs["return_dict"]: + if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((loss,) + output) if loss is not None else output diff --git a/src/transformers/models/rembert/modeling_tf_rembert.py b/src/transformers/models/rembert/modeling_tf_rembert.py index 201e904d952b..f9e330735635 100644 --- a/src/transformers/models/rembert/modeling_tf_rembert.py +++ b/src/transformers/models/rembert/modeling_tf_rembert.py @@ -49,8 +49,8 @@ TFSequenceClassificationLoss, TFTokenClassificationLoss, get_initializer, - input_processing, keras_serializable, + unpack_inputs, ) from ...tf_utils import shape_list from ...utils import logging @@ -642,6 +642,7 @@ class PreTrainedModel """ raise NotImplementedError + @unpack_inputs # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call def call( self, @@ -661,59 +662,40 @@ def call( training: bool = False, **kwargs, ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) if not self.config.is_decoder: - inputs["use_cache"] = False + use_cache = False - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape - if inputs["past_key_values"] is None: + if past_key_values is None: past_key_values_length = 0 - inputs["past_key_values"] = [None] * len(self.encoder.layer) + past_key_values = [None] * len(self.encoder.layer) else: - past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2] + past_key_values_length = shape_list(past_key_values[0][0])[-2] - if inputs["attention_mask"] is None: - inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) - if inputs["token_type_ids"] is None: - inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) embedding_output = self.embeddings( - input_ids=inputs["input_ids"], - position_ids=inputs["position_ids"], - token_type_ids=inputs["token_type_ids"], - inputs_embeds=inputs["inputs_embeds"], + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, - training=inputs["training"], + training=training, ) # We create a 3D attention mask from a 2D tensor mask. @@ -721,7 +703,7 @@ def call( # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(inputs["attention_mask"]) + attention_mask_shape = shape_list(attention_mask) mask_seq_length = seq_length + past_key_values_length # Copied from `modeling_tf_t5.py` @@ -734,18 +716,18 @@ def call( tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), seq_ids[None, :, None], ) - causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype) - extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :] + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] attention_mask_shape = shape_list(extended_attention_mask) extended_attention_mask = tf.reshape( extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) ) - if inputs["past_key_values"][0] is not None: + if past_key_values[0] is not None: # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] else: extended_attention_mask = tf.reshape( - inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for @@ -759,18 +741,16 @@ def call( extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 - if self.is_decoder and inputs["encoder_attention_mask"] is not None: + if self.is_decoder and encoder_attention_mask is not None: # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - inputs["encoder_attention_mask"] = tf.cast( - inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype - ) - num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"])) + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 @@ -786,29 +766,29 @@ def call( # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if inputs["head_mask"] is not None: + if head_mask is not None: raise NotImplementedError else: - inputs["head_mask"] = [None] * self.config.num_hidden_layers + head_mask = [None] * self.config.num_hidden_layers encoder_outputs = self.encoder( hidden_states=embedding_output, attention_mask=extended_attention_mask, - head_mask=inputs["head_mask"], - encoder_hidden_states=inputs["encoder_hidden_states"], + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=inputs["past_key_values"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None - if not inputs["return_dict"]: + if not return_dict: return ( sequence_output, pooled_output, @@ -955,6 +935,7 @@ def __init__(self, config: RemBertConfig, *inputs, **kwargs): self.rembert = TFRemBertMainLayer(config, name="rembert") + @unpack_inputs @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1000,9 +981,7 @@ def call( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). Set to `False` during training, `True` during generation """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.rembert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1017,23 +996,6 @@ def call( output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, - kwargs_call=kwargs, - ) - outputs = self.rembert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - encoder_hidden_states=inputs["encoder_hidden_states"], - encoder_attention_mask=inputs["encoder_attention_mask"], - past_key_values=inputs["past_key_values"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) return outputs @@ -1077,6 +1039,7 @@ def __init__(self, config: RemBertConfig, *inputs, **kwargs): def get_lm_head(self) -> tf.keras.layers.Layer: return self.mlm.predictions + @unpack_inputs @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1105,9 +1068,7 @@ def call( config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.rembert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1117,31 +1078,13 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.rembert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] - prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"]) - loss = ( - None - if inputs["labels"] is None - else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores) - ) + prediction_scores = self.mlm(sequence_output=sequence_output, training=training) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) - if not inputs["return_dict"]: + if not return_dict: output = (prediction_scores,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1188,6 +1131,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + @unpack_inputs @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, checkpoint="rembert", @@ -1236,9 +1180,7 @@ def call( Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., config.vocab_size - 1]`. """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.rembert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1252,37 +1194,19 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.rembert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - encoder_hidden_states=inputs["encoder_hidden_states"], - encoder_attention_mask=inputs["encoder_attention_mask"], - past_key_values=inputs["past_key_values"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] - logits = self.mlm(sequence_output=sequence_output, training=inputs["training"]) + logits = self.mlm(sequence_output=sequence_output, training=training) loss = None - if inputs["labels"] is not None: + if labels is not None: # shift labels to the left and cut last logit token shifted_logits = logits[:, :-1] - labels = inputs["labels"][:, 1:] + labels = labels[:, 1:] loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1338,6 +1262,7 @@ def __init__(self, config: RemBertConfig, *inputs, **kwargs): name="classifier", ) + @unpack_inputs @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1366,9 +1291,7 @@ def call( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.rembert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1378,28 +1301,14 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.rembert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) pooled_output = outputs[1] - pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"]) + pooled_output = self.dropout(inputs=pooled_output, training=training) logits = self.classifier(inputs=pooled_output) - loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1444,6 +1353,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: """ return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} + @unpack_inputs @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1471,51 +1381,27 @@ def call( Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - labels=labels, - training=training, - kwargs_call=kwargs, - ) - if inputs["input_ids"] is not None: - num_choices = shape_list(inputs["input_ids"])[1] - seq_length = shape_list(inputs["input_ids"])[2] + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] else: - num_choices = shape_list(inputs["inputs_embeds"])[1] - seq_length = shape_list(inputs["inputs_embeds"])[2] + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] - flat_input_ids = ( - tf.reshape(tensor=inputs["input_ids"], shape=(-1, seq_length)) if inputs["input_ids"] is not None else None - ) + flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None flat_attention_mask = ( - tf.reshape(tensor=inputs["attention_mask"], shape=(-1, seq_length)) - if inputs["attention_mask"] is not None - else None + tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None ) flat_token_type_ids = ( - tf.reshape(tensor=inputs["token_type_ids"], shape=(-1, seq_length)) - if inputs["token_type_ids"] is not None - else None + tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None ) flat_position_ids = ( - tf.reshape(tensor=inputs["position_ids"], shape=(-1, seq_length)) - if inputs["position_ids"] is not None - else None + tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None ) flat_inputs_embeds = ( - tf.reshape(tensor=inputs["inputs_embeds"], shape=(-1, seq_length, shape_list(inputs["inputs_embeds"])[3])) - if inputs["inputs_embeds"] is not None + tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None else None ) outputs = self.rembert( @@ -1523,22 +1409,20 @@ def call( attention_mask=flat_attention_mask, token_type_ids=flat_token_type_ids, position_ids=flat_position_ids, - head_mask=inputs["head_mask"], + head_mask=head_mask, inputs_embeds=flat_inputs_embeds, - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) pooled_output = outputs[1] - pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"]) + pooled_output = self.dropout(inputs=pooled_output, training=training) logits = self.classifier(inputs=pooled_output) reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) - loss = ( - None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=reshaped_logits) - ) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits) - if not inputs["return_dict"]: + if not return_dict: output = (reshaped_logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1589,6 +1473,7 @@ def __init__(self, config: RemBertConfig, *inputs, **kwargs): units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) + @unpack_inputs @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1615,9 +1500,7 @@ def call( labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.rembert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1627,28 +1510,14 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.rembert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] - sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"]) + sequence_output = self.dropout(inputs=sequence_output, training=training) logits = self.classifier(inputs=sequence_output) - loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1684,6 +1553,7 @@ def __init__(self, config: RemBertConfig, *inputs, **kwargs): units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" ) + @unpack_inputs @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1717,9 +1587,7 @@ def call( Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.rembert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1729,22 +1597,7 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - start_positions=start_positions, - end_positions=end_positions, training=training, - kwargs_call=kwargs, - ) - outputs = self.rembert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] logits = self.qa_outputs(inputs=sequence_output) @@ -1753,12 +1606,12 @@ def call( end_logits = tf.squeeze(input=end_logits, axis=-1) loss = None - if inputs["start_positions"] is not None and inputs["end_positions"] is not None: - labels = {"start_position": inputs["start_positions"]} - labels["end_position"] = inputs["end_positions"] + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) - if not inputs["return_dict"]: + if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((loss,) + output) if loss is not None else output diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py index ee9e3d1457e8..98bde182de7f 100644 --- a/src/transformers/models/roberta/modeling_tf_roberta.py +++ b/src/transformers/models/roberta/modeling_tf_roberta.py @@ -50,8 +50,8 @@ TFSequenceClassificationLoss, TFTokenClassificationLoss, get_initializer, - input_processing, keras_serializable, + unpack_inputs, ) from ...tf_utils import shape_list from ...utils import logging @@ -606,6 +606,7 @@ class PreTrainedModel """ raise NotImplementedError + @unpack_inputs # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call def call( self, @@ -625,59 +626,40 @@ def call( training: bool = False, **kwargs, ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) if not self.config.is_decoder: - inputs["use_cache"] = False + use_cache = False - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape - if inputs["past_key_values"] is None: + if past_key_values is None: past_key_values_length = 0 - inputs["past_key_values"] = [None] * len(self.encoder.layer) + past_key_values = [None] * len(self.encoder.layer) else: - past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2] + past_key_values_length = shape_list(past_key_values[0][0])[-2] - if inputs["attention_mask"] is None: - inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) - if inputs["token_type_ids"] is None: - inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) embedding_output = self.embeddings( - input_ids=inputs["input_ids"], - position_ids=inputs["position_ids"], - token_type_ids=inputs["token_type_ids"], - inputs_embeds=inputs["inputs_embeds"], + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, - training=inputs["training"], + training=training, ) # We create a 3D attention mask from a 2D tensor mask. @@ -685,7 +667,7 @@ def call( # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(inputs["attention_mask"]) + attention_mask_shape = shape_list(attention_mask) mask_seq_length = seq_length + past_key_values_length # Copied from `modeling_tf_t5.py` @@ -698,18 +680,18 @@ def call( tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), seq_ids[None, :, None], ) - causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype) - extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :] + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] attention_mask_shape = shape_list(extended_attention_mask) extended_attention_mask = tf.reshape( extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) ) - if inputs["past_key_values"][0] is not None: + if past_key_values[0] is not None: # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] else: extended_attention_mask = tf.reshape( - inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for @@ -723,18 +705,16 @@ def call( extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 - if self.is_decoder and inputs["encoder_attention_mask"] is not None: + if self.is_decoder and encoder_attention_mask is not None: # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - inputs["encoder_attention_mask"] = tf.cast( - inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype - ) - num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"])) + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 @@ -750,29 +730,29 @@ def call( # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if inputs["head_mask"] is not None: + if head_mask is not None: raise NotImplementedError else: - inputs["head_mask"] = [None] * self.config.num_hidden_layers + head_mask = [None] * self.config.num_hidden_layers encoder_outputs = self.encoder( hidden_states=embedding_output, attention_mask=extended_attention_mask, - head_mask=inputs["head_mask"], - encoder_hidden_states=inputs["encoder_hidden_states"], + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=inputs["past_key_values"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None - if not inputs["return_dict"]: + if not return_dict: return ( sequence_output, pooled_output, @@ -932,6 +912,7 @@ def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.roberta = TFRobertaMainLayer(config, name="roberta") + @unpack_inputs @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -977,9 +958,7 @@ def call( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). Set to `False` during training, `True` during generation """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.roberta( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -994,23 +973,6 @@ def call( output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, - kwargs_call=kwargs, - ) - outputs = self.roberta( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - encoder_hidden_states=inputs["encoder_hidden_states"], - encoder_attention_mask=inputs["encoder_attention_mask"], - past_key_values=inputs["past_key_values"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) return outputs @@ -1107,6 +1069,7 @@ def get_prefix_bias_name(self): warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.lm_head.name + @unpack_inputs @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1135,10 +1098,8 @@ def call( config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, + outputs = self.roberta( + input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1147,29 +1108,15 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.roberta( - inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] prediction_scores = self.lm_head(sequence_output) - loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores) + loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) - if not inputs["return_dict"]: + if not return_dict: output = (prediction_scores,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1221,6 +1168,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + @unpack_inputs @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1270,9 +1218,7 @@ def call( Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., config.vocab_size - 1]`. """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.roberta( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1286,38 +1232,20 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.roberta( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - encoder_hidden_states=inputs["encoder_hidden_states"], - encoder_attention_mask=inputs["encoder_attention_mask"], - past_key_values=inputs["past_key_values"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] - logits = self.lm_head(hidden_states=sequence_output, training=inputs["training"]) + logits = self.lm_head(hidden_states=sequence_output, training=training) loss = None - if inputs["labels"] is not None: + if labels is not None: # shift labels to the left and cut last logit token shifted_logits = logits[:, :-1] - labels = inputs["labels"][:, 1:] + labels = labels[:, 1:] loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1399,6 +1327,7 @@ def __init__(self, config, *inputs, **kwargs): self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") self.classifier = TFRobertaClassificationHead(config, name="classifier") + @unpack_inputs @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1427,10 +1356,8 @@ def call( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, + outputs = self.roberta( + input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1439,28 +1366,14 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.roberta( - inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] - logits = self.classifier(sequence_output, training=inputs["training"]) + logits = self.classifier(sequence_output, training=training) - loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) + loss = None if labels is None else self.hf_compute_loss(labels, logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1510,6 +1423,7 @@ def dummy_inputs(self): """ return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} + @unpack_inputs @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1537,60 +1451,38 @@ def call( Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - labels=labels, - training=training, - kwargs_call=kwargs, - ) - if inputs["input_ids"] is not None: - num_choices = shape_list(inputs["input_ids"])[1] - seq_length = shape_list(inputs["input_ids"])[2] + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] else: num_choices = shape_list(inputs_embeds)[1] seq_length = shape_list(inputs_embeds)[2] - flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None - flat_attention_mask = ( - tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None - ) - flat_token_type_ids = ( - tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None - ) - flat_position_ids = ( - tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None - ) + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None outputs = self.roberta( flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, - inputs["head_mask"], - inputs["inputs_embeds"], - inputs["output_attentions"], - inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + head_mask, + inputs_embeds, + output_attentions, + output_hidden_states, + return_dict=return_dict, + training=training, ) pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, training=inputs["training"]) + pooled_output = self.dropout(pooled_output, training=training) logits = self.classifier(pooled_output) reshaped_logits = tf.reshape(logits, (-1, num_choices)) - loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits) + loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - if not inputs["return_dict"]: + if not return_dict: output = (reshaped_logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1647,6 +1539,7 @@ def __init__(self, config, *inputs, **kwargs): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) + @unpack_inputs @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1673,10 +1566,8 @@ def call( labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, + outputs = self.roberta( + input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1685,30 +1576,16 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.roberta( - inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output, training=inputs["training"]) + sequence_output = self.dropout(sequence_output, training=training) logits = self.classifier(sequence_output) - loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) + loss = None if labels is None else self.hf_compute_loss(labels, logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1747,6 +1624,7 @@ def __init__(self, config, *inputs, **kwargs): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" ) + @unpack_inputs @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1780,10 +1658,8 @@ def call( Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, + outputs = self.roberta( + input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1792,22 +1668,7 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - start_positions=start_positions, - end_positions=end_positions, training=training, - kwargs_call=kwargs, - ) - outputs = self.roberta( - inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] @@ -1817,12 +1678,12 @@ def call( end_logits = tf.squeeze(end_logits, axis=-1) loss = None - if inputs["start_positions"] is not None and inputs["end_positions"] is not None: - labels = {"start_position": inputs["start_positions"]} - labels["end_position"] = inputs["end_positions"] + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - if not inputs["return_dict"]: + if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((loss,) + output) if loss is not None else output diff --git a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py index 1e8e80f2622a..e2a4c4cccc04 100755 --- a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py @@ -37,8 +37,8 @@ TFCausalLanguageModelingLoss, TFPreTrainedModel, TFSharedEmbeddings, - input_processing, keras_serializable, + unpack_inputs, ) from ...tf_utils import shape_list from ...utils import logging @@ -781,6 +781,7 @@ def _get_feature_vector_attention_mask(self, feature_vector_length, attention_ma attention_mask = tf.cast(tf.reverse(tf.math.cumsum(tf.reverse(attention_mask, [-1]), -1), [-1]), tf.int64) return attention_mask + @unpack_inputs def call( self, input_features=None, @@ -822,81 +823,64 @@ def call( return_dict (`bool`, *optional*): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_features, - attention_mask=attention_mask, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - if "input_ids" in inputs: - inputs["input_features"] = inputs.pop("input_ids") - - if inputs["input_features"] is None: + if input_features is None: raise ValueError("You have to specify input_features") - inputs_embeds = self.conv(inputs["input_features"]) + inputs_embeds = self.conv(input_features) inputs_embeds = self.embed_scale * inputs_embeds # subsample attention mask if necessary - if inputs["attention_mask"] is not None: - inputs["attention_mask"] = self._get_feature_vector_attention_mask( - inputs_embeds.shape[1], inputs["attention_mask"] - ) - padding_mask = tf.cast(tf.math.not_equal(inputs["attention_mask"], 1), tf.int64) + if attention_mask is not None: + attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask) + padding_mask = tf.cast(tf.math.not_equal(attention_mask, 1), tf.int64) else: padding_mask = tf.zeros(inputs_embeds.shape[:-1], dtype=tf.int64) embed_pos = self.embed_positions(padding_mask) hidden_states = inputs_embeds + embed_pos - hidden_states = self.dropout(hidden_states, training=inputs["training"]) + hidden_states = self.dropout(hidden_states, training=training) # check attention mask and invert - if inputs["attention_mask"] is not None: + if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - inputs["attention_mask"] = _expand_mask(inputs["attention_mask"]) + attention_mask = _expand_mask(attention_mask) - encoder_states = () if inputs["output_hidden_states"] else None - all_attentions = () if inputs["output_attentions"] else None + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager. - if inputs["head_mask"] is not None and tf.executing_eagerly(): + if head_mask is not None and tf.executing_eagerly(): tf.debugging.assert_equal( - shape_list(inputs["head_mask"])[0], + shape_list(head_mask)[0], len(self.layers), - message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.", ) for idx, encoder_layer in enumerate(self.layers): - if inputs["output_hidden_states"]: + if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) - if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer + if training and (dropout_probability < self.layerdrop): # skip the layer continue hidden_states, attn = encoder_layer( hidden_states, - inputs["attention_mask"], - inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + attention_mask, + head_mask[idx] if head_mask is not None else None, training=training, ) - if inputs["output_attentions"]: + if output_attentions: all_attentions += (attn,) hidden_states = self.layer_norm(hidden_states) - if inputs["output_hidden_states"]: + if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not inputs["return_dict"]: + if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return TFBaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions @@ -957,6 +941,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em return combined_attention_mask + @unpack_inputs def call( self, input_ids=None, @@ -1034,114 +1019,92 @@ def call( return_dict (`bool`, *optional*): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - inputs_embeds=inputs_embeds, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") # past_key_values_length - past_key_values_length = ( - shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0 - ) + past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 - if inputs["inputs_embeds"] is None: - inputs_embeds = self.embed_tokens(inputs["input_ids"]) * self.embed_scale + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale else: - inputs_embeds = inputs["inputs_embeds"] + inputs_embeds = inputs_embeds attention_mask = self._prepare_decoder_attention_mask( - inputs["attention_mask"], input_shape, inputs_embeds, past_key_values_length + attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask - if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None: + if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1]) + encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) # embed positions - positions = self.embed_positions(inputs["input_ids"], past_key_values_length=past_key_values_length) + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) hidden_states = inputs_embeds + positions - hidden_states = self.dropout(hidden_states, training=inputs["training"]) + hidden_states = self.dropout(hidden_states, training=training) # decoder layers - all_hidden_states = () if inputs["output_hidden_states"] else None - all_self_attns = () if inputs["output_attentions"] else None - all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None - next_decoder_cache = () if inputs["use_cache"] else None + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager. - for attn_mask in ["head_mask", "cross_attn_head_mask"]: - if inputs[attn_mask] is not None and tf.executing_eagerly(): + for attn_mask in [head_mask, cross_attn_head_mask]: + if attn_mask is not None and tf.executing_eagerly(): tf.debugging.assert_equal( - shape_list(inputs[attn_mask])[0], + shape_list(attn_mask)[0], len(self.layers), - message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.", + message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if inputs["output_hidden_states"]: + if output_hidden_states: all_hidden_states += (hidden_states,) dropout_probability = random.uniform(0, 1) - if inputs["training"] and (dropout_probability < self.layerdrop): + if training and (dropout_probability < self.layerdrop): continue - past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None - cross_attn_layer_head_mask = ( - inputs["cross_attn_head_mask"][idx] if inputs["cross_attn_head_mask"] is not None else None - ) + past_key_value = past_key_values[idx] if past_key_values is not None else None + cross_attn_layer_head_mask = cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, attention_mask=attention_mask, - encoder_hidden_states=inputs["encoder_hidden_states"], - encoder_attention_mask=inputs["encoder_attention_mask"], - layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, cross_attn_layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, ) - if inputs["use_cache"]: + if use_cache: next_decoder_cache += (present_key_value,) - if inputs["output_attentions"]: + if output_attentions: all_self_attns += (layer_self_attn,) - if inputs["encoder_hidden_states"] is not None: + if encoder_hidden_states is not None: all_cross_attns += (layer_cross_attn,) hidden_states = self.layer_norm(hidden_states) - if inputs["output_hidden_states"]: + if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None - if not inputs["return_dict"]: + if not return_dict: return hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attns else: return TFBaseModelOutputWithPastAndCrossAttentions( @@ -1170,6 +1133,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.decoder.embed_tokens = new_embeddings + @unpack_inputs def call( self, input_features=None, @@ -1189,85 +1153,61 @@ def call( training=False, **kwargs ): - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_features, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - if "input_ids" in inputs: - inputs["input_features"] = inputs.pop("input_ids") - - # if the attribute is not set, fetch it from the config - for attr in ("output_attentions", "output_hidden_states", "use_cache"): - if inputs[attr] is None: - inputs[attr] = getattr(self.config, attr) - inputs["return_dict"] = ( - inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if inputs["encoder_outputs"] is None: - inputs["encoder_outputs"] = self.encoder( - input_features=inputs["input_features"], - attention_mask=inputs["attention_mask"], - head_mask=inputs["head_mask"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_features=input_features, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True - elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFBaseModelOutput): - inputs["encoder_outputs"] = TFBaseModelOutput( - last_hidden_state=inputs["encoder_outputs"][0], - hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None, - attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None, + elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): + encoder_outputs = TFBaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False - elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple): - inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple() + elif not return_dict and not isinstance(encoder_outputs, tuple): + encoder_outputs = encoder_outputs.to_tuple() # downsample encoder attention mask - if inputs["attention_mask"] is not None: - inputs["encoder_attention_mask"] = self.encoder._get_feature_vector_attention_mask( - inputs["encoder_outputs"][0].shape[1], inputs["attention_mask"] + if attention_mask is not None: + encoder_attention_mask = self.encoder._get_feature_vector_attention_mask( + encoder_outputs[0].shape[1], attention_mask ) else: - inputs["encoder_attention_mask"] = None + encoder_attention_mask = None # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( - input_ids=inputs["decoder_input_ids"], - attention_mask=inputs["decoder_attention_mask"], - encoder_hidden_states=inputs["encoder_outputs"][0], - encoder_attention_mask=inputs["encoder_attention_mask"], - head_mask=inputs["decoder_head_mask"], - cross_attn_head_mask=inputs["cross_attn_head_mask"], - past_key_values=inputs["past_key_values"], - inputs_embeds=inputs["decoder_inputs_embeds"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) - if not inputs["return_dict"]: - return decoder_outputs + inputs["encoder_outputs"] + if not return_dict: + return decoder_outputs + encoder_outputs return TFSeq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, @@ -1275,9 +1215,9 @@ def call( decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, - encoder_hidden_states=inputs["encoder_outputs"].hidden_states, - encoder_attentions=inputs["encoder_outputs"].attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, ) @@ -1297,6 +1237,7 @@ def get_encoder(self): def get_decoder(self): return self.model.decoder + @unpack_inputs @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1323,10 +1264,8 @@ def call( training=False, **kwargs ): - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_features, + outputs = self.model( + input_features=input_features, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, @@ -1341,27 +1280,6 @@ def call( output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, - kwargs_call=kwargs, - ) - if "input_ids" in inputs: - inputs["input_features"] = inputs.pop("input_ids") - - outputs = self.model( - input_features=inputs["input_features"], - attention_mask=inputs["attention_mask"], - decoder_input_ids=inputs["decoder_input_ids"], - decoder_attention_mask=inputs["decoder_attention_mask"], - head_mask=inputs["head_mask"], - decoder_head_mask=inputs["decoder_head_mask"], - cross_attn_head_mask=inputs["cross_attn_head_mask"], - encoder_outputs=inputs["encoder_outputs"], - past_key_values=inputs["past_key_values"], - decoder_inputs_embeds=inputs["decoder_inputs_embeds"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) return outputs @@ -1412,6 +1330,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + @unpack_inputs @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def call( @@ -1473,61 +1392,35 @@ def call( >>> transcription = processor.batch_decode(generated_ids) ```""" - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_features, + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_features=input_features, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, past_key_values=past_key_values, decoder_inputs_embeds=decoder_inputs_embeds, - labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, - kwargs_call=kwargs, - ) - if "input_ids" in inputs: - inputs["input_features"] = inputs.pop("input_ids") - - inputs["return_dict"] = ( - inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict - ) - - if inputs["labels"] is not None: - if inputs["decoder_input_ids"] is None: - inputs["decoder_input_ids"] = shift_tokens_right( - inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id - ) - - outputs = self.model( - input_features=inputs["input_features"], - attention_mask=inputs["attention_mask"], - decoder_input_ids=inputs["decoder_input_ids"], - encoder_outputs=inputs["encoder_outputs"], - decoder_attention_mask=inputs["decoder_attention_mask"], - head_mask=inputs["head_mask"], - decoder_head_mask=inputs["decoder_head_mask"], - cross_attn_head_mask=inputs["cross_attn_head_mask"], - past_key_values=inputs["past_key_values"], - decoder_inputs_embeds=inputs["decoder_inputs_embeds"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) lm_logits = self.lm_head(outputs[0]) - masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits) + masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) - if not inputs["return_dict"]: + if not return_dict: output = (lm_logits,) + outputs[1:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output