Skip to content

Commit 74d153f

Browse files
committed
reduce the number of unit tests
Signed-off-by: Siyuan Fu <[email protected]>
1 parent 997b913 commit 74d153f

File tree

1 file changed

+77
-41
lines changed

1 file changed

+77
-41
lines changed

tests/attention/test_trtllm_gen_attention.py

Lines changed: 77 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -339,40 +339,7 @@ def unpack_compare_nvfp4(
339339
return output_unpacked, output_ref
340340

341341

342-
@pytest.mark.parametrize("kv_layout", ["HND", "NHD"])
343-
@pytest.mark.parametrize(
344-
"batch_size,page_size,num_kv_heads,head_grp_size",
345-
[
346-
(4, 16, 2, 1),
347-
(4, 32, 4, 5),
348-
(4, 64, 4, 8),
349-
(128, 16, 2, 5),
350-
(128, 32, 4, 1),
351-
(128, 64, 2, 8),
352-
(256, 16, 4, 8),
353-
(256, 32, 2, 8),
354-
(256, 64, 4, 1),
355-
(256, 64, 4, 5),
356-
],
357-
)
358-
@pytest.mark.parametrize("window_left", [-1]) # todo(Siyuan): add 127 window_left
359-
@pytest.mark.parametrize(
360-
"q_dtype,kv_dtype,o_dtype",
361-
[
362-
("bf16", "bf16", "bf16"),
363-
("fp16", "fp16", "fp16"),
364-
("fp8", "fp8", "bf16"),
365-
("fp8", "fp8", "fp16"),
366-
("fp8", "fp8", "fp8"),
367-
("fp8", "fp8", "nvfp4"),
368-
],
369-
)
370-
@pytest.mark.parametrize("enable_pdl", [True, False, None])
371-
@pytest.mark.parametrize("enable_sink", [True, False])
372-
@pytest.mark.parametrize("max_q_len", [511])
373-
@pytest.mark.parametrize("max_kv_len", [2047])
374-
@pytest.mark.parametrize("device_scale", [True, False])
375-
def test_trtllm_batch_prefill(
342+
def _test_trtllm_batch_prefill(
376343
kv_layout,
377344
batch_size,
378345
page_size,
@@ -580,6 +547,71 @@ def test_trtllm_batch_prefill(
580547
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()
581548

582549

550+
@pytest.mark.parametrize("kv_layout", ["HND", "NHD"])
551+
@pytest.mark.parametrize(
552+
"batch_size,page_size,num_kv_heads,head_grp_size",
553+
[
554+
(4, 16, 2, 1),
555+
(4, 32, 4, 5),
556+
(4, 64, 4, 8),
557+
(128, 16, 2, 5),
558+
(128, 32, 4, 1),
559+
(128, 64, 2, 8),
560+
(256, 16, 4, 8),
561+
(256, 32, 2, 8),
562+
(256, 64, 4, 1),
563+
(256, 64, 4, 5),
564+
],
565+
)
566+
@pytest.mark.parametrize("window_left", [-1]) # todo(Siyuan): add 127 window_left
567+
@pytest.mark.parametrize(
568+
"q_dtype,kv_dtype,o_dtype",
569+
[
570+
("bf16", "bf16", "bf16"),
571+
("fp16", "fp16", "fp16"),
572+
("fp8", "fp8", "bf16"),
573+
("fp8", "fp8", "fp16"),
574+
("fp8", "fp8", "fp8"),
575+
("fp8", "fp8", "nvfp4"),
576+
],
577+
)
578+
@pytest.mark.parametrize("enable_pdl", [None])
579+
@pytest.mark.parametrize("enable_sink", [True, False])
580+
@pytest.mark.parametrize("max_q_len", [511])
581+
@pytest.mark.parametrize("max_kv_len", [2047])
582+
def test_trtllm_batch_prefill(
583+
kv_layout,
584+
batch_size,
585+
page_size,
586+
num_kv_heads,
587+
head_grp_size,
588+
window_left,
589+
q_dtype,
590+
o_dtype,
591+
kv_dtype,
592+
enable_pdl,
593+
enable_sink,
594+
max_q_len,
595+
max_kv_len,
596+
):
597+
_test_trtllm_batch_prefill(
598+
kv_layout,
599+
batch_size,
600+
page_size,
601+
num_kv_heads,
602+
head_grp_size,
603+
window_left,
604+
q_dtype,
605+
o_dtype,
606+
kv_dtype,
607+
enable_pdl,
608+
enable_sink,
609+
max_q_len,
610+
max_kv_len,
611+
kv_dtype == "fp8",
612+
)
613+
614+
583615
@pytest.mark.parametrize("kv_layout", ["HND", "NHD"])
584616
@pytest.mark.parametrize(
585617
"batch_size,page_size,num_kv_heads,head_grp_size",
@@ -613,7 +645,7 @@ def test_trtllm_batch_prefill_bs1(
613645
max_q_len,
614646
max_kv_len,
615647
):
616-
test_trtllm_batch_prefill(
648+
_test_trtllm_batch_prefill(
617649
kv_layout,
618650
batch_size,
619651
page_size,
@@ -966,7 +998,6 @@ def _test_trtllm_batch_decode(
966998
@pytest.mark.parametrize("enable_sink", [True, False])
967999
@pytest.mark.parametrize("max_in_kv_len", [110])
9681000
@pytest.mark.parametrize("head_dim", [128])
969-
@pytest.mark.parametrize("device_scale", [True, False])
9701001
def test_trtllm_batch_decode(
9711002
backend,
9721003
kv_layout,
@@ -983,7 +1014,6 @@ def test_trtllm_batch_decode(
9831014
enable_sink,
9841015
max_in_kv_len,
9851016
head_dim,
986-
device_scale,
9871017
):
9881018
# General set of tests for trtllm-gen decode
9891019
_test_trtllm_batch_decode(
@@ -1002,7 +1032,7 @@ def test_trtllm_batch_decode(
10021032
enable_sink,
10031033
max_in_kv_len,
10041034
head_dim,
1005-
device_scale,
1035+
kv_dtype == "fp8",
10061036
)
10071037

10081038

@@ -1024,6 +1054,7 @@ def test_trtllm_batch_decode(
10241054
@pytest.mark.parametrize("enable_sink", [False])
10251055
@pytest.mark.parametrize("max_in_kv_len", [8192])
10261056
@pytest.mark.parametrize("head_dim", [128])
1057+
@pytest.mark.parametrize("device_scale", [True, False])
10271058
def test_trtllm_batch_decode_bs1(
10281059
kv_layout,
10291060
batch_size,
@@ -1039,6 +1070,7 @@ def test_trtllm_batch_decode_bs1(
10391070
enable_sink,
10401071
max_in_kv_len,
10411072
head_dim,
1073+
device_scale,
10421074
):
10431075
# Small number of test cases for batch size 1
10441076
pytest.xfail("trtllm-gen decode gets incorrect output with bs1")
@@ -1058,7 +1090,7 @@ def test_trtllm_batch_decode_bs1(
10581090
enable_sink,
10591091
max_in_kv_len,
10601092
head_dim,
1061-
False,
1093+
device_scale,
10621094
)
10631095

10641096

@@ -1091,6 +1123,7 @@ def test_trtllm_batch_decode_bs1(
10911123
@pytest.mark.parametrize("enable_sink", [False])
10921124
@pytest.mark.parametrize("max_in_kv_len", [110])
10931125
@pytest.mark.parametrize("head_dim", [256])
1126+
@pytest.mark.parametrize("device_scale", [True, False])
10941127
def test_trtllm_batch_decode_head_dim_256(
10951128
kv_layout,
10961129
batch_size,
@@ -1106,6 +1139,7 @@ def test_trtllm_batch_decode_head_dim_256(
11061139
enable_sink,
11071140
max_in_kv_len,
11081141
head_dim,
1142+
device_scale,
11091143
):
11101144
# Small number of test cases for head_dim = 256
11111145
pytest.xfail("trtllm-gen decode gets incorrect output with head_dim = 256")
@@ -1125,7 +1159,7 @@ def test_trtllm_batch_decode_head_dim_256(
11251159
enable_sink,
11261160
max_in_kv_len,
11271161
head_dim,
1128-
True,
1162+
device_scale,
11291163
)
11301164

11311165

@@ -1151,6 +1185,7 @@ def test_trtllm_batch_decode_head_dim_256(
11511185
@pytest.mark.parametrize("enable_sink", [False])
11521186
@pytest.mark.parametrize("max_in_kv_len", [4096, 8192, 16384, 32768, 65536, 131072])
11531187
@pytest.mark.parametrize("head_dim", [128])
1188+
@pytest.mark.parametrize("device_scale", [True, False])
11541189
def test_trtllm_batch_decode_long_sequence_length(
11551190
kv_layout,
11561191
batch_size,
@@ -1166,6 +1201,7 @@ def test_trtllm_batch_decode_long_sequence_length(
11661201
enable_sink,
11671202
max_in_kv_len,
11681203
head_dim,
1204+
device_scale,
11691205
):
11701206
# Small number of test cases for long sequence length
11711207
_test_trtllm_batch_decode(
@@ -1184,7 +1220,7 @@ def test_trtllm_batch_decode_long_sequence_length(
11841220
enable_sink,
11851221
max_in_kv_len,
11861222
head_dim,
1187-
False,
1223+
device_scale,
11881224
)
11891225

11901226

0 commit comments

Comments
 (0)