diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 36dd62e743..86cd3ed3da 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -2430,6 +2430,7 @@ def _contrastive_search( do_padding = ( key_to_check is not None and outputs.past_key_values[0][0].shape[2] == model_inputs[key_to_check].shape[1] + and generation_config.max_new_tokens > 1 ) if do_padding: @@ -2843,6 +2844,7 @@ def _sample( do_padding = ( key_to_check is not None and outputs.past_key_values[0][0].shape[2] == model_inputs[key_to_check].shape[1] + and generation_config.max_new_tokens > 1 ) if do_padding: