-
Notifications
You must be signed in to change notification settings - Fork 32k
🚨🚨 Generate: correct beam search best possible score computation and handling #20901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,6 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import warnings | ||
| from abc import ABC, abstractmethod | ||
| from collections import UserDict | ||
| from typing import List, Optional, Tuple | ||
|
|
@@ -156,6 +155,7 @@ def __init__( | |
| self, | ||
| batch_size: int, | ||
| num_beams: int, | ||
| max_length: int, | ||
| device: torch.device, | ||
| length_penalty: Optional[float] = 1.0, | ||
| do_early_stopping: Optional[bool] = False, | ||
|
|
@@ -167,6 +167,7 @@ def __init__( | |
| self.device = device | ||
| self.length_penalty = length_penalty | ||
| self.do_early_stopping = do_early_stopping | ||
| self.max_length = max_length | ||
| self.num_beam_hyps_to_keep = num_beam_hyps_to_keep | ||
| self.num_beam_groups = num_beam_groups | ||
| self.group_size = self.num_beams // self.num_beam_groups | ||
|
|
@@ -177,6 +178,7 @@ def __init__( | |
| num_beams=self.num_beams, | ||
| length_penalty=self.length_penalty, | ||
| early_stopping=self.do_early_stopping, | ||
| max_length=self.max_length, | ||
| ) | ||
| for _ in range(batch_size) | ||
| ] | ||
|
|
@@ -194,13 +196,6 @@ def __init__( | |
| f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." | ||
| ) | ||
|
|
||
| if "max_length" in kwargs: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think we can delete |
||
| warnings.warn( | ||
| "Passing `max_length` to BeamSearchScorer is deprecated and has no effect. " | ||
| "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" | ||
| ", or `group_beam_search(...)`." | ||
| ) | ||
|
|
||
| @property | ||
| def is_done(self) -> bool: | ||
| return self._done.all() | ||
|
|
@@ -424,6 +419,7 @@ def __init__( | |
| self, | ||
| batch_size: int, | ||
| num_beams: int, | ||
| max_length: int, | ||
| constraints: List[Constraint], | ||
| device: torch.device, | ||
| length_penalty: Optional[float] = 1.0, | ||
|
|
@@ -436,6 +432,7 @@ def __init__( | |
| self.device = device | ||
| self.length_penalty = length_penalty | ||
| self.do_early_stopping = do_early_stopping | ||
| self.max_length = max_length | ||
| self.num_beam_hyps_to_keep = num_beam_hyps_to_keep | ||
| self.num_beam_groups = num_beam_groups | ||
| self.group_size = self.num_beams // self.num_beam_groups | ||
|
|
@@ -447,6 +444,7 @@ def __init__( | |
| num_beams=self.num_beams, | ||
| length_penalty=self.length_penalty, | ||
| early_stopping=self.do_early_stopping, | ||
| max_length=self.max_length, | ||
| ) | ||
| for _ in range(batch_size) | ||
| ] | ||
|
|
@@ -464,13 +462,6 @@ def __init__( | |
| f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." | ||
| ) | ||
|
|
||
| if "max_length" in kwargs: | ||
| warnings.warn( | ||
| "Passing `max_length` to ConstrainedBeamSearchScorer is deprecated and has no effect. " | ||
| "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" | ||
| ", or `group_beam_search(...)`." | ||
| ) | ||
|
|
||
| @property | ||
| def is_done(self) -> bool: | ||
| return self._done.all() | ||
|
|
@@ -851,12 +842,13 @@ def finalize( | |
|
|
||
|
|
||
| class BeamHypotheses: | ||
| def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool): | ||
| def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: int): | ||
| """ | ||
| Initialize n-best list of hypotheses. | ||
| """ | ||
| self.length_penalty = length_penalty | ||
| self.early_stopping = early_stopping | ||
| self.max_length = max_length | ||
| self.num_beams = num_beams | ||
| self.beams = [] | ||
| self.worst_score = 1e9 | ||
|
|
@@ -892,6 +884,9 @@ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: | |
| elif self.early_stopping: | ||
| return True | ||
| else: | ||
| cur_score = best_sum_logprobs / cur_len**self.length_penalty | ||
| if self.length_penalty > 0.0: | ||
| cur_score = best_sum_logprobs / self.max_length**self.length_penalty | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great find! |
||
| else: | ||
| cur_score = best_sum_logprobs / cur_len**self.length_penalty | ||
|
Comment on lines
+887
to
+890
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Implements the logic as described in the PR header. |
||
| ret = self.worst_score >= cur_score | ||
| return ret | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -797,11 +797,14 @@ def beam_search_cond_fn(state): | |
| not_max_length_yet = state.cur_len < max_length | ||
|
|
||
| # 2. can the new beams still improve? | ||
| best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty) | ||
| if length_penalty > 0.0: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work with XLA? Gives no errors? |
||
| best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty) | ||
| else: | ||
| best_running_score = state.running_scores[:, -1:] / (state.cur_len**length_penalty) | ||
| worst_finished_score = jnp.where( | ||
| state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7) | ||
| ) | ||
| improvement_still_possible = jnp.all(worst_finished_score < best_running_score) | ||
| improvement_still_possible = jnp.any(best_running_score > worst_finished_score) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (see comment on the TF implementation)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For reference this was copied from: https://github.com/google/flax/blob/2d79fdb5adeb97e610f72e22c8cfb148a7017556/examples/wmt/decode.py#L223
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. incorrectly copied haha - good catch! |
||
|
|
||
| # 3. is there still a beam that has not finished? | ||
| still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3066,16 +3066,19 @@ def beam_search_cond_fn( | |
| not_max_length_yet = cur_len < max_length | ||
|
|
||
| # 2. can the new beams still improve? | ||
| best_running_score = running_scores[:, :1] / (max_length**length_penalty) | ||
| if length_penalty > 0.0: | ||
| best_running_score = running_scores[:, :1] / (max_length**length_penalty) | ||
| else: | ||
| best_running_score = running_scores[:, :1] / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty) | ||
| worst_finished_score = tf.where( | ||
| is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9 | ||
| ) | ||
| improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score) | ||
| improvement_still_possible = tf.math.reduce_any(best_running_score > worst_finished_score) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before: improvement was possible when ALL finished scores were worse than the best running scores. In other words, if batch member 0 had running candidates that were better than the finished candidates' scores, but batch member 1 did not, this condition would evaluate to |
||
|
|
||
| # 3. is there still a beam that has not finished? | ||
| still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping) | ||
|
|
||
| return not_max_length_yet & (still_open_beam | improvement_still_possible) | ||
| return not_max_length_yet & still_open_beam & improvement_still_possible | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🙈 this is what happens when test fixing is done without thinking deeply on the subject: the previous condition, combined with the previous handling of |
||
|
|
||
| def beam_search_body_fn( | ||
| cur_len, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BeamSearchScorer is a public class (even though, I don't think it's used that much on its own), do you think we could maybe do:
and throw an error if
do_early_stoppingis set toFalsethat says that one shoulddo_early_stopping=Trueto not have to passmax_length.