From 1efb76805bfab5f1c94c646209093a8727f6b99a Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Thu, 17 Oct 2024 10:33:37 -0700 Subject: [PATCH] Fix OOM error for code llama --- optimum/habana/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index f0e9935c81..ce7d3cc283 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -119,7 +119,7 @@ def __init__( self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings # Truncate the cached max sequence length to 8k to limit cached register buffer size - if config.max_position_embeddings >= 8192: + if config.max_position_embeddings > 8192 and self.rope_type == "llama3": self.max_seq_len_cached = 8192 self.original_max_seq_len = config.max_position_embeddings