From e20ab9f232f891d2c842398c830ecf0de917f61e Mon Sep 17 00:00:00 2001 From: Danny Date: Tue, 25 Jun 2024 10:19:41 +0300 Subject: [PATCH] Fixed self.k_proj.weight when using gptq --- .../transformers/models/llama/modeling_llama.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 42cf4bef39..f0707df8b0 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -300,6 +300,13 @@ def get_k_proj_weight(self): return self.k_proj.qweight return self.k_proj.weight + def get_k_proj_weight_dtype(self): + """ 4bit quantization in GPTQ replaces the k_proj.weight with qweight. + Scales tensor gets the weight dtype. """ + if hasattr(self.k_proj, 'qweight'): + return self.k_proj.scales.dtype + return self.k_proj.weight.dtype + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) device = self.get_k_proj_weight().device @@ -418,9 +425,9 @@ def pre_attn_forward( past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: if past_key_value is None: - past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_key = torch.zeros(key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device) past_value = torch.zeros( - key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device ) # Return list instead of tuple past_key_value = [past_key, past_value]