Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,7 @@ def _beam_search(
trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None,
num_return_sequences: Optional[int] = None,
num_beam_groups: Optional[int] = 1,
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi Jun 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In PyTorch we define a separate beam search method for group beam search:

def group_beam_search(

We only trigger this method if num_beam_groups>1:

is_group_beam_gen_mode = (

My opinion is that we should have a separate group beam search method in Flax as well, rather than adding to the existing one. IMO this is cleaner for the reader and more compartmentalised for building on

cc @gante as well for Flax generate design decision

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sanchit-gandhi!

My first commit was to get a prototype working for num_beam_groups=1. I intend to refactor the beam search logic to make sure it works for other num_beam_groups sizes.

  1. Will do.
  2. My current logic is jittable, as I've been doing some testing from this example. Are there test in the HF repo that explicitly test whether a function is jittable? Or is sufficient to have an E2E test jits the function?
  3. Will do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My opinion is that we should have a separate group beam search method in Flax as well, rather than adding to the existing one.

+1 :)

(btw, there was a recent bugfix on the PT side, might be relevant here)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, sounds good @yeandy! Excited to see how this pans out!

model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
):
"""
Expand Down Expand Up @@ -801,6 +802,8 @@ def gather_fn(tensor):

batch_size, num_beams, cur_len = input_ids.shape

group_size = num_beams // num_beam_groups

eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
cur_len = jnp.array(cur_len)
Expand Down Expand Up @@ -954,7 +957,20 @@ def beam_search_body_fn(state, input_ids_length=1):
state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape
) & (early_stopping is True)
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
topk_log_probs += add_penalty * np.array(-1.0e7)

# Add additional logic for diverse beam search
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! My only nit is that we try and avoid lambda functions in transformers - would you be able to re-write these as standard function definitions please?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lovely thanks

get_finished_indices_fn = lambda x: jax.numpy.where(x, size=beams_to_keep, fill_value=-1)
just_finished_indices = jax.vmap(get_finished_indices_fn, in_axes=0)(did_topk_just_finished)[0]

just_finished_indices_within_group_size = ((just_finished_indices < group_size) & (just_finished_indices >= 0)).astype(jnp.int32)

get_not_finished_indices_truncated_to_group_size_fn = lambda x: jax.numpy.where(x==0, size=group_size)
not_finished_indices_truncated_to_group_size = jax.vmap(get_not_finished_indices_truncated_to_group_size_fn, in_axes=0)(just_finished_indices_within_group_size)[0]

set_group_penalty_mask_fn = lambda x, idx: x.at[idx].set(1) == 0
group_penalty = jax.vmap(set_group_penalty_mask_fn, in_axes=0)(just_finished_indices_within_group_size, not_finished_indices_truncated_to_group_size)

topk_log_probs += (add_penalty * np.array(-1.0e7) + group_penalty * np.array(-1.0e7))

# 7. Get scores, sequences, is sentence finished for next.
# Combine sequences, scores, and flags along the beam dimension and compare
Expand Down