diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 1bdf58691a80..3a89c1ed41d2 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -911,7 +911,7 @@ def beam_search_body_fn(state, input_ids_length=1): # add new logprobs to existing running logprobs scores. log_probs = jax.nn.log_softmax(logits) log_probs = logits_processor( - flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len + flatten_beam_dim(state.running_sequences), flatten_beam_dim(log_probs), state.cur_len ) log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)