Skip to content

Commit 1b9cfda

Browse files
authored
add varlen attention for qwen 3 (#2084)
As title **Testing** <img width="469" height="431" alt="Screenshot 2025-11-24 at 4 30 53 PM" src="https://github.com/user-attachments/assets/6b9a362d-de36-48b7-b465-d91ae24f4cbf" /> performance and loss on par
1 parent d0393b3 commit 1b9cfda

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

torchtitan/models/qwen3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
# used to compute the scaling factor for quantization.
4444
torch.ops.aten.max.default,
4545
torch._higher_order_ops.flex_attention,
46+
torch.ops.torch_attn._varlen_attn,
4647
}
4748

4849

torchtitan/models/qwen3/model/model.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
from torchtitan.components.tokenizer import BaseTokenizer
1616
from torchtitan.models.attention import (
1717
create_attention_mask,
18+
create_varlen_metadata_for_document,
1819
FlexAttentionWrapper,
1920
get_causal_mask_mod,
2021
get_document_mask_mod,
2122
ScaledDotProductAttentionWrapper,
23+
VarlenAttentionWrapper,
24+
VarlenMetadata,
2225
)
2326
from torchtitan.models.moe import MoE
2427
from torchtitan.protocols.model import AttentionMasksType
@@ -170,8 +173,12 @@ def __init__(self, model_args: Qwen3ModelArgs):
170173
match self.attn_type:
171174
case "flex":
172175
self.inner_attention = FlexAttentionWrapper()
173-
case _:
176+
case "varlen":
177+
self.inner_attention = VarlenAttentionWrapper()
178+
case "sdpa":
174179
self.inner_attention = ScaledDotProductAttentionWrapper()
180+
case _:
181+
raise ValueError(f"Unknown attention type: {self.attn_type}")
175182

176183
def init_weights(self, init_std: float):
177184
for linear in (self.wq, self.wk, self.wv):
@@ -231,9 +238,20 @@ def forward(
231238
case "flex":
232239
assert isinstance(attention_masks, BlockMask), attention_masks
233240
output = self.inner_attention(xq, xk, xv, block_mask=attention_masks)
234-
case _:
241+
case "varlen":
242+
assert isinstance(attention_masks, VarlenMetadata), attention_masks
243+
output = self.inner_attention(
244+
xq,
245+
xk,
246+
xv,
247+
self.head_dim,
248+
attention_masks,
249+
)
250+
case "sdpa":
235251
assert attention_masks is None
236252
output = self.inner_attention(xq, xk, xv)
253+
case _:
254+
raise ValueError(f"Unknown attention type: {self.attn_type}")
237255

238256
output = output.transpose(
239257
1, 2
@@ -447,7 +465,7 @@ def _precompute_rope_cache(self) -> torch.Tensor:
447465
self.model_args.rope_theta,
448466
)
449467

450-
def get_attention_masks(
468+
def _get_flex_attention_masks(
451469
self,
452470
input_batch: torch.Tensor,
453471
tokenizer: BaseTokenizer,
@@ -468,6 +486,31 @@ def get_attention_masks(
468486
and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]
469487
)
470488

489+
def get_attention_masks(
490+
self,
491+
input_batch: torch.Tensor,
492+
tokenizer: BaseTokenizer,
493+
extra_inputs: dict[str, torch.Tensor] | None = None,
494+
) -> AttentionMasksType:
495+
match self.model_args.attn_type:
496+
case "flex":
497+
return self._get_flex_attention_masks(
498+
input_batch, tokenizer, extra_inputs
499+
)
500+
case "varlen":
501+
if self.model_args.attn_mask_type != "block_causal":
502+
raise ValueError(
503+
f"varlen attention is only supported with block_causal \
504+
attention mask type, got {self.model_args.attn_mask_type}"
505+
)
506+
return create_varlen_metadata_for_document(
507+
input_batch, tokenizer.eos_id
508+
)
509+
case _:
510+
raise NotImplementedError(
511+
"Only varlen and flex attn masks are supported"
512+
)
513+
471514
def forward(
472515
self,
473516
tokens: torch.Tensor,

0 commit comments

Comments
 (0)