diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index d0e217c38d..0dae5d9cb3 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -228,9 +228,6 @@ def pre_attn_forward( else: past_key_value = None - key_states = gaudi_llama_repeat_kv(key_states, self.num_key_value_groups) - value_states = gaudi_llama_repeat_kv(value_states, self.num_key_value_groups) - if use_flash_attention and FusedSDPA: import habana_frameworks.torch.hpu as ht @@ -246,6 +243,9 @@ def pre_attn_forward( attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) else: + key_states = gaudi_llama_repeat_kv(key_states, self.num_key_value_groups) + value_states = gaudi_llama_repeat_kv(value_states, self.num_key_value_groups) + attn_weights = self.matmul_qk(query_states, key_states.transpose(2, 3)) * self.norm_factor if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):