diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index d8bab89e7c..c1b14ad771 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -138,6 +138,8 @@ def _expand_dict_for_generation(dict_to_expand): def _get_hpu_graphs_kwargs(self, model_kwargs): hpu_graphs_kwargs = {} + if "limit_hpu_graphs" not in model_kwargs: + model_kwargs["limit_hpu_graphs"] = self.generation_config.limit_hpu_graphs if model_kwargs["limit_hpu_graphs"]: hpu_graphs_kwargs.update({"bypass_hpu_graphs": False}) if "first_token" not in model_kwargs.keys(): @@ -2077,6 +2079,8 @@ def beam_search( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past_key_values"] is not None: + if "reuse_cache" not in model_kwargs: + model_kwargs["reuse_cache"] = self.generation_config.reuse_cache if model_kwargs["reuse_cache"]: model_kwargs["past_key_values"] = unwrap_deepspeed_model(self).reorder_kv_cache(beam_idx) else: