From 97d7637220806ce670a279065bcdc221f37b8661 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Tue, 30 Jan 2024 17:28:23 +0000 Subject: [PATCH] Fix tests --- optimum/habana/transformers/generation/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index bc8fad5118..20cf008548 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1368,8 +1368,8 @@ def greedy_search( hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps) hb_profer.start() this_peer_finished = False # used by synced_gpus only - bucket_size = model_kwargs["bucket_size"] - reduce_recompile = model_kwargs["reduce_recompile"] + bucket_size = model_kwargs.get("bucket_size", -1) + reduce_recompile = model_kwargs.get("reduce_recompile", False) prompt_len = input_ids.shape[-1] if bucket_size >= 0: @@ -2121,8 +2121,8 @@ def expand_if_needed(tensor, new_size, value, dim=-1): hb_profer.start() this_peer_finished = False # used by synced_gpus only - bucket_size = model_kwargs["bucket_size"] - reduce_recompile = model_kwargs["reduce_recompile"] + bucket_size = model_kwargs.get("bucket_size", -1) + reduce_recompile = model_kwargs.get("reduce_recompile", False) prompt_len = input_ids.shape[-1] if bucket_size >= 0: inc = iter(incrementor(bucket_size, prompt_len))