1515from torchtitan .components .tokenizer import BaseTokenizer
1616from torchtitan .models .attention import (
1717 create_attention_mask ,
18+ create_varlen_metadata_for_document ,
1819 FlexAttentionWrapper ,
1920 get_causal_mask_mod ,
2021 get_document_mask_mod ,
2122 ScaledDotProductAttentionWrapper ,
23+ VarlenAttentionWrapper ,
24+ VarlenMetadata ,
2225)
2326from torchtitan .models .moe import MoE
2427from torchtitan .protocols .model import AttentionMasksType
@@ -170,6 +173,8 @@ def __init__(self, model_args: Qwen3ModelArgs):
170173 match self .attn_type :
171174 case "flex" :
172175 self .inner_attention = FlexAttentionWrapper ()
176+ case "varlen" :
177+ self .inner_attention = VarlenAttentionWrapper ()
173178 case _:
174179 self .inner_attention = ScaledDotProductAttentionWrapper ()
175180
@@ -231,6 +236,15 @@ def forward(
231236 case "flex" :
232237 assert isinstance (attention_masks , BlockMask ), attention_masks
233238 output = self .inner_attention (xq , xk , xv , block_mask = attention_masks )
239+ case "varlen" :
240+ assert isinstance (attention_masks , VarlenMetadata ), attention_masks
241+ output = self .inner_attention (
242+ xq ,
243+ xk ,
244+ xv ,
245+ self .head_dim ,
246+ attention_masks ,
247+ )
234248 case _:
235249 assert attention_masks is None
236250 output = self .inner_attention (xq , xk , xv )
@@ -447,7 +461,7 @@ def _precompute_rope_cache(self) -> torch.Tensor:
447461 self .model_args .rope_theta ,
448462 )
449463
450- def get_attention_masks (
464+ def _get_flex_attention_masks (
451465 self ,
452466 input_batch : torch .Tensor ,
453467 tokenizer : BaseTokenizer ,
@@ -468,6 +482,31 @@ def get_attention_masks(
468482 and_masks (* mask_mods ), B , None , input_batch .shape [1 ], input_batch .shape [1 ]
469483 )
470484
485+ def get_attention_masks (
486+ self ,
487+ input_batch : torch .Tensor ,
488+ tokenizer : BaseTokenizer ,
489+ extra_inputs : dict [str , torch .Tensor ] | None = None ,
490+ ) -> AttentionMasksType :
491+ match self .model_args .attn_type :
492+ case "flex" :
493+ return self ._get_flex_attention_masks (
494+ input_batch , tokenizer , extra_inputs
495+ )
496+ case "varlen" :
497+ if self .model_args .attn_mask_type != "block_causal" :
498+ raise ValueError (
499+ f"varlen attention is only supported with block_causal \
500+ attention mask type, got { self .model_args .attn_mask_type } "
501+ )
502+ return create_varlen_metadata_for_document (
503+ input_batch , tokenizer .eos_id
504+ )
505+ case _:
506+ raise NotImplementedError (
507+ "Only varlen and flex attn masks are supported"
508+ )
509+
471510 def forward (
472511 self ,
473512 tokens : torch .Tensor ,
0 commit comments