Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 56 additions & 9 deletions torchtitan/experiments/llama4/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch
import torch.nn.functional as F
from torch import nn

from torchtitan.models.attention import build_attention
from torchtitan.components.tokenizer import BaseTokenizer
from torchtitan.models.attention import (
create_attention_mask,
FlexAttentionWrapper,
get_causal_mask_mod,
get_document_mask_mod,
get_fixed_block_mask_mod,
ScaledDotProductAttentionWrapper,
)
from torchtitan.models.moe import MoE
from torchtitan.protocols import ModelProtocol
from torchtitan.protocols.model import AttentionMasksType
from torchtitan.protocols.train_spec import ModelProtocol

from .args import TransformerModelArgs

Expand Down Expand Up @@ -155,9 +163,11 @@ def __init__(
# values of these two variables.
self.use_rope = use_rope

self.sdpa = build_attention(
model_args.use_flex_attn, model_args.attn_mask_type, fixed_block_size
)
self.use_flex_attn = model_args.use_flex_attn
if self.use_flex_attn:
self.inner_attention = FlexAttentionWrapper()
else:
self.inner_attention = ScaledDotProductAttentionWrapper()

def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
Expand All @@ -168,6 +178,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_masks: AttentionMasksType | None,
):
"""
Forward pass of the attention module.
Expand Down Expand Up @@ -202,7 +213,13 @@ def forward(
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)

output = self.sdpa(xq, xk, xv)
if self.use_flex_attn:
assert isinstance(attention_masks, dict), attention_masks
attention_mask = attention_masks["rope" if self.use_rope else "nope"]
output = self.inner_attention(xq, xk, xv, block_mask=attention_mask)
else:
assert attention_masks is None
output = self.inner_attention(xq, xk, xv)

output = output.transpose(
1, 2
Expand Down Expand Up @@ -335,6 +352,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_masks: AttentionMasksType | None,
):
"""
Perform a forward pass through the TransformerBlock.
Expand All @@ -347,7 +365,7 @@ def forward(
torch.Tensor: Output tensor after applying attention and feedforward layers.

"""
h = x + self.attention(self.attention_norm(x), freqs_cis)
h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks)
if self.moe_enabled:
out = h + self.moe(self.ffn_norm(h))
else:
Expand Down Expand Up @@ -447,9 +465,38 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
self.model_args.rope_theta,
)

def get_attention_masks(
self,
input_batch: torch.Tensor,
tokenizer: BaseTokenizer,
extra_inputs: dict[str, torch.Tensor] | None = None,
) -> AttentionMasksType:
nope_mask_mod = get_causal_mask_mod()
match self.model_args.attn_mask_type:
case "causal":
B = 1
case "block_causal":
B = input_batch.shape[0]
nope_mask_mod = get_document_mask_mod(
nope_mask_mod, input_batch, tokenizer.eos_id
)
case _:
raise ValueError(f"Unknown attention mask type: {self.attn_mask_type}")

rope_mask_mod = get_fixed_block_mask_mod(
nope_mask_mod, self.model_args.fixed_attn_block_size
)

seqlen = input_batch.shape[1]
return {
"rope": create_attention_mask(rope_mask_mod, B, None, seqlen, seqlen),
"nope": create_attention_mask(nope_mask_mod, B, None, seqlen, seqlen),
}

def forward(
self,
tokens: torch.Tensor,
attention_masks: AttentionMasksType | None,
input_batch: torch.Tensor | None = None,
):
"""
Expand All @@ -473,7 +520,7 @@ def forward(
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

for layer in self.layers.values():
h = layer(h, self.freqs_cis)
h = layer(h, self.freqs_cis, attention_masks)

h = self.norm(h) if self.norm else h
output = self.output(h) if self.output else h
Expand Down
Loading
Loading