diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py index 0f27504780ac..91b07129dca8 100644 --- a/vllm/model_executor/layers/fla/ops/fused_recurrent.py +++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py @@ -189,7 +189,7 @@ def fused_recurrent_gated_delta_rule_fwd( B, T, H, K, V = *k.shape, v.shape[-1] HV = v.shape[2] N = B if cu_seqlens is None else len(cu_seqlens) - 1 - BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) assert NK == 1, "NK > 1 is not supported yet" num_stages = 3