-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Fix TF generation (especially for TFMarian)
#20853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3066,16 +3066,20 @@ def beam_search_cond_fn( | |||||||||
| not_max_length_yet = cur_len < max_length | ||||||||||
|
|
||||||||||
| # 2. can the new beams still improve? | ||||||||||
| best_running_score = running_scores[:, :1] / (max_length**length_penalty) | ||||||||||
| best_running_score = running_scores[:, :1] / tf.cast(cur_len, dtype=running_scores.dtype) ** length_penalty | ||||||||||
| worst_finished_score = tf.where( | ||||||||||
| is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9 | ||||||||||
| ) | ||||||||||
| improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score) | ||||||||||
|
|
||||||||||
| # 3. is there still a beam that has not finished? | ||||||||||
| still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping) | ||||||||||
| # still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping) | ||||||||||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line should be removed before merge |
||||||||||
| still_open_beam = ~(tf.math.reduce_all(is_sent_finished)) | ||||||||||
|
|
||||||||||
| return not_max_length_yet & (still_open_beam | improvement_still_possible) | ||||||||||
| _early_stopping = tf.constant(early_stopping > 0, dtype=tf.bool) | ||||||||||
|
|
||||||||||
| # return not_max_length_yet & (still_open_beam | improvement_still_possible) | ||||||||||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be removed before merge |
||||||||||
| return not_max_length_yet & (still_open_beam | (~_early_stopping & improvement_still_possible)) | ||||||||||
|
Comment on lines
+3076
to
+3082
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The method
The above suggests:
I run the slow tests for
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. However, one thing I don't understand very well is: this part
v.s.
It doesn't seem 100% equivalent conditions. (But I didn't really go into the details around this part) |
||||||||||
|
|
||||||||||
| def beam_search_body_fn( | ||||||||||
| cur_len, | ||||||||||
|
|
||||||||||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In current
mainbranch,max_lengthis used instead ofcur_len. However, in our PyTorch generation'sBeamHypotheses, it iscur_len, seetransformers/src/transformers/generation/beam_search.py
Line 895 in 3be028b
When running the code snippet in the reported TFMarian issue (#18149), we get
max_lengthbeing a constant of512, but the PyTorch generation code runs withcur_lenwhich is from1(or2) to5.(However, this is not the root cause of the issue in #18149)