diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 12017527ac..e563ac00dc 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -263,7 +263,10 @@ def pre_attn_forward( "with a layer index." ) if token_idx is None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if hasattr(past_key_value, "get_usable_length"): + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value[0].shape[-2] else: if reuse_cache: kv_seq_len = past_key_value[0][-2]