Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchtitan/models/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
# used to compute the scaling factor for quantization.
torch.ops.aten.max.default,
torch._higher_order_ops.flex_attention,
torch.ops.torch_attn._varlen_attn,
}


Expand Down
49 changes: 46 additions & 3 deletions torchtitan/models/qwen3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
from torchtitan.components.tokenizer import BaseTokenizer
from torchtitan.models.attention import (
create_attention_mask,
create_varlen_metadata_for_document,
FlexAttentionWrapper,
get_causal_mask_mod,
get_document_mask_mod,
ScaledDotProductAttentionWrapper,
VarlenAttentionWrapper,
VarlenMetadata,
)
from torchtitan.models.moe import MoE
from torchtitan.protocols.model import AttentionMasksType
Expand Down Expand Up @@ -170,8 +173,12 @@ def __init__(self, model_args: Qwen3ModelArgs):
match self.attn_type:
case "flex":
self.inner_attention = FlexAttentionWrapper()
case _:
case "varlen":
self.inner_attention = VarlenAttentionWrapper()
case "sdpa":
self.inner_attention = ScaledDotProductAttentionWrapper()
case _:
raise ValueError(f"Unknown attention type: {self.attn_type}")

def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
Expand Down Expand Up @@ -231,9 +238,20 @@ def forward(
case "flex":
assert isinstance(attention_masks, BlockMask), attention_masks
output = self.inner_attention(xq, xk, xv, block_mask=attention_masks)
case _:
case "varlen":
assert isinstance(attention_masks, VarlenMetadata), attention_masks
output = self.inner_attention(
xq,
xk,
xv,
self.head_dim,
attention_masks,
)
case "sdpa":
assert attention_masks is None
output = self.inner_attention(xq, xk, xv)
case _:
raise ValueError(f"Unknown attention type: {self.attn_type}")

output = output.transpose(
1, 2
Expand Down Expand Up @@ -447,7 +465,7 @@ def _precompute_rope_cache(self) -> torch.Tensor:
self.model_args.rope_theta,
)

def get_attention_masks(
def _get_flex_attention_masks(
self,
input_batch: torch.Tensor,
tokenizer: BaseTokenizer,
Expand All @@ -468,6 +486,31 @@ def get_attention_masks(
and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]
)

def get_attention_masks(
self,
input_batch: torch.Tensor,
tokenizer: BaseTokenizer,
extra_inputs: dict[str, torch.Tensor] | None = None,
) -> AttentionMasksType:
match self.model_args.attn_type:
case "flex":
return self._get_flex_attention_masks(
input_batch, tokenizer, extra_inputs
)
case "varlen":
if self.model_args.attn_mask_type != "block_causal":
raise ValueError(
f"varlen attention is only supported with block_causal \
attention mask type, got {self.model_args.attn_mask_type}"
)
return create_varlen_metadata_for_document(
input_batch, tokenizer.eos_id
)
case _:
raise NotImplementedError(
"Only varlen and flex attn masks are supported"
)

def forward(
self,
tokens: torch.Tensor,
Expand Down