-
Notifications
You must be signed in to change notification settings - Fork 31.9k
🚨🚨 Generate: correct beam search best possible score computation and handling #20901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
@ydshieh regarding the original issue (#18149) -- the problem was not TF with too many beam search iterations, but rather PT with not enough 😅 After this fix, in the example you shared (which I paste below, for reference), both PT and TF run >300 steps to conclude that "bonjour" is the answer. Please note that TF includes the padding in its output (as opposed to PT, which doesn't) because its output tensors are pre-padded and sliced based on the number of iterations, whereas in PT they are growing tensors that can be stored as candidate outputs without padding.
python example: from transformers import MarianMTModel, MarianTokenizer, TFMarianMTModel
import tensorflow as tf
model_name = "Helsinki-NLP/opus-mt-en-ROMANCE"
tokenizer = MarianTokenizer.from_pretrained(model_name)
text_in = ['>>fr<< hello']
# PT generates a few tokens then stops early -> very fast
model = MarianMTModel.from_pretrained(model_name)
batch = tokenizer(text_in, return_tensors='pt', padding=True)
translated = model.generate(**batch)
o = tokenizer.batch_decode(translated, skip_special_tokens=True)
print(translated)
print(o)
# TF generates 512 tokens, although the decoded version gives the same result as PT -> very slow
model = TFMarianMTModel.from_pretrained(model_name, from_pt=False)
batch = tokenizer(text_in, return_tensors='tf', padding=True)
translated = model.generate(**batch)
o = tokenizer.batch_decode(translated, skip_special_tokens=True)
print(translated)
print(o) |
| f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." | ||
| ) | ||
|
|
||
| if "max_length" in kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_length is now a mandatory argument, so this warning no longer makes sense. The test that confirms that this warning is thrown was also removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think we can delete kwargs then also from the __init__
| if self.length_penalty > 0.0: | ||
| cur_score = best_sum_logprobs / self.max_length**self.length_penalty | ||
| else: | ||
| cur_score = best_sum_logprobs / cur_len**self.length_penalty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implements the logic as described in the PR header. max_length is now needed as an input, which implied some changes in the tests (not in terms of results, but rather in terms of class initialization).
| state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7) | ||
| ) | ||
| improvement_still_possible = jnp.all(worst_finished_score < best_running_score) | ||
| improvement_still_possible = jnp.any(best_running_score > worst_finished_score) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(see comment on the TF implementation)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For reference this was copied from: https://github.com/google/flax/blob/2d79fdb5adeb97e610f72e22c8cfb148a7017556/examples/wmt/decode.py#L223
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
incorrectly copied haha - good catch!
| is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9 | ||
| ) | ||
| improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score) | ||
| improvement_still_possible = tf.math.reduce_any(best_running_score > worst_finished_score) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before: improvement was possible when ALL finished scores were worse than the best running scores. In other words, if batch member 0 had running candidates that were better than the finished candidates' scores, but batch member 1 did not, this condition would evaluate to False because of batch member 1. This means that we were terminating beam search even though an improvement as still possible for batch member 0.
| still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping) | ||
|
|
||
| return not_max_length_yet & (still_open_beam | improvement_still_possible) | ||
| return not_max_length_yet & still_open_beam & improvement_still_possible |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙈 this is what happens when test fixing is done without thinking deeply on the subject: the previous condition, combined with the previous handling of improvement_still_possible and best_running_score, made all tests pass. But for the wrong reasons -- early_stopping=True was not operating as intended before
| # fix config for models with additional sequence-length limiting settings | ||
| for var_name in ["max_position_embeddings", "max_target_positions"]: | ||
| attr = getattr(config, var_name, None) | ||
| if attr is not None and attr < generate_kwargs["max_new_tokens"]: | ||
| try: | ||
| setattr(config, var_name, generate_kwargs["max_new_tokens"]) | ||
| except NotImplementedError: | ||
| # xlnet will raise an exception when trying to set | ||
| # max_position_embeddings. | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was incorrectly removed here, causing some XLA tests to fail.
| 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' | ||
| " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" | ||
| " magazine says . all 150 on board were killed when germanwings flight 9525 crashed .", | ||
| " magazine says . all 150 on board were killed in the crash .", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ran slow tests on T5, GPT2, and BART, for all 3 frameworks, this was the only observed difference. It is to be noted that Flax's output for this particular test was already different from TF's and PT's outputs. Also, it fits one of the criteria for incorrect Flax results (batch_size > 1)
(I suspect that slow generation tests for other models may have mismatches, I'm delegating the task to the daily CI to track and fix them)
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the fixes! LGTM!
| self, | ||
| batch_size: int, | ||
| num_beams: int, | ||
| max_length: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BeamSearchScorer is a public class (even though, I don't think it's used that much on its own), do you think we could maybe do:
| max_length: int, | |
| max_length: Optional[int] = None, |
and throw an error if do_early_stopping is set to False that says that one should do_early_stopping=True to not have to pass max_length.
| else: | ||
| cur_score = best_sum_logprobs / cur_len**self.length_penalty | ||
| if self.length_penalty > 0.0: | ||
| cur_score = best_sum_logprobs / self.max_length**self.length_penalty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great find!
|
|
||
| # 2. can the new beams still improve? | ||
| best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty) | ||
| if length_penalty > 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work with XLA? Gives no errors?
|
That's a great find! Well done, on finding the inconsistency here. While this change is mathematically completely correct, I'm a bit worried whether it leads to bad/annoying side-effects in practice. I think most people don't think too deeply about There are some problems here I think:
2.) is not a big problem, but I'm a bit worried that 1.) is one. What do you think about 1.) @gante - especially when looking at generation configs like the one of BART (the model is downloaded a lot and has many "derivation" models): The change here is definitely logically/mathematically correct, but I'm worried that it has too many negative effects. It's also a bit unreasonable when doing the math: for On the other hand, maybe the log probs become very quickly so low for bad results that this change doesn't have that much of an impact. Can we maybe run some tests here @gante ? Maybe with the default setting of https://huggingface.co/facebook/bart-large-cnn/blob/main/config.json#L42 . If there are no major changes in outputs, ok to merge for me! Also should we maybe add a warning "We detected that you use |
|
@patrickvonplaten I agree entirely with your points above. Yes, these changes are technically correct, but the cost can be quite high -- here's a rundown of the results in a few models, for the PT changes:
This probably means that the output text will only see changes in corner cases, which removes some of our concerns regarding this PR. However, the additional computational cost can be prohibitively high in some typical applications. That will likely create annoyed users, which does not seem wise. So, what can we do here? I struggle to see a good compromise solution 🤔 Given that many research groups use our code to conduct research, I'd like to avoid a) (i.e. keeping the bug). For downstream users, assuming that most wouldn't react to announcements, we will have to pick between keeping a bug or risking changing behavior :( Personally, I'd go with d), but it is extremely debatable (and you folks probably have more experience). P.S.: TF XLA benchmarks showed that it was not much faster with beam search, compared to PT. Maybe this problem explains part of it! |
ydshieh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @gante ❤️ !
My only concern is the added positional argument, similar to what @patrickvonplaten said
https://github.com/huggingface/transformers/pull/20901/files#r1059400589
Note the suggested change (if you decide to apply) will need to change the arg. position.
|
Hmmm, ok this is a very tricky one then :-/
Just to better understand, are there a lot of cases where the current implementation (the correct use of length penalty) leads to better results? Could you maybe post some failure cases of the current implementation? |
|
Another option would be to frame everything as setting a "lower bound". Currently, we have a "heustic lower bound" in PT, another option as done is this PR is a "absolute lower bound" |
|
@patrickvonplaten some data about a potential
Looking at the catastrophic failure in the TF test, having the right I like the "lower bound" framing, with users being able to pick how precise they want to be in their beam search while keeping the current defaults. However, I'm reluctant to add yet another flag. We could change the
That way:
WDYT? |
|
Nice good idea! I like the idea of using
Guess we have to leave the reasoning of |
|
Applied the contents of the discussion in #21368, closing this one. |
What does this PR do?
As initially uncovered by @ydshieh in #20853, there is a gross TF/PT mismatch on the number of steps beam search takes under some circumstances. In practice, all three frameworks had a different and incomplete implementation (see below why), and this PR fixes it.
Added "🚨🚨" to the title, as this PR may change the output of beam search.
Rationale:
We know that logprobs is a negative value, and we want to maximize it in beam search (i.e. make it as close to 0 as possible). Since logprobs is always negative, and the final score is the sum of the logprobs, we can anticipate the best possible score a running sequence can ever achieve, and use it to terminate beam search early with no drawback (without this shortcut, beam search will always run
max_lengthsteps unlessearly_stopping=True). Well, it turns out that the method to compute the best possible score depends on the signal oflength_penalty, and we are not accounting for that!length_penalty > 0.0: In this case, as the sentence grows, the denominator grows as well. This means the score can get closer to 0 (i.e. higher) as the sentence grows, and longer sentences are promoted. In this case, the best possible score can be determined from the maximum sequence length (original TF/FLAX implementation).length_penalty < 0.0: In this case, as the sentence grows, the denominator gets smaller. This means the score will get farther away to 0 (i.e. lower) as the sentence grows, and shorter sentences are promoted. In this case, the best possible score can be determined from the current sequence length (original PT implementation).On top of this, FLAX and TF were incorrectly terminating early when
batch_size > 1: we were saying that a score improvement was no longer possible as soon as one of the batch members could no longer improve (as opposed to all batch members can no longer improve).Finally, there was an issue with TF where early stopping was not correctly triggered (my bad).
In summary, for different reasons, all frameworks were stopping beam search incorrectly under certain circumstances:
length_penalty > 0.0(which is the default case!)batch_size > 1||length_penalty < 0.0batch_size > 1||length_penalty < 0.0|| incorrect (missing) early stopping trigger.