Skip to content

Commit eb577e4

Browse files
authored
[Bugfix] Add missing sink tensor into flash attn cascade attn implementation (#26325)
1 parent 8f36850 commit eb577e4

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@ def forward(
607607
q_descale=layer._q_scale,
608608
k_descale=layer._k_scale,
609609
v_descale=layer._v_scale,
610+
s_aux=self.sinks,
610611
)
611612
return output
612613

@@ -767,6 +768,7 @@ def cascade_attention(
767768
q_descale: Optional[torch.Tensor] = None,
768769
k_descale: Optional[torch.Tensor] = None,
769770
v_descale: Optional[torch.Tensor] = None,
771+
s_aux: Optional[torch.Tensor] = None,
770772
) -> torch.Tensor:
771773
assert alibi_slopes is None, "Cascade attention does not support ALiBi."
772774
# TODO: Support sliding window.
@@ -801,6 +803,9 @@ def cascade_attention(
801803
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
802804
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
803805
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
806+
# s_aux is incorporated into prefix_lse inside the GPU kernel,
807+
# enabling its effect during the final attention merge.
808+
s_aux=s_aux,
804809
)
805810

806811
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])

0 commit comments

Comments
 (0)