diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index c4a8dc6c85..0d1807da36 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -118,6 +118,9 @@ def __init__( else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings + # Truncate the cached max sequence length to 8k to limit cached register buffer size + if config.max_position_embeddings >= 8192: + self.max_seq_len_cached = 8192 self.original_max_seq_len = config.max_position_embeddings self.config = config