-
Notifications
You must be signed in to change notification settings - Fork 31.8k
[WIP] Add Flax diverse group search #24508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -751,6 +751,7 @@ def _beam_search( | |
| trace: bool = True, | ||
| params: Optional[Dict[str, jnp.ndarray]] = None, | ||
| num_return_sequences: Optional[int] = None, | ||
| num_beam_groups: Optional[int] = 1, | ||
| model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, | ||
| ): | ||
| """ | ||
|
|
@@ -801,6 +802,8 @@ def gather_fn(tensor): | |
|
|
||
| batch_size, num_beams, cur_len = input_ids.shape | ||
|
|
||
| group_size = num_beams // num_beam_groups | ||
|
|
||
| eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None) | ||
| pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32) | ||
| cur_len = jnp.array(cur_len) | ||
|
|
@@ -954,7 +957,20 @@ def beam_search_body_fn(state, input_ids_length=1): | |
| state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape | ||
| ) & (early_stopping is True) | ||
| add_penalty = ~did_topk_just_finished | beams_in_batch_are_full | ||
| topk_log_probs += add_penalty * np.array(-1.0e7) | ||
|
|
||
| # Add additional logic for diverse beam search | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! My only nit is that we try and avoid
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lovely thanks |
||
| get_finished_indices_fn = lambda x: jax.numpy.where(x, size=beams_to_keep, fill_value=-1) | ||
| just_finished_indices = jax.vmap(get_finished_indices_fn, in_axes=0)(did_topk_just_finished)[0] | ||
|
|
||
| just_finished_indices_within_group_size = ((just_finished_indices < group_size) & (just_finished_indices >= 0)).astype(jnp.int32) | ||
|
|
||
| get_not_finished_indices_truncated_to_group_size_fn = lambda x: jax.numpy.where(x==0, size=group_size) | ||
| not_finished_indices_truncated_to_group_size = jax.vmap(get_not_finished_indices_truncated_to_group_size_fn, in_axes=0)(just_finished_indices_within_group_size)[0] | ||
|
|
||
| set_group_penalty_mask_fn = lambda x, idx: x.at[idx].set(1) == 0 | ||
| group_penalty = jax.vmap(set_group_penalty_mask_fn, in_axes=0)(just_finished_indices_within_group_size, not_finished_indices_truncated_to_group_size) | ||
|
|
||
| topk_log_probs += (add_penalty * np.array(-1.0e7) + group_penalty * np.array(-1.0e7)) | ||
|
|
||
| # 7. Get scores, sequences, is sentence finished for next. | ||
| # Combine sequences, scores, and flags along the beam dimension and compare | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In PyTorch we define a separate beam search method for group beam search:
transformers/src/transformers/generation/utils.py
Line 3375 in 33b5ef5
We only trigger this method if
num_beam_groups>1:transformers/src/transformers/generation/utils.py
Line 1426 in 33b5ef5
My opinion is that we should have a separate group beam search method in Flax as well, rather than adding to the existing one. IMO this is cleaner for the reader and more compartmentalised for building on
cc @gante as well for Flax generate design decision
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sanchit-gandhi!
My first commit was to get a prototype working for
num_beam_groups=1. I intend to refactor the beam search logic to make sure it works for othernum_beam_groupssizes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 :)
(btw, there was a recent bugfix on the PT side, might be relevant here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, sounds good @yeandy! Excited to see how this pans out!