diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 31f019e1b8e9..247467702e04 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -867,9 +867,8 @@ def _generate_beam_search( beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,)) - # cache compute states - past = encoder_outputs - # to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None + # variable to cache compute states + past = None # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None @@ -886,6 +885,13 @@ def _generate_beam_search( if (return_dict_in_generate and kwargs["encoder_hidden_states"]) else None ) + # the refactored generate, without the encoder outputs in `past`, expects the `encoder_outputs` + # variable to contain all (encoder_outputs, encoder_hidden_states, encoder_attentions) in + # `prepare_inputs_for_generation` + if encoder_hidden_states is not None: + encoder_outputs = (*encoder_outputs, encoder_hidden_states) + if encoder_attentions is not None: + encoder_outputs = (*encoder_outputs, encoder_attentions) # done sentences done = [False for _ in range(batch_size)] @@ -896,6 +902,7 @@ def _generate_beam_search( past=past, attention_mask=attention_mask, use_cache=use_cache, + encoder_outputs=encoder_outputs, **kwargs, ) outputs = self( @@ -1486,14 +1493,10 @@ def _generate( if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id) + # 4. Prepare model inputs which will be used for auto-regressive generation if self.config.is_encoder_decoder: - # if model is encoder decoder model, we create encoder_outputs and add to `model_kwargs` - model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( - input_ids, return_dict_in_generate, model_kwargs - ) - - # 4. Prepare `input_ids` which will be used for auto-regressive generation - if self.config.is_encoder_decoder: + # if encoder-decoder, we create encoder_outputs and add to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) # if encoder-decoder then `input_ids` come from `decoder_start_token_id` input_ids = self._prepare_decoder_input_ids_for_generation( batch_size, @@ -1531,10 +1534,6 @@ def _generate( f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." ) - # TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all - # generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs. - model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None - # 8. run greedy search return self.greedy_search( input_ids, @@ -1559,10 +1558,6 @@ def _generate( **model_kwargs, ) - # TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all - # generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs. - model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None - # 10. run sample return self.sample( input_ids, @@ -1589,12 +1584,7 @@ def _prepare_attention_mask_for_generation( else: return tf.ones(input_ids.shape[:2], dtype=tf.int32) - def _prepare_encoder_decoder_kwargs_for_generation( - self, input_ids: tf.Tensor, return_dict_in_generate, model_kwargs - ) -> Dict[str, Any]: - # TODO(Patrick) - remove `return_dict_in_generate` flag input once `past`/`encoder_outputs` - # is cleaned - + def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids: tf.Tensor, model_kwargs) -> Dict[str, Any]: # get encoder and store encoder outputs encoder = self.get_encoder() @@ -1612,17 +1602,8 @@ def _prepare_encoder_decoder_kwargs_for_generation( encoder_kwargs.pop("attention_mask") encoder_outputs = encoder(input_ids, **encoder_kwargs) - model_kwargs["encoder_outputs"] = encoder_outputs - # TODO(Patrick): `encoder_outputs`, `past` hack. Currently, `encoder_attentions` and - # `encoder_hidden_states` have to be seperated from encoder_outputs and passed - # under other names because of `encoder_outputs`, `past` hack. Need to clean-up - # all encoder-decoder prepare_inputs_for_generation method to clean this - if return_dict_in_generate: - model_kwargs["encoder_attentions"] = encoder_outputs.get("attentions", None) - model_kwargs["encoder_hidden_states"] = encoder_outputs.get("hidden_states", None) - return model_kwargs def _prepare_decoder_input_ids_for_generation( @@ -1712,27 +1693,17 @@ def _prepare_model_inputs(self, inputs: Optional[tf.Tensor] = None, bos_token_id return inputs + @staticmethod def _update_model_kwargs_for_generation( - self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False + outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False ) -> Dict[str, Any]: # update past - if self._use_cache(outputs, model_kwargs["use_cache"]): - # TODO(Patrick): `past`/`encoder_outputs` hack. This should be - # removed when cleaning up the encoder-decoder models - # if model has past, then set the past variable to speed up decoding - # make this method static then as well - model_kwargs["past"] = outputs[1] - elif "past_key_values" in outputs: + if "past_key_values" in outputs: model_kwargs["past"] = outputs.past_key_values elif "mems" in outputs: model_kwargs["past"] = outputs.mems elif "past_buckets_states" in outputs: model_kwargs["past"] = outputs.past_buckets_states - elif "past" in model_kwargs: - # TODO(Patrick) `past`/`encoder_outputs` hack. - # removed when cleaning up the encoder-decoder models. - # The line should not be necessary. - pass else: model_kwargs["past"] = None @@ -1907,26 +1878,18 @@ def greedy_search( cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - # TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs` - # to be wrapped into `past` variable. Tis is a bad design and needs - # to be updated. - # Remove the following lines when updating all encoder-decoder models - encoder_outputs = model_kwargs.pop("encoder_outputs", None) - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None - encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) # keep track of which sequences are already finished unfinished_sequences = tf.ones_like(input_ids[:, 0]) cur_len = input_ids.shape[-1] while cur_len < max_length: - # TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation` - # in all models - model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"] - # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -2129,25 +2092,18 @@ def sample( cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - # TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs` - # to be wrapped into `past` variable. This is a bad design and needs to be updated. - # Remove the following lines when updating all encoder-decoder models - encoder_outputs = model_kwargs.pop("encoder_outputs", None) - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None - encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) # keep track of which sequences are already finished unfinished_sequences = tf.ones_like(input_ids[:, 0]) cur_len = input_ids.shape[-1] while cur_len < max_length: - # TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation` - # in all models - model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"] - # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index a618b8a48586..2b1df1a73586 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -16,7 +16,7 @@ import random -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import tensorflow as tf @@ -1012,9 +1012,6 @@ def call( if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) - if inputs["use_cache"]: - present_key_values = (inputs["encoder_hidden_states"], present_key_values) - if not inputs["return_dict"]: return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: @@ -1449,43 +1446,23 @@ def serving_output(self, output): def prepare_inputs_for_generation( self, decoder_input_ids, - past, - attention_mask, + past=None, + attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, - **kwargs, - ) -> Dict: - assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" - if len(past) == 1: - assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" - encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) - past_key_values = None - else: - assert ( - len(past) == 2 - ), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position." - encoder_outputs, past_key_values = past - if isinstance(encoder_outputs, tuple): - assert isinstance( - encoder_outputs[0], tf.Tensor - ), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}" - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0]) - elif isinstance(encoder_outputs, tf.Tensor): - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs) - assert ( - past_key_values - ), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past" + encoder_outputs=None, + **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] - assert isinstance( - encoder_outputs, TFBaseModelOutput - ), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}." return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, + "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, @@ -1499,15 +1476,10 @@ def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): @staticmethod def _reorder_cache(past, beam_idx): - if len(past) == 1: - return past - - past_key_values = past[1] - reordered_past = () - for layer_past_key_values in past_key_values: + for layer_past in past: + # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) - + layer_past_key_values[2:], + tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], ) - return (past[0], reordered_past) + return reordered_past diff --git a/src/transformers/models/bert/modeling_tf_bert.py b/src/transformers/models/bert/modeling_tf_bert.py index 18dd4e754e5b..f83dc186598b 100644 --- a/src/transformers/models/bert/modeling_tf_bert.py +++ b/src/transformers/models/bert/modeling_tf_bert.py @@ -1443,17 +1443,17 @@ def get_prefix_bias_name(self) -> str: warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name - def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = tf.ones(input_shape) + # cut decoder_input_ids if past is used - if past: - inputs = tf.expand_dims(inputs[:, -1], -1) + if past is not None: + input_ids = input_ids[:, -1:] - return { - "input_ids": inputs, - "attention_mask": attention_mask, - "past_key_values": past, - "use_cache": model_kwargs["use_cache"], - } + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1575,6 +1575,13 @@ def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausa logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns ) + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),) + return reordered_past + @add_start_docstrings( """Bert Model with a `next sentence prediction (classification)` head on top.""", diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index c0cf39921263..66d9e5ffb19f 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -18,7 +18,7 @@ import os import random import warnings -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import tensorflow as tf @@ -1011,9 +1011,6 @@ def call( if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) - if inputs["use_cache"]: - present_key_values = (inputs["encoder_hidden_states"], present_key_values) - if not inputs["return_dict"]: return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: @@ -1461,43 +1458,23 @@ def serving_output(self, output): def prepare_inputs_for_generation( self, decoder_input_ids, - past, - attention_mask, + past=None, + attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, - **kwargs, - ) -> Dict: - assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" - if len(past) == 1: - assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" - encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) - past_key_values = None - else: - assert ( - len(past) == 2 - ), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position." - encoder_outputs, past_key_values = past - if isinstance(encoder_outputs, tuple): - assert isinstance( - encoder_outputs[0], tf.Tensor - ), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}" - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0]) - elif isinstance(encoder_outputs, tf.Tensor): - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs) - assert ( - past_key_values - ), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past" + encoder_outputs=None, + **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] - assert isinstance( - encoder_outputs, TFBaseModelOutput - ), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}." return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, + "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, @@ -1509,15 +1486,10 @@ def prepare_inputs_for_generation( @staticmethod # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache def _reorder_cache(past, beam_idx): - if len(past) == 1: - return past - - past_key_values = past[1] - reordered_past = () - for layer_past_key_values in past_key_values: + for layer_past in past: + # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) - + layer_past_key_values[2:], + tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], ) - return (past[0], reordered_past) + return reordered_past diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index 3c9a4b40f9de..43e67f43f738 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -16,7 +16,7 @@ import random -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import tensorflow as tf @@ -1010,9 +1010,6 @@ def call( if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) - if inputs["use_cache"]: - present_key_values = (inputs["encoder_hidden_states"], present_key_values) - if not inputs["return_dict"]: return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: @@ -1434,43 +1431,23 @@ def serving_output(self, output): def prepare_inputs_for_generation( self, decoder_input_ids, - past, - attention_mask, + past=None, + attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, - **kwargs, - ) -> Dict: - assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" - if len(past) == 1: - assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" - encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) - past_key_values = None - else: - assert ( - len(past) == 2 - ), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position." - encoder_outputs, past_key_values = past - if isinstance(encoder_outputs, tuple): - assert isinstance( - encoder_outputs[0], tf.Tensor - ), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}" - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0]) - elif isinstance(encoder_outputs, tf.Tensor): - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs) - assert ( - past_key_values - ), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past" + encoder_outputs=None, + **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] - assert isinstance( - encoder_outputs, TFBaseModelOutput - ), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}." return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, + "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, @@ -1482,15 +1459,10 @@ def prepare_inputs_for_generation( @staticmethod # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache def _reorder_cache(past, beam_idx): - if len(past) == 1: - return past - - past_key_values = past[1] - reordered_past = () - for layer_past_key_values in past_key_values: + for layer_past in past: + # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) - + layer_past_key_values[2:], + tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], ) - return (past[0], reordered_past) + return reordered_past diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index c72448310a85..3287c442e1ef 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -16,6 +16,7 @@ """ TF 2.0 CTRL model.""" import warnings +from typing import Tuple import numpy as np import tensorflow as tf @@ -659,12 +660,12 @@ def get_prefix_bias_name(self): warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.lm_head.name - def prepare_inputs_for_generation(self, inputs, past, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs): # only last token for inputs_ids if past is defined in kwargs if past: - inputs = tf.expand_dims(inputs[:, -1], -1) + input_ids = tf.expand_dims(input_ids[:, -1], -1) - return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]} + return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache} @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) @add_code_sample_docstrings( @@ -758,6 +759,12 @@ def serving_output(self, output): return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns) + @staticmethod + def _reorder_cache(past: Tuple[Tuple[tf.Tensor]], beam_idx: tf.Tensor) -> Tuple[Tuple[tf.Tensor]]: + return tuple( + tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past) for layer_past in past + ) + @add_start_docstrings( """ diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index a2668b75b117..4458b9c532e6 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -692,52 +692,21 @@ def serving_output(self, output): ) def prepare_inputs_for_generation( - self, - decoder_input_ids, - past, - attention_mask, - use_cache=None, - **kwargs, + self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): - if past is None or len(past) not in {1, 2}: - raise ValueError(f"past has to be an iterable of length 1,2 got {past}") - - if len(past) == 1: - if not isinstance(past[0], tf.Tensor): - raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}") - encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) - past_key_values = None - else: - if len(past) != 2: - raise ValueError( - "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position." - ) - encoder_outputs, past_key_values = past - if isinstance(encoder_outputs, tuple): - if not isinstance(encoder_outputs[0], tf.Tensor): - raise ValueError( - f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}" - ) - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0]) - elif isinstance(encoder_outputs, tf.Tensor): - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs) - if not past_key_values: - raise ValueError( - f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past" - ) - decoder_input_ids = decoder_input_ids[:, -1:] - - if not isinstance(encoder_outputs, TFBaseModelOutput): - raise ValueError(f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}.") - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, + decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) + decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None + input_dict = { + "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy "attention_mask": attention_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + "decoder_attention_mask": decoder_attention_mask, + "decoder_input_ids": decoder_inputs["input_ids"], + # TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete + "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]), + "past_key_values": decoder_inputs["past_key_values"], + "use_cache": use_cache, } + return input_dict def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) @@ -750,9 +719,4 @@ def resize_token_embeddings(self, *args, **kwargs): def _reorder_cache(self, past, beam_idx): # apply decoder cache reordering here - if len(past) == 1: - return past - - encoder_outputs, past_key_values = past - - return (encoder_outputs, self.decoder._reorder_cache(past_key_values, beam_idx)) + return self.decoder._reorder_cache(past, beam_idx) diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index d4939594d5ea..98f78e16da99 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -851,12 +851,15 @@ def get_output_embeddings(self): def set_output_embeddings(self, value): self.set_input_embeddings(value) - def prepare_inputs_for_generation(self, inputs, past, **kwargs): + def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs): + # TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2 + # tests will need to be fixed after the change + # only last token for inputs_ids if past is defined in kwargs if past: inputs = tf.expand_dims(inputs[:, -1], -1) - return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]} + return {"input_ids": inputs, "past_key_values": past, "use_cache": use_cache} @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_code_sample_docstrings( diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index e282db0e811f..1e9a05bb6daf 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -17,7 +17,7 @@ import random from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import tensorflow as tf @@ -2097,7 +2097,7 @@ def call( all_self_attns = all_self_attns if inputs["output_attentions"] else None all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None - present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None + present_key_values = present_key_values if inputs["use_cache"] else None if not inputs["return_dict"]: return tuple( @@ -2527,45 +2527,26 @@ def serving_output(self, output): def prepare_inputs_for_generation( self, decoder_input_ids, - past, - attention_mask, + past=None, + attention_mask=None, head_mask=None, + decoder_head_mask=None, use_cache=None, + encoder_outputs=None, **kwargs, - ) -> Dict: - assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" - if len(past) == 1: - assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" - encoder_outputs = TFLEDEncoderBaseModelOutput(last_hidden_state=past[0]) - past_key_values = None - else: - assert ( - len(past) == 2 - ), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position." - encoder_outputs, past_key_values = past - if isinstance(encoder_outputs, tuple): - assert isinstance( - encoder_outputs[0], tf.Tensor - ), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}" - encoder_outputs = TFLEDEncoderBaseModelOutput(last_hidden_state=encoder_outputs[0]) - elif isinstance(encoder_outputs, tf.Tensor): - encoder_outputs = TFLEDEncoderBaseModelOutput(last_hidden_state=encoder_outputs) - assert ( - past_key_values - ), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past" + ): + # cut decoder_input_ids if past is used + if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] - assert isinstance( - encoder_outputs, - TFLEDEncoderBaseModelOutput, - ), f"encoder_outputs should be a TFLEDEncoderBaseModelOutput, Instead got {type(encoder_outputs)}." return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, + "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } @@ -2574,18 +2555,13 @@ def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): @staticmethod def _reorder_cache(past, beam_idx): - if len(past) == 1: - return past - - past_key_values = past[1] - reordered_past = () - for layer_past_key_values in past_key_values: + for layer_past in past: + # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) - + layer_past_key_values[2:], + tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], ) - return (past[0], reordered_past) + return reordered_past def hf_compute_loss(self, labels, logits): """CrossEntropyLoss that ignores pad tokens""" diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index d81c052b6df4..d6b0b123d690 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -16,7 +16,7 @@ import random -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import tensorflow as tf @@ -1050,9 +1050,6 @@ def call( if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) - if inputs["use_cache"]: - present_key_values = (inputs["encoder_hidden_states"], present_key_values) - if not inputs["return_dict"]: return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: @@ -1477,43 +1474,23 @@ def serving_output(self, output): def prepare_inputs_for_generation( self, decoder_input_ids, - past, - attention_mask, + past=None, + attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, - **kwargs, - ) -> Dict: - assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" - if len(past) == 1: - assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" - encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) - past_key_values = None - else: - assert ( - len(past) == 2 - ), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position." - encoder_outputs, past_key_values = past - if isinstance(encoder_outputs, tuple): - assert isinstance( - encoder_outputs[0], tf.Tensor - ), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}" - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0]) - elif isinstance(encoder_outputs, tf.Tensor): - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs) - assert ( - past_key_values - ), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past" + encoder_outputs=None, + **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] - assert isinstance( - encoder_outputs, TFBaseModelOutput - ), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}." return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, + "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, @@ -1528,18 +1505,13 @@ def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): @staticmethod # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache def _reorder_cache(past, beam_idx): - if len(past) == 1: - return past - - past_key_values = past[1] - reordered_past = () - for layer_past_key_values in past_key_values: + for layer_past in past: + # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) - + layer_past_key_values[2:], + tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], ) - return (past[0], reordered_past) + return reordered_past def adjust_logits_during_generation( self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index bd11d1601044..a7c7b40e690b 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -16,7 +16,7 @@ import random -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import tensorflow as tf @@ -1034,9 +1034,6 @@ def call( if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) - if inputs["use_cache"]: - present_key_values = (inputs["encoder_hidden_states"], present_key_values) - if not inputs["return_dict"]: return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: @@ -1462,43 +1459,23 @@ def serving_output(self, output): def prepare_inputs_for_generation( self, decoder_input_ids, - past, - attention_mask, + past=None, + attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, - **kwargs, - ) -> Dict: - assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" - if len(past) == 1: - assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" - encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) - past_key_values = None - else: - assert ( - len(past) == 2 - ), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position." - encoder_outputs, past_key_values = past - if isinstance(encoder_outputs, tuple): - assert isinstance( - encoder_outputs[0], tf.Tensor - ), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}" - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0]) - elif isinstance(encoder_outputs, tf.Tensor): - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs) - assert ( - past_key_values - ), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past" + encoder_outputs=None, + **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] - assert isinstance( - encoder_outputs, TFBaseModelOutput - ), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}." return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, + "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, @@ -1513,15 +1490,10 @@ def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): @staticmethod # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache def _reorder_cache(past, beam_idx): - if len(past) == 1: - return past - - past_key_values = past[1] - reordered_past = () - for layer_past_key_values in past_key_values: + for layer_past in past: + # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) - + layer_past_key_values[2:], + tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], ) - return (past[0], reordered_past) + return reordered_past diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index 9461fa871ac4..0e3917e9d632 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -16,7 +16,7 @@ import random -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import tensorflow as tf @@ -1058,9 +1058,6 @@ def call( if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) - if inputs["use_cache"]: - present_key_values = (inputs["encoder_hidden_states"], present_key_values) - if not inputs["return_dict"]: return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: @@ -1485,43 +1482,23 @@ def serving_output(self, output): def prepare_inputs_for_generation( self, decoder_input_ids, - past, - attention_mask, + past=None, + attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, - **kwargs, - ) -> Dict: - assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" - if len(past) == 1: - assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" - encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) - past_key_values = None - else: - assert ( - len(past) == 2 - ), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position." - encoder_outputs, past_key_values = past - if isinstance(encoder_outputs, tuple): - assert isinstance( - encoder_outputs[0], tf.Tensor - ), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}" - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0]) - elif isinstance(encoder_outputs, tf.Tensor): - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs) - assert ( - past_key_values - ), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past" + encoder_outputs=None, + **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] - assert isinstance( - encoder_outputs, TFBaseModelOutput - ), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}." return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, + "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, @@ -1536,15 +1513,10 @@ def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): @staticmethod # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache def _reorder_cache(past, beam_idx): - if len(past) == 1: - return past - - past_key_values = past[1] - reordered_past = () - for layer_past_key_values in past_key_values: + for layer_past in past: + # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) - + layer_past_key_values[2:], + tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], ) - return (past[0], reordered_past) + return reordered_past diff --git a/src/transformers/models/rag/modeling_tf_rag.py b/src/transformers/models/rag/modeling_tf_rag.py index 7ea2d3521b61..53a21864254b 100644 --- a/src/transformers/models/rag/modeling_tf_rag.py +++ b/src/transformers/models/rag/modeling_tf_rag.py @@ -16,14 +16,13 @@ """TFRAG model implementation.""" from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import numpy as np import tensorflow as tf from ...configuration_utils import PretrainedConfig from ...file_utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings -from ...modeling_tf_outputs import TFBaseModelOutput from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, input_processing, shape_list from ...utils import logging from .configuration_rag import RagConfig @@ -788,42 +787,28 @@ def set_retriever(self, retriever: RagRetriever): # Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_tf_bart.py def prepare_inputs_for_generation( - self, decoder_input_ids, past, attention_mask, use_cache, doc_scores, n_docs=None, **kwargs - ) -> Dict: - assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" - - if len(past) == 1: - assert isinstance(past[0], tf.Tensor) - encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) - decoder_cached_states = None - else: - assert len(past) == 2 - # Note: encoder_outputs is never changed by Bart as a generator - encoder_outputs, decoder_cached_states = past - - if isinstance(encoder_outputs, tuple): - assert isinstance(encoder_outputs[0], tf.Tensor) - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0]) - elif isinstance(encoder_outputs, tf.Tensor): - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs) - - assert ( - decoder_cached_states - ), f"decoder cached states must be truthy. got {decoder_cached_states} from the 2nd element of past" - # if past is defined cut decoder_input_ids to last token + self, + decoder_input_ids, + past=None, + attention_mask=None, + use_cache=None, + encoder_outputs=None, + doc_scores=None, + n_docs=None, + **kwargs + ): + if past is not None: + # if past is defined use only last decoder_input_ids decoder_input_ids = decoder_input_ids[:, -1:] - assert isinstance( - encoder_outputs, TFBaseModelOutput - ), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}." return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed + "input_ids": None, "encoder_outputs": encoder_outputs, "doc_scores": doc_scores, "context_attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, - "past_key_values": decoder_cached_states, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + "past_key_values": past, + "use_cache": use_cache, "do_marginalize": True, "n_docs": n_docs, } @@ -844,46 +829,19 @@ def question_encoder(self): def _reorder_cache(past, beam_idx): """Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs""" - def tf_index_select(input_, dim, indices): - """ - Input: - input_(tensor): input tensor dim(int): dimension indices(list): selected indices list - Output: - mimic of torch_tensor.index_select(dim, indices) - - credit: - https://stackoverflow.com/questions/58464790/is-there-an-equivalent-function-of-pytorch-named-index-select-in-tensorflow - """ - shape = shape_list(input_) - if dim == -1: - dim = len(shape) - 1 - shape[dim] = 1 - - tmp = [] - for idx in indices: - begin = [0] * len(shape) - begin[dim] = idx - tmp.append(tf.slice(input_, begin, shape)) - res = tf.concat(tmp, axis=dim) - - return res - - def _reorder_stacked(hidden_states, new_order=beam_idx): + def _reorder_stacked(hidden_states, new_order): n_docs = hidden_states.shape[0] // new_order.shape[0] hidden_states = tf.reshape(hidden_states, (-1, n_docs, *hidden_states.shape[1:])) - hidden_states = tf_index_select(hidden_states, 0, new_order) - return tf.reshape(hidden_states, (-1, *hidden_states.shape[2:])) - - if len(past) == 1: - return past - - past_key_values = past[1] + hidden_states = tf.gather(hidden_states, new_order, axis=0) + result = tf.reshape(hidden_states, (-1, *hidden_states.shape[2:])) + return result reordered_past = () - for layer_past in past_key_values: + for layer_past in past: + # get the correct batch idx from decoder layer's batch dim for cross and self-attn reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),) - return (past[0], reordered_past) + return reordered_past def marginalize(self, seq_logits, doc_scores, n_docs=None): n_docs = n_docs if n_docs is not None else self.config.n_docs @@ -1268,14 +1226,6 @@ def generate( return_dict=True, ) - if return_dict_in_generate: - # TODO(Patrick): `encoder_outputs`, `past` hack. - # Remove after cleaning encoder-decoder outputs - if output_attentions: - model_kwargs["encoder_attentions"] = encoder_outputs.attentions - if output_hidden_states: - model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states - decoder_input_ids = tf.fill( (batch_size * num_beams, 1), tf.cast(decoder_start_token_id, tf.int32), @@ -1366,10 +1316,6 @@ def extend_enc_output(tensor, num_beams=None): model_kwargs.pop("output_attentions", None) model_kwargs.pop("output_scores", None) - # TODO(Patrick): `encoder_outputs`, `past` hack. - # Remove after cleaning encoder-decoder outputs - model_kwargs["past"] = encoder_outputs - return self.greedy_search( input_ids=decoder_input_ids, max_length=max_length, diff --git a/src/transformers/models/rembert/modeling_tf_rembert.py b/src/transformers/models/rembert/modeling_tf_rembert.py index c7b65fe3a157..201e904d952b 100644 --- a/src/transformers/models/rembert/modeling_tf_rembert.py +++ b/src/transformers/models/rembert/modeling_tf_rembert.py @@ -1176,17 +1176,17 @@ def get_lm_head(self) -> tf.keras.layers.Layer: return self.mlm.predictions # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = tf.ones(input_shape) + # cut decoder_input_ids if past is used - if past: - inputs = tf.expand_dims(inputs[:, -1], -1) + if past is not None: + input_ids = input_ids[:, -1:] - return { - "input_ids": inputs, - "attention_mask": attention_mask, - "past_key_values": past, - "use_cache": model_kwargs["use_cache"], - } + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1309,6 +1309,14 @@ def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausa logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns ) + @staticmethod + # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),) + return reordered_past + @add_start_docstrings( """ diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py index dd45b6fd2d36..ee9e3d1457e8 100644 --- a/src/transformers/models/roberta/modeling_tf_roberta.py +++ b/src/transformers/models/roberta/modeling_tf_roberta.py @@ -1209,17 +1209,17 @@ def get_prefix_bias_name(self): return self.name + "/" + self.lm_head.name # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = tf.ones(input_shape) + # cut decoder_input_ids if past is used - if past: - inputs = tf.expand_dims(inputs[:, -1], -1) + if past is not None: + input_ids = input_ids[:, -1:] - return { - "input_ids": inputs, - "attention_mask": attention_mask, - "past_key_values": past, - "use_cache": model_kwargs["use_cache"], - } + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -1344,6 +1344,14 @@ def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausa logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns ) + @staticmethod + # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),) + return reordered_past + class TFRobertaClassificationHead(tf.keras.layers.Layer): """Head for sentence-level classification tasks.""" diff --git a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py index 0eba94521d25..1e8e80f2622a 100755 --- a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py @@ -1139,7 +1139,7 @@ def call( if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) - next_cache = (inputs["encoder_hidden_states"], next_decoder_cache) if use_cache else None + next_cache = next_decoder_cache if use_cache else None if not inputs["return_dict"]: return hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attns @@ -1571,26 +1571,17 @@ def prepare_inputs_for_generation( decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, + encoder_outputs=None, **kwargs ): - if past is not None and len(past) <= 2: - if not isinstance(past[0], tf.Tensor): - raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}") - encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) - if len(past) == 1: - past_key_values = None - else: - past_key_values = past[1] - if not past_key_values: - raise ValueError(f"decoder cached states must be truthy, got {past_key_values}") - decoder_input_ids = decoder_input_ids[:, -1:] - else: - raise ValueError(f"`past` must be an iterable with length 1 or 2, got {past}") + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] return { "input_features": None, # needs to be passed to make Keras.layer.__call__ happy "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, + "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, @@ -1601,15 +1592,7 @@ def prepare_inputs_for_generation( @staticmethod def _reorder_cache(past, beam_idx): - if len(past) == 1: - return past - - past_key_values = past[1] - reordered_past = () - for layer_past_key_values in past_key_values: - reordered_past += ( - tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) - + layer_past_key_values[2:], - ) - return (past[0], reordered_past) + for layer_past in past: + reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index ca307df70ebc..91d1c019b5fc 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -1256,15 +1256,13 @@ def call( return_dict=inputs["return_dict"], training=inputs["training"], ) + past = decoder_outputs[1] if inputs["use_cache"] else None if not inputs["return_dict"]: - past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None if past is not None: decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] return decoder_outputs + inputs["encoder_outputs"] - past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None - return TFSeq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=past, @@ -1483,8 +1481,8 @@ def call( loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) + past = decoder_outputs[1] if inputs["use_cache"] else None if not inputs["return_dict"]: - past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None if past is not None: decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"] @@ -1509,8 +1507,6 @@ def call( attentions=attentions, ) - past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None - return TFSeq2SeqLMOutput( loss=loss, logits=logits, @@ -1544,65 +1540,57 @@ def serving_output(self, output): def prepare_inputs_for_generation( self, - inputs, - past, - attention_mask, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, use_cache=None, - **kwargs, + encoder_outputs=None, + **kwargs ): - assert past is not None, "past has to be defined for encoder_outputs" - - # first step - if len(past) < 2: - encoder_outputs, past_key_values = past, None - else: - encoder_outputs, past_key_values = past[0], past[1] - if "encoder_hidden_states" in kwargs: - encoder_outputs = (*encoder_outputs, kwargs["encoder_hidden_states"]) - if "encoder_attentions" in kwargs: - encoder_outputs = (*encoder_outputs, kwargs["encoder_attentions"]) # cut decoder_input_ids if past is used - if past_key_values is not None: - inputs = inputs[:, -1:] + if past is not None: + input_ids = input_ids[:, -1:] return { - "input_ids": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy - "decoder_input_ids": inputs, # inputs are the decoder_input_ids - "past_key_values": past_key_values, + "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy + "decoder_input_ids": input_ids, + "past_key_values": past, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, "use_cache": use_cache, } def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): return self._shift_right(labels) - def _reorder_cache(self, past, beam_idx) -> Tuple: + def _reorder_cache(self, past, beam_idx): # if decoder past is not included in output # speedy decoding is disabled and no need to reorder - - if len(past) < 2: + if past is None: logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") return past - decoder_past = past[1] - past = (past[0],) reordered_decoder_past = () - - for layer_past_states in decoder_past: + for layer_past_states in past: # get the correct batch idx from layer past batch dim # batch dim of `past` is at 2nd position reordered_layer_past_states = () for layer_past_state in layer_past_states: # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + (tf.gather(layer_past_state, beam_idx),) + reordered_layer_past_states = reordered_layer_past_states + ( + tf.gather(layer_past_state, beam_idx, axis=0), + ) - assert shape_list(reordered_layer_past_states[0]) == shape_list(layer_past_states[0]) + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape assert len(reordered_layer_past_states) == len(layer_past_states) reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return past + (reordered_decoder_past,) + return reordered_decoder_past @add_start_docstrings( diff --git a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py index 4534a4884aa7..b5e21efa7bd5 100644 --- a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py @@ -1058,15 +1058,22 @@ def serving_output(self, output): attentions=attns, ) - def prepare_inputs_for_generation(self, inputs, past, **model_kwargs): - inputs = {"input_ids": inputs} + def prepare_inputs_for_generation(self, input_ids, past=None, **model_kwargs): + inputs = {} # if past is defined in model kwargs then use it for faster decoding if past: inputs["mems"] = past + inputs["input_ids"] = tf.expand_dims(input_ids[:, -1], axis=-1) + else: + inputs["input_ids"] = input_ids return inputs + @staticmethod + def _reorder_cache(mems: List[tf.Tensor], beam_idx: tf.Tensor) -> List[tf.Tensor]: + return [tf.gather(layer_past, beam_idx, axis=1) for layer_past in mems] + @add_start_docstrings( """ diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index 244c836b8c3f..0f63e343165d 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -722,45 +722,22 @@ def serving_output(self, output): cross_attentions=cross_attns, ) - def prepare_inputs_for_generation(self, decoder_input_ids, past, use_cache=None, **kwargs): - if past is None or len(past) not in {1, 2}: - raise ValueError(f"past has to be an iterable of length 1,2 got {past}") - - if len(past) == 1: - if not isinstance(past[0], tf.Tensor): - raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}") - encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) - past_key_values = None - else: - if len(past) != 2: - raise ValueError( - "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position." - ) - encoder_outputs, past_key_values = past - if isinstance(encoder_outputs, tuple): - if not isinstance(encoder_outputs[0], tf.Tensor): - raise ValueError( - f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}" - ) - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0]) - elif isinstance(encoder_outputs, tf.Tensor): - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs) - if not past_key_values: - raise ValueError( - f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past" - ) - decoder_input_ids = decoder_input_ids[:, -1:] - - if not isinstance(encoder_outputs, TFBaseModelOutput): - raise ValueError(f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}.") - - return { - "pixel_values": None, # encoder_outputs is defined. pixel_values not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + def prepare_inputs_for_generation( + self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + ): + decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) + decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None + input_dict = { + "pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_input_ids": decoder_inputs["input_ids"], + # TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete + "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]), + "past_key_values": decoder_inputs["past_key_values"], + "use_cache": use_cache, } + return input_dict def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) @@ -773,9 +750,4 @@ def resize_token_embeddings(self, *args, **kwargs): def _reorder_cache(self, past, beam_idx): # apply decoder cache reordering here - if len(past) == 1: - return past - - encoder_outputs, past_key_values = past - - return (encoder_outputs, self.decoder._reorder_cache(past_key_values, beam_idx)) + return self.decoder._reorder_cache(past, beam_idx) diff --git a/src/transformers/models/xlnet/modeling_tf_xlnet.py b/src/transformers/models/xlnet/modeling_tf_xlnet.py index ea0f6b6baf84..96aa88bb2df2 100644 --- a/src/transformers/models/xlnet/modeling_tf_xlnet.py +++ b/src/transformers/models/xlnet/modeling_tf_xlnet.py @@ -1246,17 +1246,17 @@ def get_prefix_bias_name(self): warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.lm_loss.name - def prepare_inputs_for_generation(self, inputs, past, use_mems=None, **kwargs): + def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwargs): # Add dummy token at the end (no attention on this one) + effective_batch_size = inputs.shape[0] + dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype) + # At every pass, the attention values for the new token and the two last generated tokens # are computed, the rest is reloaded from the `past` cache. A purely auto-regressive model would have # offset = 1; offset = 2 seems to have slightly better computation. offset = 2 - effective_batch_size = inputs.shape[0] - dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype) - if past: inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1) else: @@ -1277,7 +1277,7 @@ def prepare_inputs_for_generation(self, inputs, past, use_mems=None, **kwargs): "input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, - "use_mems": kwargs.get("use_mems"), + "use_mems": use_mems, } # if past is defined in model kwargs then use it for faster decoding diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index bb7adc9d0540..25afc22d6c03 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -1777,7 +1777,7 @@ def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAn {% else %} import random -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import tensorflow as tf @@ -2736,9 +2736,6 @@ def call( if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) - if inputs["use_cache"]: - present_key_values = (inputs["encoder_hidden_states"], present_key_values) - if not inputs["return_dict"]: return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: @@ -3186,43 +3183,23 @@ def serving_output(self, output): def prepare_inputs_for_generation( self, decoder_input_ids, - past, - attention_mask, + past=None, + attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, - use_cache=False, + use_cache=None, + encoder_outputs=None, **kwargs - ) -> Dict: - assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" - if len(past) == 1: - assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" - encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) - past_key_values = None - else: - assert ( - len(past) == 2 - ), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position." - encoder_outputs, past_key_values = past - if isinstance(encoder_outputs, tuple): - assert isinstance( - encoder_outputs[0], tf.Tensor - ), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}" - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0]) - elif isinstance(encoder_outputs, tf.Tensor): - encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs) - assert ( - past_key_values - ), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past" + ): + # cut decoder_input_ids if past is used + if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] - assert isinstance( - encoder_outputs, TFBaseModelOutput - ), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}." return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed + "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, + "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, @@ -3233,17 +3210,10 @@ def prepare_inputs_for_generation( @staticmethod def _reorder_cache(past, beam_idx): - if len(past) == 1: - return past - - past_key_values = past[1] - reordered_past = () - for layer_past_key_values in past_key_values: - reordered_past += ( - tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) + layer_past_key_values[2:], - ) - return (past[0], reordered_past) + for layer_past in past: + reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),) + return reordered_past def hf_compute_loss(self, labels, logits): """CrossEntropyLoss that ignores pad tokens""" diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py index d2875f232c73..16b31500dd6c 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -802,7 +802,6 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) output, past_key_values = outputs.to_tuple() - past_key_values = past_key_values[1] # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) diff --git a/tests/bart/test_modeling_tf_bart.py b/tests/bart/test_modeling_tf_bart.py index c231e1418866..417f6edcafe9 100644 --- a/tests/bart/test_modeling_tf_bart.py +++ b/tests/bart/test_modeling_tf_bart.py @@ -116,7 +116,6 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() - past_key_values = past_key_values[1] # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) diff --git a/tests/blenderbot/test_modeling_tf_blenderbot.py b/tests/blenderbot/test_modeling_tf_blenderbot.py index e8aebc446232..3d0e8fc4365b 100644 --- a/tests/blenderbot/test_modeling_tf_blenderbot.py +++ b/tests/blenderbot/test_modeling_tf_blenderbot.py @@ -114,7 +114,6 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() - past_key_values = past_key_values[1] # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) diff --git a/tests/blenderbot_small/test_modeling_tf_blenderbot_small.py b/tests/blenderbot_small/test_modeling_tf_blenderbot_small.py index cb74c799cb12..6a3eeb826d2f 100644 --- a/tests/blenderbot_small/test_modeling_tf_blenderbot_small.py +++ b/tests/blenderbot_small/test_modeling_tf_blenderbot_small.py @@ -114,7 +114,6 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() - past_key_values = past_key_values[1] # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) diff --git a/tests/led/test_modeling_tf_led.py b/tests/led/test_modeling_tf_led.py index 6870c2b08cbe..cb75ddf8c3eb 100644 --- a/tests/led/test_modeling_tf_led.py +++ b/tests/led/test_modeling_tf_led.py @@ -133,7 +133,6 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) output, past_key_values = outputs.to_tuple() - past_key_values = past_key_values[1] # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) diff --git a/tests/marian/test_modeling_tf_marian.py b/tests/marian/test_modeling_tf_marian.py index fb7b8629f907..23bd9be1fc2d 100644 --- a/tests/marian/test_modeling_tf_marian.py +++ b/tests/marian/test_modeling_tf_marian.py @@ -116,7 +116,6 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() - past_key_values = past_key_values[1] # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) diff --git a/tests/pegasus/test_modeling_tf_pegasus.py b/tests/pegasus/test_modeling_tf_pegasus.py index b9e7c45b6db6..ca0d52526740 100644 --- a/tests/pegasus/test_modeling_tf_pegasus.py +++ b/tests/pegasus/test_modeling_tf_pegasus.py @@ -114,7 +114,6 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() - past_key_values = past_key_values[1] # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) diff --git a/tests/speech_to_text/test_modeling_tf_speech_to_text.py b/tests/speech_to_text/test_modeling_tf_speech_to_text.py index b9036ae3a19c..3aad2ba9b0bb 100644 --- a/tests/speech_to_text/test_modeling_tf_speech_to_text.py +++ b/tests/speech_to_text/test_modeling_tf_speech_to_text.py @@ -179,7 +179,7 @@ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): # first forward pass outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) - _, (_, past_key_values) = outputs.to_tuple() + _, past_key_values = outputs.to_tuple() # create hypothetical multiple next token and extent to next_input_ids next_tokens = tf.math.maximum(ids_tensor((self.batch_size, 3), config.vocab_size), 2) diff --git a/tests/t5/test_modeling_tf_t5.py b/tests/t5/test_modeling_tf_t5.py index 49d020d17bc3..5abf66f4c23b 100644 --- a/tests/t5/test_modeling_tf_t5.py +++ b/tests/t5/test_modeling_tf_t5.py @@ -98,13 +98,10 @@ def create_and_check_t5_model(self, config, input_ids, input_mask, token_labels) encoder_output = result.encoder_last_hidden_state self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size]) self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size]) - self.parent.assertEqual(len(decoder_past), 2) - # decoder_past[0] should correspond to encoder output - self.parent.assertTrue(tf.reduce_all(tf.math.equal(decoder_past[0][0], encoder_output))) # There should be `num_layers` key value embeddings stored in decoder_past[1] - self.parent.assertEqual(len(decoder_past[1]), config.num_layers) + self.parent.assertEqual(len(decoder_past), config.num_layers) # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple - self.parent.assertEqual(len(decoder_past[1][0]), 4) + self.parent.assertEqual(len(decoder_past[0]), 4) def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels): model = TFT5ForConditionalGeneration(config=config)