From cd223dd2108f39607a5ade396023cced5e78feb4 Mon Sep 17 00:00:00 2001 From: Zong Wei Date: Tue, 8 Oct 2024 03:39:25 +0000 Subject: [PATCH] fix some of models text generation error use trim_logits in HPU to save memory (comment out the num_logits_to_keep in utils.py) --- optimum/habana/transformers/generation/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index a76ea59e87..bab7db4359 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1168,8 +1168,10 @@ def generate( # If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding # dynamically overrides this value as it can need more than the last token logits - if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs: - model_kwargs["num_logits_to_keep"] = 1 + # + # Use trim_logits in HPU to save memory (in replacement of the num_logits_to_keep) + # if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs: + # model_kwargs["num_logits_to_keep"] = 1 self._validate_generated_length( generation_config,