diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 2f65d26fdc86..440a0ae26275 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -448,10 +448,11 @@ def generate( model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams ) - if "attention_mask" in model_kwargs: - model_kwargs["attention_mask"] = self._expand_to_num_beams( - model_kwargs["attention_mask"], num_beams=generation_config.num_beams - ) + for kwarg in ["attention_mask", "decoder_attention_mask"]: + if kwarg in model_kwargs: + model_kwargs[kwarg] = self._expand_to_num_beams( + model_kwargs[kwarg], num_beams=generation_config.num_beams + ) return self._beam_search( input_ids, @@ -821,8 +822,9 @@ def gather_fn(tensor): model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim( model_kwargs["encoder_outputs"]["last_hidden_state"] ) - if "attention_mask" in model_kwargs: - model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"]) + for kwarg in ["attention_mask", "decoder_attention_mask"]: + if kwarg in model_kwargs: + model_kwargs[kwarg] = flatten_beam_dim(model_kwargs[kwarg]) # initialize model specific kwargs model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)