diff --git a/tests/gdn/test_prefill_delta_rule.py b/tests/gdn/test_prefill_delta_rule.py index f2fd06cbce..084a66d08e 100644 --- a/tests/gdn/test_prefill_delta_rule.py +++ b/tests/gdn/test_prefill_delta_rule.py @@ -144,7 +144,16 @@ def _test_prefill_kernel( @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize( "num_q_heads, num_k_heads, num_v_heads", - [(1, 1, 1), (4, 1, 1), (3, 3, 3), (6, 2, 2), (1, 1, 2), (2, 2, 4)], + [ + (1, 1, 1), + (4, 1, 1), + (3, 3, 3), + (6, 2, 2), + (1, 1, 2), + (2, 2, 4), + (16, 16, 32), + (16, 16, 64), + ], ) @pytest.mark.parametrize("seq_lens", [[64], [128], [256], [256, 256], [64, 128, 512]]) @pytest.mark.parametrize("block_size", [64]) @@ -186,7 +195,16 @@ def test_prefill_kernel_basic( @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize( "num_q_heads, num_k_heads, num_v_heads", - [(1, 1, 1), (4, 1, 1), (3, 3, 3), (6, 2, 2), (1, 1, 2), (2, 2, 4)], + [ + (1, 1, 1), + (4, 1, 1), + (3, 3, 3), + (6, 2, 2), + (1, 1, 2), + (2, 2, 4), + (16, 16, 32), + (16, 16, 64), + ], ) @pytest.mark.parametrize( "seq_lens", @@ -390,7 +408,8 @@ def concat_varlen(t1, cu_seq_lens1, t2, cu_seq_lens2): @pytest.mark.parametrize("scale", [1.0, "auto"]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize( - "num_q_heads, num_k_heads, num_v_heads", [(6, 2, 2), (2, 2, 4)] + "num_q_heads, num_k_heads, num_v_heads", + [(6, 2, 2), (2, 2, 4), (16, 16, 32), (16, 16, 64)], ) @pytest.mark.parametrize( "seq_lens1, seq_lens2",