Skip to content

Commit a47df59

Browse files
committed
add varlen attention for qwen 3
1 parent ad9f188 commit a47df59

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

torchtitan/models/qwen3/model/model.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
from torchtitan.components.tokenizer import BaseTokenizer
1616
from torchtitan.models.attention import (
1717
create_attention_mask,
18+
create_varlen_metadata_for_document,
1819
FlexAttentionWrapper,
1920
get_causal_mask_mod,
2021
get_document_mask_mod,
2122
ScaledDotProductAttentionWrapper,
23+
VarlenAttentionWrapper,
24+
VarlenMetadata,
2225
)
2326
from torchtitan.models.moe import MoE
2427
from torchtitan.protocols.model import AttentionMasksType
@@ -170,6 +173,8 @@ def __init__(self, model_args: Qwen3ModelArgs):
170173
match self.attn_type:
171174
case "flex":
172175
self.inner_attention = FlexAttentionWrapper()
176+
case "varlen":
177+
self.inner_attention = VarlenAttentionWrapper()
173178
case _:
174179
self.inner_attention = ScaledDotProductAttentionWrapper()
175180

@@ -231,6 +236,15 @@ def forward(
231236
case "flex":
232237
assert isinstance(attention_masks, BlockMask), attention_masks
233238
output = self.inner_attention(xq, xk, xv, block_mask=attention_masks)
239+
case "varlen":
240+
assert isinstance(attention_masks, VarlenMetadata), attention_masks
241+
output = self.inner_attention(
242+
xq,
243+
xk,
244+
xv,
245+
self.head_dim,
246+
attention_masks,
247+
)
234248
case _:
235249
assert attention_masks is None
236250
output = self.inner_attention(xq, xk, xv)
@@ -447,7 +461,7 @@ def _precompute_rope_cache(self) -> torch.Tensor:
447461
self.model_args.rope_theta,
448462
)
449463

450-
def get_attention_masks(
464+
def _get_flex_attention_masks(
451465
self,
452466
input_batch: torch.Tensor,
453467
tokenizer: BaseTokenizer,
@@ -468,6 +482,31 @@ def get_attention_masks(
468482
and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]
469483
)
470484

485+
def get_attention_masks(
486+
self,
487+
input_batch: torch.Tensor,
488+
tokenizer: BaseTokenizer,
489+
extra_inputs: dict[str, torch.Tensor] | None = None,
490+
) -> AttentionMasksType:
491+
match self.model_args.attn_type:
492+
case "flex":
493+
return self._get_flex_attention_masks(
494+
input_batch, tokenizer, extra_inputs
495+
)
496+
case "varlen":
497+
if self.model_args.attn_mask_type != "block_causal":
498+
raise ValueError(
499+
f"varlen attention is only supported with block_causal \
500+
attention mask type, got {self.model_args.attn_mask_type}"
501+
)
502+
return create_varlen_metadata_for_document(
503+
input_batch, tokenizer.eos_id
504+
)
505+
case _:
506+
raise NotImplementedError(
507+
"Only varlen and flex attn masks are supported"
508+
)
509+
471510
def forward(
472511
self,
473512
tokens: torch.Tensor,

0 commit comments

Comments
 (0)