Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/transformers/models/bert/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,12 @@ def call(
**kwargs,
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:

# The received `past_key_values` is a tuple of 2 elements.
# The 1st element is `encoder_hidden_states`. The 2nd element is a tuple of `n_layers` elements,
# each element is a tuple of 4 tensors of shape (batch_size, n_heads, seq_len - 1, embed_size_per_head)
if type(past_key_values) == tuple and len(past_key_values) == 2 and type(past_key_values[1]) == tuple:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shoudn't be the case. past_key_values should only be the 2nd element you describe above. @gante - did we miss this model in the past/encoder_outputs refactor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TFBert and TFEncoderDecoder were both updated (PR). The test in question also doesn't have the num_beams nor the sample arguments in generate, meaning that it should be going through the updated greedy_search, which has the updated past format 🤔

@ydshieh, can I ask you to confirm that generate() calls the new _generate() internally (this one)? If not, it means that it is going through the old code, which may have some past-related issue (I made some ad hoc changes to comply with the new past format, but they may be incomplete.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can check the logic flow, will report back.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is probably this method TFEncoderDecoderModel.prepare_inputs_for_generation should be corrected, rather than in TFBertModel.

Considering it was me who worked on TFEncoderDecoderModel, I can fix it instead.
@gante, WDYT? If yes, do you have some hints/tips for me about the changes on the format you have done?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the problem should be related to this line, and it probably implies I missed some updates 🙈

I can look into it, it is almost surely related to a problem in my past PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thank you!

past_key_values = past_key_values[1]

if not self.config.is_decoder:
use_cache = False

Expand Down