[BugFix] Phaseout unused tests for gqa decode kernels and add the kernels to CI#1515
[BugFix] Phaseout unused tests for gqa decode kernels and add the kernels to CI#1515LeiWang1999 merged 1 commit intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughThe PR reorganizes flash decoding example entry points by adding Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
examples/flash_decoding/example_gqa_decode_varlen_logits.py (1)
200-200: Consider documenting or removing the commented autotune decorator.The commented-out
@autotunedecorator suggests temporary disabling. If autotuning is not needed for CI tests, consider removing the line. Otherwise, add a comment explaining why it's disabled.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/flash_decoding/example_gqa_decode_varlen_logits.pyexamples/flash_decoding/example_gqa_decode_varlen_logits_paged.pyexamples/flash_decoding/test_example_flash_decoding.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
examples/flash_decoding/example_gqa_decode_varlen_logits.py
🧬 Code graph analysis (3)
examples/flash_decoding/test_example_flash_decoding.py (2)
examples/flash_decoding/example_gqa_decode_varlen_logits.py (1)
main(770-784)examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py (1)
main(522-537)
examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py (1)
examples/flash_decoding/example_gqa_decode_varlen_logits.py (2)
main(770-784)test_varlen_decode_main(440-641)
examples/flash_decoding/example_gqa_decode_varlen_logits.py (2)
examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py (2)
main(522-537)test_varlen_decode_main(202-407)examples/flash_decoding/example_gqa_decode.py (1)
main(440-483)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (8)
examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py (2)
522-538: LGTM! Entry point for CI testing.The
main()function provides a clean entry point for the CI test suite with appropriate default parameters. The addition ofpage_block_size=128is specific to this paged variant and aligns with the kernel's requirements.
563-563: LGTM! Simplified test execution path.The change to directly call
test_varlen_decode_main(args)aligns with the PR objective to phase out unused equal-length decode tests and focus on variable-length testing.examples/flash_decoding/test_example_flash_decoding.py (2)
5-6: LGTM! Module imports for CI tests.The imports enable the new test functions to call the
main()entry points added to these modules.
20-26: No action needed. The varlen test modules (example_gqa_decode_varlen_logitsandexample_gqa_decode_varlen_logits_paged) show no evidence of SM90 correctness issues, unlike the base gqa_decode module which explicitly documents this issue in line 9 of the test file (# TODO(lei): fix the correctness of gqa decode on sm90). The varlen modules contain unrelated TODOs about max_seqlen_kv validation and split versions, with no SM90 or compute version constraints mentioned anywhere in their code. The decorators are not required for these tests.Likely an incorrect or invalid review comment.
examples/flash_decoding/example_gqa_decode_varlen_logits.py (4)
623-627: LGTM! Improved shape alignment for comparison.The change ensures both tensors are sliced to the same size before computing the difference, making the comparison more robust and preventing potential shape mismatch errors.
634-639: LGTM! Consistent shape alignment in assertion.The multi-line assertion format improves readability, and the slicing on both tensors ensures shape consistency with the
max_diff_s_tlcalculation above.
770-785: LGTM! Entry point for CI testing.The
main()function provides a clean entry point for the CI test suite with appropriate default parameters. The configuration aligns with the updated defaults in the argument parser.
802-809: LGTM! Varlen testing now default.The changes make variable-length testing the default path, aligning with the PR objective to phase out equal-length decode tests. The execution flow now consistently uses
test_varlen_decode_main(args).
| parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") | ||
| parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") | ||
| parser.add_argument("--block_size", type=int, default=64, help="Block size for computation") | ||
| parser.add_argument("--block_size", type=int, default=128, help="Block size for computation") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if block_size=64 is used elsewhere in the codebase
rg -nP --type=py 'block_size\s*=\s*64' examples/Repository: tile-ai/tilelang
Length of output: 1496
🏁 Script executed:
#!/bin/bash
# Check the broader context of block_size usage
rg -nP --type=py 'block_size' examples/flash_decoding/ | head -20Repository: tile-ai/tilelang
Length of output: 2950
🏁 Script executed:
#!/bin/bash
# Check the git history of this file to confirm the change from 64 to 128
git log -p --follow examples/flash_decoding/example_gqa_decode_varlen_logits.py | grep -A 5 -B 5 "block_size.*default" | head -30Repository: tile-ai/tilelang
Length of output: 42
Parameter default changed from 64 to 128 inconsistent with codebase patterns.
The block_size default has been changed from 64 to 128, creating an inconsistency with the rest of the codebase where block_size=64 is the standard default across multiple modules (deepseek_nsa, deepseek_mla, and even the related example_gqa_decode_varlen_logits_paged.py file). Document this deviation and ensure the change is intentional and doesn't break existing workflows.
🤖 Prompt for AI Agents
In examples/flash_decoding/example_gqa_decode_varlen_logits.py around line 794,
the parser default for --block_size was changed from 64 to 128 which is
inconsistent with the rest of the codebase; either revert the default back to 64
to match deepseek_nsa/deepseek_mla and
example_gqa_decode_varlen_logits_paged.py, or if 128 is intentional, add a clear
inline comment and update documentation/examples to explain the deviation and
run a quick smoke test to ensure no workflows break.
As title
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.