Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2612,15 +2612,24 @@ def _sample(
streamer.end()

if batch_size > 1 and has_eos_stopping_criteria:
# cover the over-generated tokens after eos_token with pad_token
eos_token_id = generation_config.eos_token_id
idx_bs = generation_config.max_length
def find_first_eos_token_idx_in_input_ids(batch_id) -> int :
idx = 0
max_length = len(input_ids[batch_id])
while idx < max_length and input_ids[batch_id][idx] == pad_token_id :
idx = idx+1
if isinstance(eos_token_id, list) :
while idx < max_length and input_ids[batch_id][idx] not in eos_token_id:
idx = idx+1
elif isinstance(eos_token_id, int) :
while idx < max_length and input_ids[batch_id][idx] != eos_token_id:
idx = idx+1
return idx
for i in range(batch_size):
for idx in range(len(input_ids[i])):
if input_ids[i][idx] == eos_token_id:
idx_bs = idx
if idx > idx_bs:
input_ids[i][idx] = pad_token_id
idx_bs = generation_config.max_length
eos_idx = find_first_eos_token_idx_in_input_ids(i)
for j in range(eos_idx+1, len(input_ids[i])) :
input_ids[i][j] = pad_token_id

if return_dict_in_generate:
if self.config.is_encoder_decoder:
Expand Down