Skip to content

Commit b9964cc

Browse files
bkryuyzh119
andauthored
test: Enable testing for trtllm-gen decode bs1 (#2103)
<!-- .github/pull_request_template.md --> ## 📌 Description In #1898, it was raised that trtllm-gen's attention kernels fail for batch size 1. The prefill kernel was fixed in #1912 and prefill tests have been enabled. Further updates to trtllm-gen kernels have also fixed the decode batch size 1 issue. Current PR re-enables testing. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Expanded batch_decode test scenarios to cover additional small-batch and page-size combinations. * Increased coverage for max_in_kv_len by testing multiple length options instead of a single value. * Restored previously marked-as-expected-failure case to run normally, improving overall test pass coverage. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Zihao Ye <[email protected]>
1 parent 219592b commit b9964cc

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/attention/test_trtllm_gen_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,7 @@ def test_trtllm_batch_decode(
10411041
"batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size",
10421042
[
10431043
(1, 1, 16, 8, 8),
1044+
(1, 1, 32, 8, 8),
10441045
],
10451046
)
10461047
@pytest.mark.parametrize("window_left", [-1])
@@ -1052,7 +1053,7 @@ def test_trtllm_batch_decode(
10521053
)
10531054
@pytest.mark.parametrize("enable_pdl", [None])
10541055
@pytest.mark.parametrize("enable_sink", [False])
1055-
@pytest.mark.parametrize("max_in_kv_len", [8192])
1056+
@pytest.mark.parametrize("max_in_kv_len", [4096, 8192])
10561057
@pytest.mark.parametrize("head_dim", [128])
10571058
@pytest.mark.parametrize("device_scale", [True, False])
10581059
def test_trtllm_batch_decode_bs1(
@@ -1073,7 +1074,6 @@ def test_trtllm_batch_decode_bs1(
10731074
device_scale,
10741075
):
10751076
# Small number of test cases for batch size 1
1076-
pytest.xfail("trtllm-gen decode gets incorrect output with bs1")
10771077
_test_trtllm_batch_decode(
10781078
"trtllm-gen",
10791079
kv_layout,

0 commit comments

Comments
 (0)