diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 7f102a80..f9bd4279 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -114,18 +114,6 @@ def apply_rotary_emb( return xq_out.type_as(xq), xk_out.type_as(xk) -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - torch.unsqueeze(x, dim=3) - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - - class Attention(nn.Module): """ Multi-head attention module. @@ -198,16 +186,14 @@ def forward( xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = xk.transpose(1, 2) # (bs, n_local_kv_heads, seqlen, head_dim) + xv = xv.transpose(1, 2) # (bs, n_local_kv_heads, seqlen, head_dim) # we use casual mask for training - output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = F.scaled_dot_product_attention( + xq, xk, xv, is_causal=True, enable_gqa=self.n_rep > 1 + ) output = output.transpose( 1, 2 ).contiguous() # (bs, seqlen, n_local_heads, head_dim) diff --git a/torchtitan/models/llama_multimodal/model.py b/torchtitan/models/llama_multimodal/model.py index 6866f3a6..605a8db2 100644 --- a/torchtitan/models/llama_multimodal/model.py +++ b/torchtitan/models/llama_multimodal/model.py @@ -130,18 +130,6 @@ def apply_rotary_emb( return xq_out.type_as(xq), xk_out.type_as(xk) -def repeat_kv(x: torch.Tensor, num_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=num_rep)""" - bsz, seq_len, num_kv_heads, head_dim = x.shape - if num_rep == 1: - return x - return ( - torch.unsqueeze(x, dim=3) - .expand(bsz, seq_len, num_kv_heads, num_rep, head_dim) - .reshape(bsz, seq_len, num_kv_heads * num_rep, head_dim) - ) - - class Attention(nn.Module): """ Multi-head attention module. @@ -222,16 +210,14 @@ def forward( ): # Only used in the self attention layers for text decoder xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - # repeat k/v heads if num_kv_heads < n_heads - keys = repeat_kv(xk, self.num_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(xv, self.num_rep) # (bs, seqlen, n_local_heads, head_dim) - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = xk.transpose(1, 2) # (bs, n_local_kv_heads, seqlen, head_dim) + xv = xv.transpose(1, 2) # (bs, n_local_kv_heads, seqlen, head_dim) # we use casual mask for training - output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=self.is_causal) + output = F.scaled_dot_product_attention( + xq, xk, xv, is_causal=self.is_causal, enable_gqa=self.num_rep > 1 + ) output = output.transpose( 1, 2 ).contiguous() # (bs, seqlen, n_local_heads, head_dim)