1515from torchtitan .components .tokenizer import BaseTokenizer
1616from torchtitan .models .attention import (
1717 create_attention_mask ,
18+ create_varlen_metadata_for_document ,
1819 FlexAttentionWrapper ,
1920 get_causal_mask_mod ,
2021 get_document_mask_mod ,
2122 ScaledDotProductAttentionWrapper ,
23+ VarlenAttentionWrapper ,
24+ VarlenMetadata ,
2225)
2326from torchtitan .models .moe import MoE
2427from torchtitan .protocols .model import AttentionMasksType
@@ -170,8 +173,12 @@ def __init__(self, model_args: Qwen3ModelArgs):
170173 match self .attn_type :
171174 case "flex" :
172175 self .inner_attention = FlexAttentionWrapper ()
173- case _:
176+ case "varlen" :
177+ self .inner_attention = VarlenAttentionWrapper ()
178+ case "sdpa" :
174179 self .inner_attention = ScaledDotProductAttentionWrapper ()
180+ case _:
181+ raise ValueError (f"Unknown attention type: { self .attn_type } " )
175182
176183 def init_weights (self , init_std : float ):
177184 for linear in (self .wq , self .wk , self .wv ):
@@ -231,9 +238,20 @@ def forward(
231238 case "flex" :
232239 assert isinstance (attention_masks , BlockMask ), attention_masks
233240 output = self .inner_attention (xq , xk , xv , block_mask = attention_masks )
234- case _:
241+ case "varlen" :
242+ assert isinstance (attention_masks , VarlenMetadata ), attention_masks
243+ output = self .inner_attention (
244+ xq ,
245+ xk ,
246+ xv ,
247+ self .head_dim ,
248+ attention_masks ,
249+ )
250+ case "sdpa" :
235251 assert attention_masks is None
236252 output = self .inner_attention (xq , xk , xv )
253+ case _:
254+ raise ValueError (f"Unknown attention type: { self .attn_type } " )
237255
238256 output = output .transpose (
239257 1 , 2
@@ -447,7 +465,7 @@ def _precompute_rope_cache(self) -> torch.Tensor:
447465 self .model_args .rope_theta ,
448466 )
449467
450- def get_attention_masks (
468+ def _get_flex_attention_masks (
451469 self ,
452470 input_batch : torch .Tensor ,
453471 tokenizer : BaseTokenizer ,
@@ -468,6 +486,31 @@ def get_attention_masks(
468486 and_masks (* mask_mods ), B , None , input_batch .shape [1 ], input_batch .shape [1 ]
469487 )
470488
489+ def get_attention_masks (
490+ self ,
491+ input_batch : torch .Tensor ,
492+ tokenizer : BaseTokenizer ,
493+ extra_inputs : dict [str , torch .Tensor ] | None = None ,
494+ ) -> AttentionMasksType :
495+ match self .model_args .attn_type :
496+ case "flex" :
497+ return self ._get_flex_attention_masks (
498+ input_batch , tokenizer , extra_inputs
499+ )
500+ case "varlen" :
501+ if self .model_args .attn_mask_type != "block_causal" :
502+ raise ValueError (
503+ f"varlen attention is only supported with block_causal \
504+ attention mask type, got { self .model_args .attn_mask_type } "
505+ )
506+ return create_varlen_metadata_for_document (
507+ input_batch , tokenizer .eos_id
508+ )
509+ case _:
510+ raise NotImplementedError (
511+ "Only varlen and flex attn masks are supported"
512+ )
513+
471514 def forward (
472515 self ,
473516 tokens : torch .Tensor ,
0 commit comments