Skip to content

Commit 9d9e4fc

Browse files
committed
Merge remote-tracking branch 'origin/main' into enable_xqa_spec_dec
2 parents 95ce39b + b9964cc commit 9d9e4fc

File tree

5 files changed

+468
-195
lines changed

5 files changed

+468
-195
lines changed

tests/attention/test_trtllm_gen_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,7 @@ def test_trtllm_batch_decode(
10871087
"batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size",
10881088
[
10891089
(1, 1, 16, 8, 8),
1090+
(1, 1, 32, 8, 8),
10901091
],
10911092
)
10921093
@pytest.mark.parametrize("window_left", [-1])
@@ -1098,7 +1099,7 @@ def test_trtllm_batch_decode(
10981099
)
10991100
@pytest.mark.parametrize("enable_pdl", [None])
11001101
@pytest.mark.parametrize("enable_sink", [False])
1101-
@pytest.mark.parametrize("max_in_kv_len", [8192])
1102+
@pytest.mark.parametrize("max_in_kv_len", [4096, 8192])
11021103
@pytest.mark.parametrize("head_dim", [128])
11031104
@pytest.mark.parametrize("device_scale", [True, False])
11041105
def test_trtllm_batch_decode_bs1(
@@ -1119,7 +1120,6 @@ def test_trtllm_batch_decode_bs1(
11191120
device_scale,
11201121
):
11211122
# Small number of test cases for batch size 1
1122-
pytest.xfail("trtllm-gen decode gets incorrect output with bs1")
11231123
_test_trtllm_batch_decode(
11241124
"trtllm-gen",
11251125
kv_layout,

tests/attention/test_trtllm_gen_mla.py

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,7 @@
99
workspace_size = 128 * 1024 * 1024
1010

1111

12-
@pytest.mark.parametrize(
13-
"batch_size",
14-
[1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024],
15-
)
16-
@pytest.mark.parametrize("scale", [1.0, 0.5])
17-
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
18-
@pytest.mark.parametrize("page_size", [32, 64])
19-
@pytest.mark.parametrize(
20-
"q_len_per_request", [1, 2]
21-
) # todo(Yingyi): verify larger q_len_per_request
22-
@pytest.mark.parametrize("dynamic_scale", [False])
23-
@pytest.mark.parametrize("enable_pdl", [True, False, None])
24-
@pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"])
25-
def test_trtllm_batch_decode_mla(
12+
def trtllm_batch_decode_mla(
2613
batch_size: int,
2714
scale: float,
2815
dtype: torch.dtype,
@@ -31,6 +18,7 @@ def test_trtllm_batch_decode_mla(
3118
dynamic_scale: bool,
3219
enable_pdl: bool,
3320
backend: str,
21+
MAX_SEQ_LEN: int,
3422
):
3523
compute_capability = get_compute_capability(torch.device(device="cuda"))
3624
if backend == "xqa":
@@ -49,9 +37,6 @@ def test_trtllm_batch_decode_mla(
4937
torch.manual_seed(42)
5038
device = "cuda:0"
5139

52-
# Fixed max sequence length
53-
MAX_SEQ_LEN = 1024
54-
5540
# Deepseek attention config (decode-MLA)
5641
num_q_heads = 128
5742
qk_nope_head_dim = 128
@@ -239,3 +224,75 @@ def test_trtllm_batch_decode_mla(
239224
f"Total {o_ref.numel()} elements, only {pass_ratio:.1%} meet tolerance criteria, "
240225
f"require at least {required_ratio:.1%}"
241226
)
227+
228+
229+
@pytest.mark.parametrize(
230+
"batch_size",
231+
[1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024],
232+
)
233+
@pytest.mark.parametrize("scale", [1.0, 0.5])
234+
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
235+
@pytest.mark.parametrize("page_size", [32, 64])
236+
@pytest.mark.parametrize(
237+
"q_len_per_request", [1, 2]
238+
) # todo(Yingyi): verify larger q_len_per_request
239+
@pytest.mark.parametrize("dynamic_scale", [False])
240+
@pytest.mark.parametrize("enable_pdl", [True, False, None])
241+
@pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"])
242+
def test_trtllm_batch_decode_mla(
243+
batch_size: int,
244+
scale: float,
245+
dtype: torch.dtype,
246+
page_size: int,
247+
q_len_per_request: int,
248+
dynamic_scale: bool,
249+
enable_pdl: bool,
250+
backend: str,
251+
):
252+
trtllm_batch_decode_mla(
253+
batch_size,
254+
scale,
255+
dtype,
256+
page_size,
257+
q_len_per_request,
258+
dynamic_scale,
259+
enable_pdl,
260+
backend,
261+
1024,
262+
)
263+
264+
265+
@pytest.mark.parametrize(
266+
"batch_size",
267+
[2, 4, 8],
268+
)
269+
@pytest.mark.parametrize("scale", [1.0, 0.5])
270+
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
271+
@pytest.mark.parametrize("page_size", [64])
272+
@pytest.mark.parametrize("q_len_per_request", [1, 2, 3])
273+
@pytest.mark.parametrize("dynamic_scale", [False])
274+
@pytest.mark.parametrize("enable_pdl", [True, False, None])
275+
@pytest.mark.parametrize("backend", ["trtllm-gen"])
276+
@pytest.mark.parametrize("MAX_SEQ_LEN", [1024, 8960])
277+
def test_dsr1_trtllm_mla(
278+
batch_size: int,
279+
scale: float,
280+
dtype: torch.dtype,
281+
page_size: int,
282+
q_len_per_request: int,
283+
dynamic_scale: bool,
284+
enable_pdl: bool,
285+
backend: str,
286+
MAX_SEQ_LEN: int,
287+
):
288+
trtllm_batch_decode_mla(
289+
batch_size,
290+
scale,
291+
dtype,
292+
page_size,
293+
q_len_per_request,
294+
dynamic_scale,
295+
enable_pdl,
296+
backend,
297+
MAX_SEQ_LEN,
298+
)

0 commit comments

Comments
 (0)