Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 74 additions & 17 deletions tests/attention/test_trtllm_gen_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,7 @@
workspace_size = 128 * 1024 * 1024


@pytest.mark.parametrize(
"batch_size",
[1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024],
)
@pytest.mark.parametrize("scale", [1.0, 0.5])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("page_size", [32, 64])
@pytest.mark.parametrize(
"q_len_per_request", [1, 2]
) # todo(Yingyi): verify larger q_len_per_request
@pytest.mark.parametrize("dynamic_scale", [False])
@pytest.mark.parametrize("enable_pdl", [True, False, None])
@pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"])
def test_trtllm_batch_decode_mla(
def trtllm_batch_decode_mla(
batch_size: int,
scale: float,
dtype: torch.dtype,
Expand All @@ -31,6 +18,7 @@ def test_trtllm_batch_decode_mla(
dynamic_scale: bool,
enable_pdl: bool,
backend: str,
MAX_SEQ_LEN: int,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Per PEP 8, function parameter names should be in lowercase_with_underscores. Please rename MAX_SEQ_LEN to max_seq_len for consistency. You will also need to update its usage within the function body.

Suggested change
MAX_SEQ_LEN: int,
max_seq_len: int,

):
compute_capability = get_compute_capability(torch.device(device="cuda"))
if backend == "xqa":
Expand All @@ -49,9 +37,6 @@ def test_trtllm_batch_decode_mla(
torch.manual_seed(42)
device = "cuda:0"

# Fixed max sequence length
MAX_SEQ_LEN = 1024

# Deepseek attention config (decode-MLA)
num_q_heads = 128
qk_nope_head_dim = 128
Expand Down Expand Up @@ -239,3 +224,75 @@ def test_trtllm_batch_decode_mla(
f"Total {o_ref.numel()} elements, only {pass_ratio:.1%} meet tolerance criteria, "
f"require at least {required_ratio:.1%}"
)


@pytest.mark.parametrize(
"batch_size",
[1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024],
)
@pytest.mark.parametrize("scale", [1.0, 0.5])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("page_size", [32, 64])
@pytest.mark.parametrize(
"q_len_per_request", [1, 2]
) # todo(Yingyi): verify larger q_len_per_request
@pytest.mark.parametrize("dynamic_scale", [False])
@pytest.mark.parametrize("enable_pdl", [True, False, None])
@pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"])
def test_trtllm_batch_decode_mla(
batch_size: int,
scale: float,
dtype: torch.dtype,
page_size: int,
q_len_per_request: int,
dynamic_scale: bool,
enable_pdl: bool,
backend: str,
):
trtllm_batch_decode_mla(
batch_size,
scale,
dtype,
page_size,
q_len_per_request,
dynamic_scale,
enable_pdl,
backend,
1024,
)


@pytest.mark.parametrize(
"batch_size",
[2, 4, 8],
)
@pytest.mark.parametrize("scale", [1.0, 0.5])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("page_size", [64])
@pytest.mark.parametrize("q_len_per_request", [1, 2, 3])
@pytest.mark.parametrize("dynamic_scale", [False])
@pytest.mark.parametrize("enable_pdl", [True, False, None])
@pytest.mark.parametrize("backend", ["trtllm-gen"])
@pytest.mark.parametrize("MAX_SEQ_LEN", [1024, 8960])
def test_dsr1_trtllm_mla(
batch_size: int,
scale: float,
dtype: torch.dtype,
page_size: int,
q_len_per_request: int,
dynamic_scale: bool,
enable_pdl: bool,
backend: str,
MAX_SEQ_LEN: int,
):
trtllm_batch_decode_mla(
batch_size,
scale,
dtype,
page_size,
q_len_per_request,
dynamic_scale,
enable_pdl,
backend,
MAX_SEQ_LEN,
)
Comment on lines +276 to +298
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Following PEP 8 style guidelines, parameter and variable names should be in lowercase_with_underscores. Please rename MAX_SEQ_LEN to max_seq_len in the parametrize decorator, the test function signature, and the call to the helper function.

@pytest.mark.parametrize("max_seq_len", [1024, 8960])
def test_dsr1_trtllm_mla(
    batch_size: int,
    scale: float,
    dtype: torch.dtype,
    page_size: int,
    q_len_per_request: int,
    dynamic_scale: bool,
    enable_pdl: bool,
    backend: str,
    max_seq_len: int,
):
    trtllm_batch_decode_mla(
        batch_size,
        scale,
        dtype,
        page_size,
        q_len_per_request,
        dynamic_scale,
        enable_pdl,
        backend,
        max_seq_len,
    )