@@ -607,6 +607,7 @@ def forward(
607607 q_descale = layer ._q_scale ,
608608 k_descale = layer ._k_scale ,
609609 v_descale = layer ._v_scale ,
610+ s_aux = self .sinks ,
610611 )
611612 return output
612613
@@ -767,6 +768,7 @@ def cascade_attention(
767768 q_descale : Optional [torch .Tensor ] = None ,
768769 k_descale : Optional [torch .Tensor ] = None ,
769770 v_descale : Optional [torch .Tensor ] = None ,
771+ s_aux : Optional [torch .Tensor ] = None ,
770772) -> torch .Tensor :
771773 assert alibi_slopes is None , "Cascade attention does not support ALiBi."
772774 # TODO: Support sliding window.
@@ -801,6 +803,9 @@ def cascade_attention(
801803 q_descale = q_descale .expand (descale_shape ) if q_descale is not None else None ,
802804 k_descale = k_descale .expand (descale_shape ) if k_descale is not None else None ,
803805 v_descale = v_descale .expand (descale_shape ) if v_descale is not None else None ,
806+ # s_aux is incorporated into prefix_lse inside the GPU kernel,
807+ # enabling its effect during the final attention merge.
808+ s_aux = s_aux ,
804809 )
805810
806811 descale_shape = (cu_query_lens .shape [0 ] - 1 , key_cache .shape [- 2 ])
0 commit comments