From 8c0ead3e6dc2983b74ef7f24d2808426785fbfdf Mon Sep 17 00:00:00 2001 From: Jianhong-Zhang Date: Tue, 24 Sep 2024 17:43:19 -0700 Subject: [PATCH] Fix GPT_neox incorrect output with batch query --- optimum/habana/transformers/generation/utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 7315ff0bff..ec22cb00b3 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -2437,6 +2437,17 @@ def _sample( if streamer is not None: streamer.end() + if batch_size > 1 and has_eos_stopping_criteria: + eos_token_id = generation_config.eos_token_id + idx_bs = generation_config.max_length + for i in range(batch_size): + for idx in range(len(input_ids[i])): + if input_ids[i][idx] == eos_token_id: + idx_bs = idx + if idx > idx_bs: + input_ids[i][idx] = pad_token_id + idx_bs = generation_config.max_length + if return_dict_in_generate: if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput(