diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index a76ea59e87..bab7db4359 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1168,8 +1168,10 @@ def generate( # If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding # dynamically overrides this value as it can need more than the last token logits - if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs: - model_kwargs["num_logits_to_keep"] = 1 + # + # Use trim_logits in HPU to save memory (in replacement of the num_logits_to_keep) + # if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs: + # model_kwargs["num_logits_to_keep"] = 1 self._validate_generated_length( generation_config,