diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index 508fce831d..d56be03eb6 100644 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -9,20 +9,7 @@ workspace_size = 128 * 1024 * 1024 -@pytest.mark.parametrize( - "batch_size", - [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024], -) -@pytest.mark.parametrize("scale", [1.0, 0.5]) -@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) -@pytest.mark.parametrize("page_size", [32, 64]) -@pytest.mark.parametrize( - "q_len_per_request", [1, 2] -) # todo(Yingyi): verify larger q_len_per_request -@pytest.mark.parametrize("dynamic_scale", [False]) -@pytest.mark.parametrize("enable_pdl", [True, False, None]) -@pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"]) -def test_trtllm_batch_decode_mla( +def trtllm_batch_decode_mla( batch_size: int, scale: float, dtype: torch.dtype, @@ -31,6 +18,7 @@ def test_trtllm_batch_decode_mla( dynamic_scale: bool, enable_pdl: bool, backend: str, + MAX_SEQ_LEN: int, ): compute_capability = get_compute_capability(torch.device(device="cuda")) if backend == "xqa": @@ -49,9 +37,6 @@ def test_trtllm_batch_decode_mla( torch.manual_seed(42) device = "cuda:0" - # Fixed max sequence length - MAX_SEQ_LEN = 1024 - # Deepseek attention config (decode-MLA) num_q_heads = 128 qk_nope_head_dim = 128 @@ -239,3 +224,75 @@ def test_trtllm_batch_decode_mla( f"Total {o_ref.numel()} elements, only {pass_ratio:.1%} meet tolerance criteria, " f"require at least {required_ratio:.1%}" ) + + +@pytest.mark.parametrize( + "batch_size", + [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024], +) +@pytest.mark.parametrize("scale", [1.0, 0.5]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) +@pytest.mark.parametrize("page_size", [32, 64]) +@pytest.mark.parametrize( + "q_len_per_request", [1, 2] +) # todo(Yingyi): verify larger q_len_per_request +@pytest.mark.parametrize("dynamic_scale", [False]) +@pytest.mark.parametrize("enable_pdl", [True, False, None]) +@pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"]) +def test_trtllm_batch_decode_mla( + batch_size: int, + scale: float, + dtype: torch.dtype, + page_size: int, + q_len_per_request: int, + dynamic_scale: bool, + enable_pdl: bool, + backend: str, +): + trtllm_batch_decode_mla( + batch_size, + scale, + dtype, + page_size, + q_len_per_request, + dynamic_scale, + enable_pdl, + backend, + 1024, + ) + + +@pytest.mark.parametrize( + "batch_size", + [2, 4, 8], +) +@pytest.mark.parametrize("scale", [1.0, 0.5]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) +@pytest.mark.parametrize("page_size", [64]) +@pytest.mark.parametrize("q_len_per_request", [1, 2, 3]) +@pytest.mark.parametrize("dynamic_scale", [False]) +@pytest.mark.parametrize("enable_pdl", [True, False, None]) +@pytest.mark.parametrize("backend", ["trtllm-gen"]) +@pytest.mark.parametrize("MAX_SEQ_LEN", [1024, 8960]) +def test_dsr1_trtllm_mla( + batch_size: int, + scale: float, + dtype: torch.dtype, + page_size: int, + q_len_per_request: int, + dynamic_scale: bool, + enable_pdl: bool, + backend: str, + MAX_SEQ_LEN: int, +): + trtllm_batch_decode_mla( + batch_size, + scale, + dtype, + page_size, + q_len_per_request, + dynamic_scale, + enable_pdl, + backend, + MAX_SEQ_LEN, + )