diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index 9c779799c5..1484224695 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -198,9 +198,22 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): self.block_size = 4096 self.rotary_emb = GaudiRotaryEmbedding(config=self.config) + def get_k_proj_weight(self): + """4bit quantization in GPTQ replaces the k_proj.weight with qweight.""" + if hasattr(self.k_proj, "qweight"): + return self.k_proj.qweight + return self.k_proj.weight + + def get_k_proj_weight_dtype(self): + """4bit quantization in GPTQ replaces the k_proj.weight with qweight. + Scales tensor gets the weight dtype.""" + if hasattr(self.k_proj, "qweight"): + return self.k_proj.scales.dtype + return self.k_proj.weight.dtype + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) - device = self.k_proj.weight.device + device = self.get_k_proj_weight().device dtype = self.config.torch_dtype self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) @@ -211,7 +224,7 @@ def update_sincos_cache(self, seq_len): # reduce memory consumption and improve performance. if seq_len > self.max_position_embeddings: self.max_position_embeddings = seq_len - _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + _, _ = self.rotary_emb(self.get_k_proj_weight(), seq_len=seq_len) def reorder(self, tensor, beam_idx, dim_a, dim_b): updated = tensor.index_select(0, beam_idx) @@ -316,9 +329,11 @@ def pre_attn_forward( past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: if past_key_value is None: - past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_key = torch.zeros( + key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device + ) past_value = torch.zeros( - key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device ) past_key_value = [past_key, past_value] key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)