diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index ad1521725c..f840662a30 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -2872,11 +2872,16 @@ def _sample( if batch_size > 1 and has_eos_stopping_criteria: eos_token_id = generation_config.eos_token_id + # Init eos_positions + eos_positions = torch.full((batch_size,), start_token_idx, dtype=torch.long, device=input_ids.device) # Find the positions of the first eos_token_id in each sequence - eos_positions = ( - torch.isin(input_ids[:, start_token_idx:], torch.tensor(eos_token_id)).int().argmax(dim=1) - + start_token_idx - ) + eos_positions_tmp = torch.isin( + input_ids[:, start_token_idx:], torch.tensor(eos_token_id).to(device=input_ids.device) + ).int() + if eos_positions_tmp.numel() != 0: + # argmax(dim=1) is throwing this error in eager mode, if the tensor is empty + eos_positions = eos_positions + eos_positions_tmp.argmax(dim=1) + # Create a mask for positions greater than the first eos_token_id mask = torch.arange(generation_config.max_length, device="hpu").expand( batch_size, generation_config.max_length