-
Notifications
You must be signed in to change notification settings - Fork 31.9k
TF generate refactor - past without encoder outputs #15944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
6052c8f
Remove packed past from generation_tf_utils; fix tf speech_to_text
gante 6d5c50b
Fix bart
gante 06a10c2
refactored and tested up to ctrl
gante 2ef7df0
up to t5 (excluding)
gante 71f28d7
all models updated, a few tests failing
gante 117d0ee
final tweaks to open the PR
gante 84867f0
Merge branch 'master' into destroy_past
gante 2526021
PR comments
gante 14117ac
update template tests accordingly
gante 17230aa
correct template imports
gante File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -867,9 +867,8 @@ def _generate_beam_search( | |
|
|
||
| beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,)) | ||
|
|
||
| # cache compute states | ||
| past = encoder_outputs | ||
| # to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None | ||
| # variable to cache compute states | ||
| past = None | ||
|
|
||
| # init attention / hidden states / scores tuples | ||
| scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None | ||
|
|
@@ -886,6 +885,13 @@ def _generate_beam_search( | |
| if (return_dict_in_generate and kwargs["encoder_hidden_states"]) | ||
| else None | ||
| ) | ||
| # the refactored generate, without the encoder outputs in `past`, expects the `encoder_outputs` | ||
| # variable to contain all (encoder_outputs, encoder_hidden_states, encoder_attentions) in | ||
| # `prepare_inputs_for_generation` | ||
| if encoder_hidden_states is not None: | ||
| encoder_outputs = (*encoder_outputs, encoder_hidden_states) | ||
| if encoder_attentions is not None: | ||
| encoder_outputs = (*encoder_outputs, encoder_attentions) | ||
|
|
||
| # done sentences | ||
| done = [False for _ in range(batch_size)] | ||
|
|
@@ -896,6 +902,7 @@ def _generate_beam_search( | |
| past=past, | ||
| attention_mask=attention_mask, | ||
| use_cache=use_cache, | ||
| encoder_outputs=encoder_outputs, | ||
| **kwargs, | ||
| ) | ||
| outputs = self( | ||
|
|
@@ -1486,14 +1493,10 @@ def _generate( | |
| if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: | ||
| model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id) | ||
|
|
||
| # 4. Prepare model inputs which will be used for auto-regressive generation | ||
| if self.config.is_encoder_decoder: | ||
| # if model is encoder decoder model, we create encoder_outputs and add to `model_kwargs` | ||
| model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( | ||
| input_ids, return_dict_in_generate, model_kwargs | ||
| ) | ||
|
|
||
| # 4. Prepare `input_ids` which will be used for auto-regressive generation | ||
| if self.config.is_encoder_decoder: | ||
| # if encoder-decoder, we create encoder_outputs and add to `model_kwargs` | ||
| model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) you could maybe put the under the |
||
| # if encoder-decoder then `input_ids` come from `decoder_start_token_id` | ||
| input_ids = self._prepare_decoder_input_ids_for_generation( | ||
| batch_size, | ||
|
|
@@ -1531,10 +1534,6 @@ def _generate( | |
| f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." | ||
| ) | ||
|
|
||
| # TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all | ||
| # generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs. | ||
| model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None | ||
|
|
||
| # 8. run greedy search | ||
| return self.greedy_search( | ||
| input_ids, | ||
|
|
@@ -1559,10 +1558,6 @@ def _generate( | |
| **model_kwargs, | ||
| ) | ||
|
|
||
| # TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all | ||
| # generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs. | ||
| model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None | ||
|
|
||
| # 10. run sample | ||
| return self.sample( | ||
| input_ids, | ||
|
|
@@ -1589,12 +1584,7 @@ def _prepare_attention_mask_for_generation( | |
| else: | ||
| return tf.ones(input_ids.shape[:2], dtype=tf.int32) | ||
|
|
||
| def _prepare_encoder_decoder_kwargs_for_generation( | ||
| self, input_ids: tf.Tensor, return_dict_in_generate, model_kwargs | ||
| ) -> Dict[str, Any]: | ||
| # TODO(Patrick) - remove `return_dict_in_generate` flag input once `past`/`encoder_outputs` | ||
| # is cleaned | ||
|
|
||
| def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids: tf.Tensor, model_kwargs) -> Dict[str, Any]: | ||
| # get encoder and store encoder outputs | ||
| encoder = self.get_encoder() | ||
|
|
||
|
|
@@ -1612,17 +1602,8 @@ def _prepare_encoder_decoder_kwargs_for_generation( | |
| encoder_kwargs.pop("attention_mask") | ||
|
|
||
| encoder_outputs = encoder(input_ids, **encoder_kwargs) | ||
|
|
||
| model_kwargs["encoder_outputs"] = encoder_outputs | ||
|
|
||
| # TODO(Patrick): `encoder_outputs`, `past` hack. Currently, `encoder_attentions` and | ||
gante marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # `encoder_hidden_states` have to be seperated from encoder_outputs and passed | ||
| # under other names because of `encoder_outputs`, `past` hack. Need to clean-up | ||
| # all encoder-decoder prepare_inputs_for_generation method to clean this | ||
| if return_dict_in_generate: | ||
| model_kwargs["encoder_attentions"] = encoder_outputs.get("attentions", None) | ||
| model_kwargs["encoder_hidden_states"] = encoder_outputs.get("hidden_states", None) | ||
|
|
||
| return model_kwargs | ||
|
|
||
| def _prepare_decoder_input_ids_for_generation( | ||
|
|
@@ -1712,27 +1693,17 @@ def _prepare_model_inputs(self, inputs: Optional[tf.Tensor] = None, bos_token_id | |
|
|
||
| return inputs | ||
|
|
||
| @staticmethod | ||
| def _update_model_kwargs_for_generation( | ||
| self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False | ||
| outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False | ||
| ) -> Dict[str, Any]: | ||
| # update past | ||
| if self._use_cache(outputs, model_kwargs["use_cache"]): | ||
| # TODO(Patrick): `past`/`encoder_outputs` hack. This should be | ||
| # removed when cleaning up the encoder-decoder models | ||
| # if model has past, then set the past variable to speed up decoding | ||
| # make this method static then as well | ||
| model_kwargs["past"] = outputs[1] | ||
| elif "past_key_values" in outputs: | ||
| if "past_key_values" in outputs: | ||
| model_kwargs["past"] = outputs.past_key_values | ||
| elif "mems" in outputs: | ||
| model_kwargs["past"] = outputs.mems | ||
| elif "past_buckets_states" in outputs: | ||
| model_kwargs["past"] = outputs.past_buckets_states | ||
| elif "past" in model_kwargs: | ||
| # TODO(Patrick) `past`/`encoder_outputs` hack. | ||
| # removed when cleaning up the encoder-decoder models. | ||
| # The line should not be necessary. | ||
| pass | ||
| else: | ||
| model_kwargs["past"] = None | ||
|
|
||
|
|
@@ -1907,26 +1878,18 @@ def greedy_search( | |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None | ||
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | ||
|
|
||
| # TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs` | ||
| # to be wrapped into `past` variable. Tis is a bad design and needs | ||
| # to be updated. | ||
| # Remove the following lines when updating all encoder-decoder models | ||
| encoder_outputs = model_kwargs.pop("encoder_outputs", None) | ||
|
|
||
| # if model is an encoder-decoder, retrieve encoder attention weights and hidden states | ||
| if return_dict_in_generate and self.config.is_encoder_decoder: | ||
| encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None | ||
| encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None | ||
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | ||
| encoder_hidden_states = ( | ||
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | ||
| ) | ||
|
|
||
| # keep track of which sequences are already finished | ||
| unfinished_sequences = tf.ones_like(input_ids[:, 0]) | ||
| cur_len = input_ids.shape[-1] | ||
|
|
||
| while cur_len < max_length: | ||
| # TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation` | ||
| # in all models | ||
| model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"] | ||
|
|
||
| # prepare model inputs | ||
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | ||
|
|
||
|
|
@@ -2129,25 +2092,18 @@ def sample( | |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None | ||
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | ||
|
|
||
| # TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs` | ||
| # to be wrapped into `past` variable. This is a bad design and needs to be updated. | ||
| # Remove the following lines when updating all encoder-decoder models | ||
| encoder_outputs = model_kwargs.pop("encoder_outputs", None) | ||
|
|
||
| # if model is an encoder-decoder, retrieve encoder attention weights and hidden states | ||
| if return_dict_in_generate and self.config.is_encoder_decoder: | ||
| encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None | ||
| encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None | ||
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | ||
| encoder_hidden_states = ( | ||
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | ||
| ) | ||
|
|
||
| # keep track of which sequences are already finished | ||
| unfinished_sequences = tf.ones_like(input_ids[:, 0]) | ||
| cur_len = input_ids.shape[-1] | ||
|
|
||
| while cur_len < max_length: | ||
| # TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation` | ||
| # in all models | ||
| model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"] | ||
|
|
||
| # prepare model inputs | ||
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) Why not wrap it into a
TFEncoderOutputsclass here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great question! I tried that, it would be the most sensible change IMO (as the updated generate gets the encoder outputs with
return_dict=True). However, aTFEncoderOutputswould make T5 tests fail. At this point, I had 2 options: update TF T5 or write this. Since this PR is mostly about updating the past variable, I thought it would be the path of least resistance.Happy to change T5 instead :)