[Enhancement] Improve iterator handling in layout utilities and parallel operations#1221
Conversation
…lel operations * Added a new function, DivideUnusedIterators, to detect per-iterator gaps in fused index expressions, enhancing the accuracy of unused iterator detection. * Updated CompleteBufferFragment to prefer direct inversion for bijective index mappings and introduced a fallback mechanism for non-bijective cases, improving layout inversion robustness. * Added a new test for layout inference in fused kernels to ensure correct compilation and execution without layout inversion failures.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughThis PR refactors iterator split computation and optimizes buffer fragment completion for bijective mappings. It introduces a new Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant CBF as CompleteBufferFragment
participant Check as BijectiveCheck
participant FwdPath as FastBijectivePath
participant RepBSetup as RepB Setup
participant CardCheck as CardinalityCheck
participant FallBack as Fallback Path
participant Output as Fragment
Caller->>CBF: CompleteBufferFragment(buffer)
rect rgb(230, 240, 255)
Note over Check,FwdPath: Fast Bijective Path (new)
CBF->>Check: Check 2D bijective mapping
alt 2D indices form bijection
Check->>FwdPath: Invert 2D mapping directly
FwdPath->>Output: Return condensed Fragment
end
end
rect rgb(240, 230, 255)
Note over RepBSetup,CardCheck: Extended Path with RepB (new)
CBF->>RepBSetup: Create rep_b from unused iterators
RepBSetup->>RepBSetup: Flatten and extend indice_map
RepBSetup->>CardCheck: Check cardinality (in_prod vs out_prod)
alt Bijectivity holds after RepB
CardCheck->>CBF: Compute ind_inv, use ForwardThread
CBF->>Output: Return replication-aware Fragment
else Non-bijective after RepB
CardCheck->>FallBack: Compute non-replicated inverse
FallBack->>Output: Return CondenseReplicateVar Fragment
end
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/layout/utils.cc (1)
133-151: Don't throw on fused IterMark sources
NormalizeToIterSumlegitimately producesIterMarknodes whose source is anIterSumExpr(e.g., fused iterators). The new blanket check now throws aNormalizeIterExceptionfor those marks, so any fused layout that previously worked will immediately fail. Instead of rejecting them, skip non-Var sources when merging splits and keep the existing behavior for Var-backed marks.Apply this diff to keep fused marks legal:
- for (const IterMark &mark : collector.visited_) { - if (!mark->source.as<Var>()) { - std::ostringstream oss; - oss << "Not a normalized iterator: " << mark; - throw NormalizeIterException(oss.str()); - } - } - for (const IterVar &iter : input_iters) { // Merge splits from all IterMark that share the same source Var as `iter`. std::vector<IterSplitExpr> merged_splits; for (const IterMark &mark : collector.visited_) { - auto vexpr = mark->source.as<Var>(); - if (vexpr && vexpr.value().same_as(iter->var)) { + auto vexpr = mark->source.as<Var>(); + if (!vexpr) + continue; + if (vexpr.value().same_as(iter->var)) { auto it = collector.mark2splits_.find(mark); if (it != collector.mark2splits_.end()) { const auto &vec = it->second; merged_splits.insert(merged_splits.end(), vec.begin(), vec.end()); }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/layout/utils.cc(2 hunks)src/op/parallel.cc(1 hunks)testing/python/layout/test_tilelang_layout_fused_replicate.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/parallel.cc (2)
src/layout/utils.cc (6)
ToVMap(268-274)ToVMap(268-268)MakeFlattenedExpression(170-180)MakeFlattenedExpression(170-170)DivideUnusedIterators(122-168)DivideUnusedIterators(122-124)src/layout/layout.cc (6)
Layout(57-69)Layout(71-74)InputPlaceholder(30-32)InputPlaceholder(30-30)Fragment(318-340)Fragment(342-352)
testing/python/layout/test_tilelang_layout_fused_replicate.py (4)
tilelang/testing/__init__.py (1)
set_random_seed(30-35)tilelang/language/allocate.py (1)
alloc_fragment(59-70)tilelang/language/loop.py (1)
Parallel(12-32)tilelang/language/v2/dtypes.py (2)
bfloat16(297-297)float32(200-200)
🪛 Ruff (0.14.3)
testing/python/layout/test_tilelang_layout_fused_replicate.py
18-18: Unused function argument: a
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
| for i, j in T.Parallel(BLOCK_MN, BLOCK_K): | ||
| idx = i * BLOCK_K + j | ||
| a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE] | ||
|
|
There was a problem hiding this comment.
Initialize the fragment before reading from it
a_out is filled from a_fp32_local, but that fragment is never written—a is ignored entirely. This leaves the store sourcing undefined data from uninitialized memory. Please load (or otherwise initialize) the fragment before using it and consume the input tensor.
Apply this diff to populate the fragment from a:
for i, j in T.Parallel(BLOCK_MN, BLOCK_K):
idx = i * BLOCK_K + j
- a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE]
+ a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE] = T.cast(
+ a[pid_b, offs_m + i, offs_n + j], "float32"
+ )
+ a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for i, j in T.Parallel(BLOCK_MN, BLOCK_K): | |
| idx = i * BLOCK_K + j | |
| a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE] | |
| for i, j in T.Parallel(BLOCK_MN, BLOCK_K): | |
| idx = i * BLOCK_K + j | |
| a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE] = T.cast( | |
| a[pid_b, offs_m + i, offs_n + j], "float32" | |
| ) | |
| a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE] |
🤖 Prompt for AI Agents
testing/python/layout/test_tilelang_layout_fused_replicate.py around lines 31 to
34: the code stores from a_fp32_local into a_out but never initializes
a_fp32_local from the input tensor a, so stores read uninitialized memory; fix
by loading/initializing the fragment before the Parallel store — perform the
corresponding read from the input tensor a into a_fp32_local (e.g., a.load or an
explicit loop that reads a into the fragment using the same indexing/vec layout)
so the fragment is consumed and then use that initialized fragment when writing
to a_out.
…llel operations (tile-ai#1221) * [Enhancement] Improve iterator handling in layout utilities and parallel operations * Added a new function, DivideUnusedIterators, to detect per-iterator gaps in fused index expressions, enhancing the accuracy of unused iterator detection. * Updated CompleteBufferFragment to prefer direct inversion for bijective index mappings and introduced a fallback mechanism for non-bijective cases, improving layout inversion robustness. * Added a new test for layout inference in fused kernels to ensure correct compilation and execution without layout inversion failures. * lint fix
Summary by CodeRabbit
New Features
Tests