diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 0c32418eaa..8071985a6f 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -102,7 +102,7 @@ def __init__(self, config: LlamaConfig): self.past_key = None self.past_value = None self.inp_seq_len = -1 - self.register_buffer("norm_factor", torch.tensor(1.0 / math.sqrt(self.head_dim)), persistent=False) + self.norm_factor = 1.0 / math.sqrt(self.head_dim) def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): key_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)