Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,11 @@ def generate(
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams
)

if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = self._expand_to_num_beams(
model_kwargs["attention_mask"], num_beams=generation_config.num_beams
)
for kwarg in ["attention_mask", "decoder_attention_mask"]:
if kwarg in model_kwargs:
model_kwargs[kwarg] = self._expand_to_num_beams(
model_kwargs[kwarg], num_beams=generation_config.num_beams
)

return self._beam_search(
input_ids,
Expand Down Expand Up @@ -821,8 +822,9 @@ def gather_fn(tensor):
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
model_kwargs["encoder_outputs"]["last_hidden_state"]
)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
for kwarg in ["attention_mask", "decoder_attention_mask"]:
if kwarg in model_kwargs:
model_kwargs[kwarg] = flatten_beam_dim(model_kwargs[kwarg])

# initialize model specific kwargs
model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
Expand Down