diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index f0e9935c81..ce7d3cc283 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -119,7 +119,7 @@ def __init__( self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings # Truncate the cached max sequence length to 8k to limit cached register buffer size - if config.max_position_embeddings >= 8192: + if config.max_position_embeddings > 8192 and self.rope_type == "llama3": self.max_seq_len_cached = 8192 self.original_max_seq_len = config.max_position_embeddings