From 072f785eccd4aedf40a7342e2e4ae7bb2ef32755 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Mon, 26 Jun 2023 23:13:51 +0000 Subject: [PATCH] Add flax diverse group search --- src/transformers/generation/flax_utils.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index f18cc0ea84ff..efe0948d9aaa 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -751,6 +751,7 @@ def _beam_search( trace: bool = True, params: Optional[Dict[str, jnp.ndarray]] = None, num_return_sequences: Optional[int] = None, + num_beam_groups: Optional[int] = 1, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, ): """ @@ -801,6 +802,8 @@ def gather_fn(tensor): batch_size, num_beams, cur_len = input_ids.shape + group_size = num_beams // num_beam_groups + eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None) pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32) cur_len = jnp.array(cur_len) @@ -954,7 +957,20 @@ def beam_search_body_fn(state, input_ids_length=1): state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape ) & (early_stopping is True) add_penalty = ~did_topk_just_finished | beams_in_batch_are_full - topk_log_probs += add_penalty * np.array(-1.0e7) + + # Add additional logic for diverse beam search + get_finished_indices_fn = lambda x: jax.numpy.where(x, size=beams_to_keep, fill_value=-1) + just_finished_indices = jax.vmap(get_finished_indices_fn, in_axes=0)(did_topk_just_finished)[0] + + just_finished_indices_within_group_size = ((just_finished_indices < group_size) & (just_finished_indices >= 0)).astype(jnp.int32) + + get_not_finished_indices_truncated_to_group_size_fn = lambda x: jax.numpy.where(x==0, size=group_size) + not_finished_indices_truncated_to_group_size = jax.vmap(get_not_finished_indices_truncated_to_group_size_fn, in_axes=0)(just_finished_indices_within_group_size)[0] + + set_group_penalty_mask_fn = lambda x, idx: x.at[idx].set(1) == 0 + group_penalty = jax.vmap(set_group_penalty_mask_fn, in_axes=0)(just_finished_indices_within_group_size, not_finished_indices_truncated_to_group_size) + + topk_log_probs += (add_penalty * np.array(-1.0e7) + group_penalty * np.array(-1.0e7)) # 7. Get scores, sequences, is sentence finished for next. # Combine sequences, scores, and flags along the beam dimension and compare