Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/transformers/generation/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3066,16 +3066,20 @@ def beam_search_cond_fn(
not_max_length_yet = cur_len < max_length

# 2. can the new beams still improve?
best_running_score = running_scores[:, :1] / (max_length**length_penalty)
best_running_score = running_scores[:, :1] / tf.cast(cur_len, dtype=running_scores.dtype) ** length_penalty
Copy link
Collaborator Author

@ydshieh ydshieh Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In current main branch, max_length is used instead of cur_len. However, in our PyTorch generation's BeamHypotheses, it is cur_len, see

cur_score = best_sum_logprobs / cur_len**self.length_penalty

When running the code snippet in the reported TFMarian issue (#18149), we get max_length being a constant of 512, but the PyTorch generation code runs with cur_len which is from 1 (or 2) to 5.

(However, this is not the root cause of the issue in #18149)

worst_finished_score = tf.where(
is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9
)
improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score)

# 3. is there still a beam that has not finished?
still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping)
# still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line should be removed before merge

still_open_beam = ~(tf.math.reduce_all(is_sent_finished))

return not_max_length_yet & (still_open_beam | improvement_still_possible)
_early_stopping = tf.constant(early_stopping > 0, dtype=tf.bool)

# return not_max_length_yet & (still_open_beam | improvement_still_possible)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be removed before merge

return not_max_length_yet & (still_open_beam | (~_early_stopping & improvement_still_possible))
Comment on lines +3076 to +3082
Copy link
Collaborator Author

@ydshieh ydshieh Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method beam_search_cond_fn corresponds to BeamHypotheses.is_done in our PyTorch generation code (despite the meaning is reversed: generation done v.s. not done).

The above suggests:

  • The main issue in TFMarian super slow generation comes from the condition around early_stopping
  • With the changes in this PR, it could generate quickly just as the Marian

I run the slow tests for bert, gpt2, bart, t5: One test need to be fixed tests/models/bart/test_modeling_tf_bart.py::TFBartModelTest::test_xla_generate_slow

Copy link
Collaborator Author

@ydshieh ydshieh Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, one thing I don't understand very well is:

this part len(self) < self.num_beams in BeamHypotheses.is_done

if len(self) < self.num_beams:

v.s.

tf.math.reduce_all(is_sent_finished) and/or not_max_length_yet in beam_search_cond_fn.

It doesn't seem 100% equivalent conditions. (But I didn't really go into the details around this part)


def beam_search_body_fn(
cur_len,
Expand Down