diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 42a3a6f4cb..82fcf6f0c3 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -76,19 +76,38 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -def gaudi_llama_repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: +def gaudi_llama_repeat_kv( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, +): """ Copied from repeat_kv: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py The only differences are: - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them. - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion. + The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) + The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape + batch, num_key_value_heads, kv_len, head_dim = key_states.shape if n_rep == 1 or num_key_value_heads == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + return query_states, key_states, value_states, attention_mask + + new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) + key_states = key_states.reshape(new_kv_shape) + value_states = value_states.reshape(new_kv_shape) + + batch, _, q_len, head_dim = query_states.shape + new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) + query_states = query_states.reshape(new_q_shape) + + if attention_mask is not None: + # Add groups dim and set to 1 + attention_mask = attention_mask.unsqueeze(1) + + return query_states, key_states, value_states, attention_mask class Matmul(torch.nn.Module): @@ -251,21 +270,27 @@ def pre_attn_forward( ) else: - key_states = gaudi_llama_repeat_kv(key_states, self.num_key_value_groups) - value_states = gaudi_llama_repeat_kv(value_states, self.num_key_value_groups) + query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) - attn_weights = self.matmul_qk(query_states, key_states.transpose(2, 3)) * self.norm_factor + attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + if attn_weights.size() not in [ + (bsz, self.num_heads, q_len, kv_seq_len), + (bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len), + ]: raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)} or" + f" {(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + if attention_mask.size() not in [(bsz, 1, q_len, kv_seq_len), (bsz, 1, 1, q_len, kv_seq_len)]: raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)} or {(bsz, 1, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask @@ -278,6 +303,7 @@ def pre_attn_forward( ) attn_output = self.matmul_av(attn_weights, value_states) + attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError(