[Enhancement] Add thread count validation for ReduceOp fragment layout inference#1225
[Enhancement] Add thread count validation for ReduceOp fragment layout inference#1225LeiWang1999 merged 2 commits intotile-ai:mainfrom
Conversation
…t inference * Introduced a check to ensure that the thread count is divisible by the replicate extent during layout inference in ReduceOpNode. This validation prevents layout inference failures and provides detailed error messages to guide users in resolving issues related to thread block sizes and fragment layouts. * Updated tests to remove unsupported configurations that could lead to layout inference errors, ensuring more robust testing scenarios.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughThe changes add a divisibility guard in reduce operation layout inference using an arithmetic analyzer to verify thread count divisibility by replicate extent, with abort behavior if unproven. Two related test cases are removed from the reduce test suite. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/op/reduce.cc(1 hunks)testing/python/language/test_tilelang_language_reduce.py(0 hunks)
💤 Files with no reviewable changes (1)
- testing/python/language/test_tilelang_language_reduce.py
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/reduce.cc (2)
tilelang/language/kernel.py (1)
num_threads(222-226)src/tl_templates/cuda/reduce.h (1)
T(208-280)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (1)
src/op/reduce.cc (1)
393-419: Divisibility validation logic is sound.The bidirectional check correctly ensures that either
num_threadsis divisible bydest_buffer_rep_extentor vice versa, preventing invalid fragment-to-fragment mappings. The use ofarith::Analyzer::CanProveandICHECK(false)for compile-time validation is appropriate.
| ICHECK(false) << "ReduceOp fragment layout inference failed: " | ||
| "num_threads % replicate_extent != 0. " | ||
| << "This mapping requires the block's thread count to be " | ||
| "divisible by the " | ||
| << "replicate extent. " | ||
| << "Try one of: (1) choose a thread block size divisible " | ||
| "by replicate_extent; " | ||
| << "(2) pick a different reduce dimension or adjust the " | ||
| "source fragment layout; " | ||
| << "Details: num_threads=" << num_threads | ||
| << ", replicate_extent=" << indice_rep_extent | ||
| << ", src=" << src << ", dst=" << dst; | ||
| } |
There was a problem hiding this comment.
Error message should print the actual value being checked.
The error message on line 416 prints indice_rep_extent, but the divisibility check validates dest_buffer_rep_extent (which equals indice_rep_extent * src_rep_extent from line 378). This discrepancy could confuse users debugging the error.
Apply this diff to improve error message accuracy:
<< "Details: num_threads=" << num_threads
- << ", replicate_extent=" << indice_rep_extent
+ << ", replicate_extent=" << dest_buffer_rep_extent
<< ", src=" << src << ", dst=" << dst;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ICHECK(false) << "ReduceOp fragment layout inference failed: " | |
| "num_threads % replicate_extent != 0. " | |
| << "This mapping requires the block's thread count to be " | |
| "divisible by the " | |
| << "replicate extent. " | |
| << "Try one of: (1) choose a thread block size divisible " | |
| "by replicate_extent; " | |
| << "(2) pick a different reduce dimension or adjust the " | |
| "source fragment layout; " | |
| << "Details: num_threads=" << num_threads | |
| << ", replicate_extent=" << indice_rep_extent | |
| << ", src=" << src << ", dst=" << dst; | |
| } | |
| ICHECK(false) << "ReduceOp fragment layout inference failed: " | |
| "num_threads % replicate_extent != 0. " | |
| << "This mapping requires the block's thread count to be " | |
| "divisible by the " | |
| << "replicate extent. " | |
| << "Try one of: (1) choose a thread block size divisible " | |
| "by replicate_extent; " | |
| << "(2) pick a different reduce dimension or adjust the " | |
| "source fragment layout; " | |
| << "Details: num_threads=" << num_threads | |
| << ", replicate_extent=" << dest_buffer_rep_extent | |
| << ", src=" << src << ", dst=" << dst; | |
| } |
🤖 Prompt for AI Agents
In src/op/reduce.cc around lines 406 to 418, the error message prints
indice_rep_extent but the divisibility check is actually on
dest_buffer_rep_extent (which equals indice_rep_extent * src_rep_extent); update
the ICHECK message to print dest_buffer_rep_extent (or include both
dest_buffer_rep_extent and indice_rep_extent) so the logged values match the
condition being checked and give accurate debugging info.
…t inference (tile-ai#1225) * [Enhancement] Add thread count validation for ReduceOp fragment layout inference * Introduced a check to ensure that the thread count is divisible by the replicate extent during layout inference in ReduceOpNode. This validation prevents layout inference failures and provides detailed error messages to guide users in resolving issues related to thread block sizes and fragment layouts. * Updated tests to remove unsupported configurations that could lead to layout inference errors, ensuring more robust testing scenarios. * lint fix
This pull request improves the robustness of layout inference for reduce operations and updates the corresponding test coverage. The most significant change is the addition of a check to ensure that the thread count is divisible by the replicate extent during layout inference, which prevents invalid fragment-to-fragment mappings. Additionally, some test cases that would violate this new constraint have been removed.
Reduce operation layout inference improvements:
ReduceOpNode::InferLayout(src/op/reduce.cc) to ensure that the thread count (num_threads) is divisible by the replicate extent (replicate_extent). If this condition is not met, an error is raised with a detailed message, preventing invalid fragment layout inference and guiding users on how to resolve the issue.Testing updates:
test_reduce_sum_sharedandtest_reduce_max_shared(testing/python/language/test_tilelang_language_reduce.py) that used thread and replicate extent combinations which are now disallowed by the new divisibility check. [1] [2]Summary by CodeRabbit
Release Notes