unittest: Add head dim 256 test cases and mark as xfail#1999
unittest: Add head dim 256 test cases and mark as xfail#1999yzh119 merged 3 commits intoflashinfer-ai:mainfrom
Conversation
WalkthroughTests updated to parameterize batch-decode tests with a new Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
| head_dim, | ||
| ): | ||
| pytest.xfail("trtllm-gen decode gets incorrect output with head_dim = 256") | ||
| test_trtllm_batch_decode( |
There was a problem hiding this comment.
It's unusual practice (but it's totally okay to do so and it's not introduced in this PR) to call one test_* function instead another one, all top-level functions with prefix test_ will be treated as standalone unittests.
Can we create a function _test_trtllm_batch_decode as the common body of these unittests, instead of calling another top-level test_trtllm_batch_decode function?
There was a problem hiding this comment.
Hi @yzh119, I think this PR is a good opportunity to do make this change.
I have:
- renamed
test_trtllm_batch_decodeto_test_trtllm_batch_decodeas a base function. - There are test functions that call
_test_trtllm_batch_decodewith a group of parameter combinations:test_trtllm_batch_decode--> 1632 existing parameter combinationstest_trtllm_batch_decode_bs1--> 1 xfail case with batch size 1test_trtllm_batch_decode_head_dim_256--> 40 xfail cases with head_dim=256.test_trtllm_batch_decode_long_sequence_length--> 48 cases of long seqlen.
There was a problem hiding this comment.
The long seqlen was added because I saw #1968 and tested what happens if try testing long seqlens. We start to see failures starting from 4k
1993033 to
097308f
Compare
|
/bot run |
|
[SUCCESS] Pipeline #37485772: 13/17 passed |
📌 Description
Adding unit test for
head_dim=256cases for trtllm-gen decode and marking them as xfail.Renames
test_trtllm_batch_decodeto_test_trtllm_batch_decodeas a base function. Test functions now call _test_trtllm_batch_decode with a group of parameter combinations:test_trtllm_batch_decode--> 1632 existing parameter combinationstest_trtllm_batch_decode_bs1--> 1 xfail case with batch size 1test_trtllm_batch_decode_head_dim_256--> 40 xfail cases with head_dim=256.test_trtllm_batch_decode_long_sequence_length--> 48 cases of long seqlen.🔍 Related Issues
#1993
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit