diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 1747868116..be4ba15915 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1384,13 +1384,13 @@ def greedy_search( hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps) hb_profer.start() this_peer_finished = False # used by synced_gpus only - bucket_size = model_kwargs["bucket_size"] + bucket_size = model_kwargs.get("bucket_size", -1) prev_idx = None # avoiding calculate cache_idx when its value is not changing bucket_internal = model_kwargs["bucket_internal"] - reduce_recompile = model_kwargs["reduce_recompile"] + reduce_recompile = model_kwargs.get("reduce_recompile", False) prompt_len = input_ids.shape[-1] - + if not bucket_internal: if bucket_size >= 0: inc = iter(incrementor(bucket_size, prompt_len)) @@ -2167,8 +2167,8 @@ def expand_if_needed(tensor, new_size, value, dim=-1): hb_profer.start() this_peer_finished = False # used by synced_gpus only - bucket_size = model_kwargs["bucket_size"] - reduce_recompile = model_kwargs["reduce_recompile"] + bucket_size = model_kwargs.get("bucket_size", -1) + reduce_recompile = model_kwargs.get("reduce_recompile", False) prompt_len = input_ids.shape[-1] if bucket_size >= 0: inc = iter(incrementor(bucket_size, prompt_len))