Skip to content

Conversation

@gante
Copy link
Contributor

@gante gante commented Dec 26, 2022

What does this PR do?

As initially uncovered by @ydshieh in #20853, there is a gross TF/PT mismatch on the number of steps beam search takes under some circumstances. In practice, all three frameworks had a different and incomplete implementation (see below why), and this PR fixes it.

Added "🚨🚨" to the title, as this PR may change the output of beam search.

Rationale:

We know that logprobs is a negative value, and we want to maximize it in beam search (i.e. make it as close to 0 as possible). Since logprobs is always negative, and the final score is the sum of the logprobs, we can anticipate the best possible score a running sequence can ever achieve, and use it to terminate beam search early with no drawback (without this shortcut, beam search will always run max_length steps unless early_stopping=True). Well, it turns out that the method to compute the best possible score depends on the signal of length_penalty, and we are not accounting for that!

  • Scenario 1, length_penalty > 0.0: In this case, as the sentence grows, the denominator grows as well. This means the score can get closer to 0 (i.e. higher) as the sentence grows, and longer sentences are promoted. In this case, the best possible score can be determined from the maximum sequence length (original TF/FLAX implementation).
  • Scenario 2, length_penalty < 0.0: In this case, as the sentence grows, the denominator gets smaller. This means the score will get farther away to 0 (i.e. lower) as the sentence grows, and shorter sentences are promoted. In this case, the best possible score can be determined from the current sequence length (original PT implementation).

On top of this, FLAX and TF were incorrectly terminating early when batch_size > 1: we were saying that a score improvement was no longer possible as soon as one of the batch members could no longer improve (as opposed to all batch members can no longer improve).

Finally, there was an issue with TF where early stopping was not correctly triggered (my bad).

In summary, for different reasons, all frameworks were stopping beam search incorrectly under certain circumstances:

  1. PT: when length_penalty > 0.0 (which is the default case!)
  2. Flax: with batch_size > 1 || length_penalty < 0.0
  3. TF: with batch_size > 1 || length_penalty < 0.0 || incorrect (missing) early stopping trigger.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 26, 2022

The documentation is not available anymore as the PR was closed or merged.

@gante gante marked this pull request as ready for review December 26, 2022 17:45
@gante
Copy link
Contributor Author

gante commented Dec 26, 2022

@ydshieh regarding the original issue (#18149) -- the problem was not TF with too many beam search iterations, but rather PT with not enough 😅 After this fix, in the example you shared (which I paste below, for reference), both PT and TF run >300 steps to conclude that "bonjour" is the answer. Please note that TF includes the padding in its output (as opposed to PT, which doesn't) because its output tensors are pre-padded and sliced based on the number of iterations, whereas in PT they are growing tensors that can be stored as candidate outputs without padding.

early_stopping=True can be used with TF for quicker results.


python example:

from transformers import MarianMTModel, MarianTokenizer, TFMarianMTModel
import tensorflow as tf

model_name = "Helsinki-NLP/opus-mt-en-ROMANCE"
tokenizer = MarianTokenizer.from_pretrained(model_name)
text_in = ['>>fr<< hello']

# PT generates a few tokens then stops early -> very fast
model = MarianMTModel.from_pretrained(model_name)
batch = tokenizer(text_in, return_tensors='pt', padding=True)
translated = model.generate(**batch)
o = tokenizer.batch_decode(translated, skip_special_tokens=True)

print(translated)
print(o)

# TF generates 512 tokens, although the decoded version gives the same result as PT -> very slow
model = TFMarianMTModel.from_pretrained(model_name, from_pt=False)
batch = tokenizer(text_in, return_tensors='tf', padding=True)
translated = model.generate(**batch)
o = tokenizer.batch_decode(translated, skip_special_tokens=True)

print(translated)
print(o)

f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)

if "max_length" in kwargs:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_length is now a mandatory argument, so this warning no longer makes sense. The test that confirms that this warning is thrown was also removed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we can delete kwargs then also from the __init__

Comment on lines +887 to +890
if self.length_penalty > 0.0:
cur_score = best_sum_logprobs / self.max_length**self.length_penalty
else:
cur_score = best_sum_logprobs / cur_len**self.length_penalty
Copy link
Contributor Author

@gante gante Dec 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implements the logic as described in the PR header. max_length is now needed as an input, which implied some changes in the tests (not in terms of results, but rather in terms of class initialization).

state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
)
improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
improvement_still_possible = jnp.any(best_running_score > worst_finished_score)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(see comment on the TF implementation)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

incorrectly copied haha - good catch!

is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9
)
improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score)
improvement_still_possible = tf.math.reduce_any(best_running_score > worst_finished_score)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before: improvement was possible when ALL finished scores were worse than the best running scores. In other words, if batch member 0 had running candidates that were better than the finished candidates' scores, but batch member 1 did not, this condition would evaluate to False because of batch member 1. This means that we were terminating beam search even though an improvement as still possible for batch member 0.

still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping)

return not_max_length_yet & (still_open_beam | improvement_still_possible)
return not_max_length_yet & still_open_beam & improvement_still_possible
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙈 this is what happens when test fixing is done without thinking deeply on the subject: the previous condition, combined with the previous handling of improvement_still_possible and best_running_score, made all tests pass. But for the wrong reasons -- early_stopping=True was not operating as intended before

Comment on lines +1854 to +1863
# fix config for models with additional sequence-length limiting settings
for var_name in ["max_position_embeddings", "max_target_positions"]:
attr = getattr(config, var_name, None)
if attr is not None and attr < generate_kwargs["max_new_tokens"]:
try:
setattr(config, var_name, generate_kwargs["max_new_tokens"])
except NotImplementedError:
# xlnet will raise an exception when trying to set
# max_position_embeddings.
pass
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was incorrectly removed here, causing some XLA tests to fail.

'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
" cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
" magazine says . all 150 on board were killed when germanwings flight 9525 crashed .",
" magazine says . all 150 on board were killed in the crash .",
Copy link
Contributor Author

@gante gante Dec 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran slow tests on T5, GPT2, and BART, for all 3 frameworks, this was the only observed difference. It is to be noted that Flax's output for this particular test was already different from TF's and PT's outputs. Also, it fits one of the criteria for incorrect Flax results (batch_size > 1)

(I suspect that slow generation tests for other models may have mismatches, I'm delegating the task to the daily CI to track and fix them)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the fixes! LGTM!

self,
batch_size: int,
num_beams: int,
max_length: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BeamSearchScorer is a public class (even though, I don't think it's used that much on its own), do you think we could maybe do:

Suggested change
max_length: int,
max_length: Optional[int] = None,

and throw an error if do_early_stopping is set to False that says that one should do_early_stopping=True to not have to pass max_length.

else:
cur_score = best_sum_logprobs / cur_len**self.length_penalty
if self.length_penalty > 0.0:
cur_score = best_sum_logprobs / self.max_length**self.length_penalty
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great find!


# 2. can the new beams still improve?
best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
if length_penalty > 0.0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with XLA? Gives no errors?

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Dec 30, 2022

That's a great find! Well done, on finding the inconsistency here.

While this change is mathematically completely correct, I'm a bit worried whether it leads to bad/annoying side-effects in practice. I think most people don't think too deeply about length_pentalty and just use a parameter that works "well enough".

There are some problems here I think:

  • 1.) As noted the default case is length_penalty=1.0 and do_early_stopping=False which means that this PR changes the default case of all beam search applications. While it will certainly always improve "mathematically" the output result there are two problems in practice:
  • 1.1) Some people have probably unknowingly found a high length_penalty to work reasonably well. A high length_penalty combined with a high max_length can now lead to the beam search giving some super long results as the best solution (which would be mathematically correct given the high length_penalty, but I don't think people understand/understood the length penalty well enough to understand why this is).
  • 1.2) Beam search will now always run much much longer if max_length is very high (there are lots of models with set max_length to something like 128 or even 256 for short sentence tasks like translation.
  • 2.) (smaller problem) - we were trying to move away from having to require max_length overall - ideally the user should be able to use any kind of stopping criteria with beam search.

2.) is not a big problem, but I'm a bit worried that 1.) is one. What do you think about 1.) @gante - especially when looking at generation configs like the one of BART (the model is downloaded a lot and has many "derivation" models):

The change here is definitely logically/mathematically correct, but I'm worried that it has too many negative effects. It's also a bit unreasonable when doing the math:

best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)

for max_length=256 and length_penalty=2 will essentially make beam search rarely stop before the end x/(256*256) = x/65536 is very low for log-probs no? Or do log-probs became extremely large as soon as the text becomes bad?

On the other hand, maybe the log probs become very quickly so low for bad results that this change doesn't have that much of an impact. Can we maybe run some tests here @gante ? Maybe with the default setting of https://huggingface.co/facebook/bart-large-cnn/blob/main/config.json#L42 . If there are no major changes in outputs, ok to merge for me!

Also should we maybe add a warning "We detected that you use length_penalty > 1.0 which strongly encourages long sequences to be generated. Recently there has been a change that might cause your generation to last longer than expected and lead to different results. You might want to consider lowering the length_penalty."
?

@gante
Copy link
Contributor Author

gante commented Jan 2, 2023

@patrickvonplaten I agree entirely with your points above. Yes, these changes are technically correct, but the cost can be quite high -- here's a rundown of the results in a few models, for the PT changes:

  1. Models with early_stopping=True in the config, such as facebook/bart-large-cnn: no output change, same number of beam search iterations 👍
  2. Models with early_stopping=False in the config, such as Marian or T5: no output change, one order of magnitude (!) more iterations for short inputs 🙅 This is because of what you wrote above -- the best_running_score can stay very high for a large number of iterations, even with length_penalty=1.0.

This probably means that the output text will only see changes in corner cases, which removes some of our concerns regarding this PR. However, the additional computational cost can be prohibitively high in some typical applications. That will likely create annoyed users, which does not seem wise.


So, what can we do here?
a) Don't merge some or all of the changes, especially on the PT side, since they introduce unwanted (although correct) behavior. [probably not great, as we would be intentionally keeping a bug in the code]
b) Add warnings so that users pick the right flags. [users ignore warnings most of the time...]
c) Add some flag and/or transformers version gating, to keep the old behavior. [adds complexity, undesirable and, like b), requires users to use flags]
d) Update the default length_penalty to 0.0, which stops biasing beam search toward long searches. In the examples I tried, this keeps the same outputs while not causing the number of beam search iterations to grow with this PR. [changing a default can be tricky, and some models might rely on length_penalty=1.0 to get the expected output. On the plus side, most users intuitively think that a positive length_penalty promotes shorter sentences, which is not true, so we might be killing two birds with one stone]
e) Update the default of early_stopping to True. [similar to d), but less good imo]

I struggle to see a good compromise solution 🤔 Given that many research groups use our code to conduct research, I'd like to avoid a) (i.e. keeping the bug). For downstream users, assuming that most wouldn't react to announcements, we will have to pick between keeping a bug or risking changing behavior :(

Personally, I'd go with d), but it is extremely debatable (and you folks probably have more experience).

P.S.: TF XLA benchmarks showed that it was not much faster with beam search, compared to PT. Maybe this problem explains part of it!

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @gante ❤️ !
My only concern is the added positional argument, similar to what @patrickvonplaten said

https://github.com/huggingface/transformers/pull/20901/files#r1059400589

Note the suggested change (if you decide to apply) will need to change the arg. position.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jan 3, 2023

Hmmm, ok this is a very tricky one then :-/

length_penalty is a pretty important parameter, and it's somewhat natural IMO to bias the model to slightly prefer longer output lengths (as longer output sequences always have <= log prob than shorter sequences). I think especially summarization models gain performance from using a length penalty.

Just to better understand, are there a lot of cases where the current implementation (the correct use of length penalty) leads to better results? Could you maybe post some failure cases of the current implementation?

@patrickvonplaten
Copy link
Contributor

Another option would be to frame everything as setting a "lower bound".

Currently, we have a "heustic lower bound" in PT, another option as done is this PR is a "absolute lower bound"

@gante
Copy link
Contributor Author

gante commented Jan 3, 2023

@patrickvonplaten some data about a potential length_penalty change -- I've tried setting the default to 0.0 (from 1.0), and run our test suite for potentially impacted tests. More precisely, running NVIDIA_TF32_OVERRIDE=0 RUN_SLOW=1 py.test tests/ -k WORD -vv, with WORD = {beam_search, summ, translat}, which catches most (or all) of the hard beam search tests on all 3 frameworks, had the following results:

  • 810 tests ran in total, including the challenging generate tests for beam search
  • 4 failed due to GPU OOM
  • 1 TF test failed (on T5-small, a translation outcome was ruined by the change -- Ich liebe es so sehr! to ! )
  • 1 PT test failed (on a pipeline test, a translation had 1 differing character but was equally correct -- هذا اختبار to هذا إختبار)

Looking at the catastrophic failure in the TF test, having the right length_penalty does make a difference, so a change may result in very annoyed users 👎


I like the "lower bound" framing, with users being able to pick how precise they want to be in their beam search while keeping the current defaults. However, I'm reluctant to add yet another flag. We could change the early_stopping flag from a binary one to a ternary one (like the verbose flag in many CLIs), as it already controls how long beam search runs. Something like:

  1. [no change] early_stopping = 0 would be equivalent to early_stopping = false (on PyTorch, i.e. stops in a few iterations because it does not consider the max_length when computing the best score). This would be the default;
  2. [no change] early_stopping = 1 would be equivalent to early_stopping = true;
  3. [new] early_stopping = -1 would be the mathematically correct (yet ineffective) best possible score computation.

That way:

  1. TF/FLAX would start behaving like PT, running fewer beam search iterations by default with minimal impact on the output;
  2. PT users would see no changes;
  3. Users still have the option of setting the mathematically correct version of beam search.

WDYT?

@patrickvonplaten
Copy link
Contributor

Nice good idea! I like the idea of using early_stopping to decide what do here! Would probably slightly favor:

early_stopping: Union[bool, str] = {False, True, "never"}

Guess we have to leave the reasoning of False as is for PyTorch. Using 1,0,-1 is also ok for me, but think it's nicer for the user to make early_stopping accept both str and bool

@gante
Copy link
Contributor Author

gante commented Jan 30, 2023

Applied the contents of the discussion in #21368, closing this one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants