diff --git a/python/sglang/srt/layers/attention/fla/kda.py b/python/sglang/srt/layers/attention/fla/kda.py index 4235a3fe9d10..9556fa63f216 100644 --- a/python/sglang/srt/layers/attention/fla/kda.py +++ b/python/sglang/srt/layers/attention/fla/kda.py @@ -102,8 +102,11 @@ def fused_recurrent_kda_fwd( # stride_final_state_token=stride_final_state_token, # stride_indices_seq=stride_indices_seq, # stride_indices_tok=stride_indices_tok, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, IS_BETA_HEADWISE=beta.ndim == v.ndim, USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + IS_VARLEN=cu_seqlens is not None, # INPLACE_FINAL_STATE=inplace_final_state, IS_KDA=True, num_warps=num_warps,