Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 25 additions & 69 deletions src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,9 +867,8 @@ def _generate_beam_search(

beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))

# cache compute states
past = encoder_outputs
# to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None
# variable to cache compute states
past = None

# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None
Expand All @@ -886,6 +885,13 @@ def _generate_beam_search(
if (return_dict_in_generate and kwargs["encoder_hidden_states"])
else None
)
# the refactored generate, without the encoder outputs in `past`, expects the `encoder_outputs`
# variable to contain all (encoder_outputs, encoder_hidden_states, encoder_attentions) in
# `prepare_inputs_for_generation`
if encoder_hidden_states is not None:
Copy link
Contributor

@patrickvonplaten patrickvonplaten Mar 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Why not wrap it into a TFEncoderOutputs class here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question! I tried that, it would be the most sensible change IMO (as the updated generate gets the encoder outputs with return_dict=True). However, a TFEncoderOutputs would make T5 tests fail. At this point, I had 2 options: update TF T5 or write this. Since this PR is mostly about updating the past variable, I thought it would be the path of least resistance.

Happy to change T5 instead :)

encoder_outputs = (*encoder_outputs, encoder_hidden_states)
if encoder_attentions is not None:
encoder_outputs = (*encoder_outputs, encoder_attentions)

# done sentences
done = [False for _ in range(batch_size)]
Expand All @@ -896,6 +902,7 @@ def _generate_beam_search(
past=past,
attention_mask=attention_mask,
use_cache=use_cache,
encoder_outputs=encoder_outputs,
**kwargs,
)
outputs = self(
Expand Down Expand Up @@ -1486,14 +1493,10 @@ def _generate(
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id)

# 4. Prepare model inputs which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
# if model is encoder decoder model, we create encoder_outputs and add to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
input_ids, return_dict_in_generate, model_kwargs
)

# 4. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
# if encoder-decoder, we create encoder_outputs and add to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) you could maybe put the under the # 4. Prepare ... comment and change the comment to prepare model inputs which will be used for ...

# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
Expand Down Expand Up @@ -1531,10 +1534,6 @@ def _generate(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
)

# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all
# generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs.
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None

# 8. run greedy search
return self.greedy_search(
input_ids,
Expand All @@ -1559,10 +1558,6 @@ def _generate(
**model_kwargs,
)

# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all
# generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs.
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None

# 10. run sample
return self.sample(
input_ids,
Expand All @@ -1589,12 +1584,7 @@ def _prepare_attention_mask_for_generation(
else:
return tf.ones(input_ids.shape[:2], dtype=tf.int32)

def _prepare_encoder_decoder_kwargs_for_generation(
self, input_ids: tf.Tensor, return_dict_in_generate, model_kwargs
) -> Dict[str, Any]:
# TODO(Patrick) - remove `return_dict_in_generate` flag input once `past`/`encoder_outputs`
# is cleaned

def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids: tf.Tensor, model_kwargs) -> Dict[str, Any]:
# get encoder and store encoder outputs
encoder = self.get_encoder()

Expand All @@ -1612,17 +1602,8 @@ def _prepare_encoder_decoder_kwargs_for_generation(
encoder_kwargs.pop("attention_mask")

encoder_outputs = encoder(input_ids, **encoder_kwargs)

model_kwargs["encoder_outputs"] = encoder_outputs

# TODO(Patrick): `encoder_outputs`, `past` hack. Currently, `encoder_attentions` and
# `encoder_hidden_states` have to be seperated from encoder_outputs and passed
# under other names because of `encoder_outputs`, `past` hack. Need to clean-up
# all encoder-decoder prepare_inputs_for_generation method to clean this
if return_dict_in_generate:
model_kwargs["encoder_attentions"] = encoder_outputs.get("attentions", None)
model_kwargs["encoder_hidden_states"] = encoder_outputs.get("hidden_states", None)

return model_kwargs

def _prepare_decoder_input_ids_for_generation(
Expand Down Expand Up @@ -1712,27 +1693,17 @@ def _prepare_model_inputs(self, inputs: Optional[tf.Tensor] = None, bos_token_id

return inputs

@staticmethod
def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
) -> Dict[str, Any]:
# update past
if self._use_cache(outputs, model_kwargs["use_cache"]):
# TODO(Patrick): `past`/`encoder_outputs` hack. This should be
# removed when cleaning up the encoder-decoder models
# if model has past, then set the past variable to speed up decoding
# make this method static then as well
model_kwargs["past"] = outputs[1]
elif "past_key_values" in outputs:
if "past_key_values" in outputs:
model_kwargs["past"] = outputs.past_key_values
elif "mems" in outputs:
model_kwargs["past"] = outputs.mems
elif "past_buckets_states" in outputs:
model_kwargs["past"] = outputs.past_buckets_states
elif "past" in model_kwargs:
# TODO(Patrick) `past`/`encoder_outputs` hack.
# removed when cleaning up the encoder-decoder models.
# The line should not be necessary.
pass
else:
model_kwargs["past"] = None

Expand Down Expand Up @@ -1907,26 +1878,18 @@ def greedy_search(
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

# TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs`
# to be wrapped into `past` variable. Tis is a bad design and needs
# to be updated.
# Remove the following lines when updating all encoder-decoder models
encoder_outputs = model_kwargs.pop("encoder_outputs", None)

# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None
encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)

# keep track of which sequences are already finished
unfinished_sequences = tf.ones_like(input_ids[:, 0])
cur_len = input_ids.shape[-1]

while cur_len < max_length:
# TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation`
# in all models
model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"]

# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

Expand Down Expand Up @@ -2129,25 +2092,18 @@ def sample(
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

# TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs`
# to be wrapped into `past` variable. This is a bad design and needs to be updated.
# Remove the following lines when updating all encoder-decoder models
encoder_outputs = model_kwargs.pop("encoder_outputs", None)

# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None
encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)

# keep track of which sequences are already finished
unfinished_sequences = tf.ones_like(input_ids[:, 0])
cur_len = input_ids.shape[-1]

while cur_len < max_length:
# TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation`
# in all models
model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"]

# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

Expand Down
54 changes: 13 additions & 41 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


import random
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union

import tensorflow as tf

Expand Down Expand Up @@ -1012,9 +1012,6 @@ def call(
if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)

if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)

if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else:
Expand Down Expand Up @@ -1449,43 +1446,23 @@ def serving_output(self, output):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]

assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
Expand All @@ -1499,15 +1476,10 @@ def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):

@staticmethod
def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past

past_key_values = past[1]

reordered_past = ()
for layer_past_key_values in past_key_values:
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
+ layer_past_key_values[2:],
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return (past[0], reordered_past)
return reordered_past
25 changes: 16 additions & 9 deletions src/transformers/models/bert/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,17 +1443,17 @@ def get_prefix_bias_name(self) -> str:
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name

def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = tf.ones(input_shape)

# cut decoder_input_ids if past is used
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
if past is not None:
input_ids = input_ids[:, -1:]

return {
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": model_kwargs["use_cache"],
}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}

@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
Expand Down Expand Up @@ -1575,6 +1575,13 @@ def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausa
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
)

@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past


@add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top.""",
Expand Down
Loading