Skip to content

[Layout] Fix Layout Bugs in Parallel and Reduce#1713

Merged
kurisu6912 merged 15 commits intotile-ai:mainfrom
kurisu6912:fix-loop-layout
Jan 28, 2026
Merged

[Layout] Fix Layout Bugs in Parallel and Reduce#1713
kurisu6912 merged 15 commits intotile-ai:mainfrom
kurisu6912:fix-loop-layout

Conversation

@kurisu6912
Copy link
Collaborator

@kurisu6912 kurisu6912 commented Jan 22, 2026

  1. Add read write checking in T.Parallel: When a loop read a buffer, the buffer layout contains the loop layout (every read can hit the buffer); and when the loop write a buffer, loop layout contains the buffer layout (every place in the buffer is written)
  2. Refactor layout inference of T.reduce_xx: First infer a reducer layout, then use the reducer layout to do reduce, and finally copy the reducer to target buffer.

Summary by CodeRabbit

  • Refactor

    • Access tracking centralized to record per-buffer indices and read/write intent, simplifying parallel access iteration and planning.
  • Behavior Change

    • Layout inference and conflict reporting now use reducer-aware checks and per-buffer access metadata; guarded/conditional writes added to avoid unsafe writes during reductions.
  • Tests

    • Added regression tests exercising multiple layout and reduction scenarios to prevent regressions.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 22, 2026

📝 Walkthrough

Walkthrough

Centralizes per-buffer access tracking in ParallelOpNode via a new BufferAccessInfo and RecordBufferAccess/GetAccessInfo APIs; refactors layout inference and fragment validation to use per-buffer indices plus read/write flags; adds reducer-aware layout logic in ReduceOp; and adds seven TileLang regression tests for issue 1719.

Changes

Cohort / File(s) Summary
ParallelOp header
src/op/parallel.h
Add BufferAccessInfo { Array<PrimExpr> indices; bool is_read; bool is_write }; replace prior maps/sets with BufferIndiceMap (std::unordered_map); add RecordBufferAccess() and GetAccessInfo(); change GetIndiceMap() signature.
ParallelOp implementation
src/op/parallel.cc
Replace per-buffer indice_map_/write-set usage with centralized BufferAccessInfo; iterate (buffer, access) pairs and use access.indices / access.is_read / access.is_write across IsCommonAccessIndice, InferLayout, ComputeLoopLayoutFromBuffer, CompleteBufferFragment, ValidateCandidateAgainstFragments, ComputePlanCandidate, and related flows; update error/reporting text.
ReduceOp implementation
src/op/reduce.cc
Add ComputeReducerLayout helper; compute red_layout/red_indices; route initialization, local/thread reductions, and BufferLoad/BufferStore to use reducer indices; add predicate-guarded writes when reducer layout exceeds dst layout; update InferLayout to consider reducer layout.
TileLang tests
testing/python/issue/test_tilelang_issue_1719.py
Add seven regression tests (test_issue_1719_layout_1test_issue_1719_layout_7) exercising fragment fill/reduce patterns, guarded writes, and kernel generation/indexing behaviors.

Sequence Diagram(s)

sequenceDiagram
  participant Caller
  participant ParallelOpNode
  participant Buffer
  participant LayoutInferencer
  participant FragmentValidator

  Caller->>ParallelOpNode: RecordBufferAccess(buffer, indices, is_write)
  Note over ParallelOpNode: Store BufferAccessInfo { indices, is_read, is_write }
  Caller->>ParallelOpNode: Trigger InferLayout / ComputePlan
  ParallelOpNode->>ParallelOpNode: For each (buffer, access) -> GetAccessInfo(buffer)
  ParallelOpNode->>LayoutInferencer: Provide buffer access infos (access.indices, access.is_read/is_write)
  LayoutInferencer->>FragmentValidator: Validate candidate fragments using access.indices & flags
  FragmentValidator-->>LayoutInferencer: validation result
  LayoutInferencer-->>ParallelOpNode: chosen layout/plan (may include red_layout)
  ParallelOpNode-->>Caller: return plan
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 I hop through indices, flags, and glue,
One map now tells each buffer what to do.
Reducer lanes and guarded writes take flight,
Tests run, kernels hum into the night.
A carrot-nibble of code — neat and bright!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 15.15% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Layout] Fix Layout Bugs in Parallel and Reduce' clearly and directly summarizes the main changes across the pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@kurisu6912 kurisu6912 linked an issue Jan 23, 2026 that may be closed by this pull request
2 tasks
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@src/op/reduce.cc`:
- Around line 475-481: Summary: Fix the minor typo in a comment in
src/op/reduce.cc. Locate the comment block inside reduce.cc that reads "We
shouldn't widen the layout here, because is may be written by other parallel for
op" and change "is" to "it" so it reads "because it may be written by other
parallel for op"; no logic change required—only update the comment text in that
block (the commented-out if (dst_rep > src_rep) { ... } section).
🧹 Nitpick comments (1)
testing/python/issue/test_tilelang_issue_1719.py (1)

12-19: Consider initializing tmp1 or adding a clarifying comment.

tmp1 is allocated but never initialized before being used as a divisor at line 19. While the test only inspects the generated source (never executes the kernel), using uninitialized memory can be confusing for future maintainers who might try to run the kernel.

Consider either initializing tmp1 or adding a comment clarifying that this test only validates code generation, not runtime behavior.

💡 Option 1: Initialize tmp1
                 tmp1 = T.alloc_fragment((N, M), T.float32)
                 tmp2 = T.alloc_fragment((N, M), T.float32)
                 tmp3 = T.alloc_fragment((N, M, M), T.float32)
+                for i, k in T.Parallel(N, M):
+                    tmp1[i, k] = 1
                 for i, j, k in T.Parallel(N, M, M):
                     tmp3[i, j, k] = 1
💡 Option 2: Add clarifying comment
 def test_tilelang_issue_1719():
+    # This test validates code generation only; the kernel is not executed.
+    # tmp1 is intentionally left uninitialized as we only check the source.
     `@tilelang.jit`()
     def _buggy_kernel(M: int, N: int) -> tilelang.JITKernel:

@kurisu6912

This comment was marked as resolved.

@github-actions

This comment was marked as resolved.

@kurisu6912 kurisu6912 changed the title [Layout] Add ReadWrite Constraint in Layout Inference [Layout] Fix Layout Bugs in Parallel and Reduce Jan 27, 2026
@kurisu6912
Copy link
Collaborator Author

@regression-perf

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@src/op/reduce.cc`:
- Around line 412-437: The proven-path currently stores BufferLoad(clear_buffer,
red_indices) while the unproven path stores the computed update (which may
include accumulation), causing inconsistent semantics when need_update is true;
change the analyzer->CanProve(predicate) branch in the block handling
need_duplicate so that it pushes BufferStore(dst_buffer, update, dst_indices)
(i.e., reuse the previously computed update value) when need_update is true (and
still push the plain clear_buffer load only when need_update is false), ensuring
both the proven and unproven branches store the same value; update references:
analyzer->CanProve(predicate), need_update, update, BufferStore, clear_buffer,
dst_buffer, dst_indices, BufferLoad.
🧹 Nitpick comments (4)
src/op/reduce.cc (1)

244-245: Consider null check for as_const_int results.

The as_const_int function returns a pointer that could be null if the expression isn't a constant integer. While the context likely guarantees constant extents, a defensive check would prevent potential crashes on unexpected inputs.

💡 Suggested defensive pattern
-    auto dst_rep = *as_const_int(dst_layout->ReplicateExtent());
-    auto red_rep = *as_const_int(red_layout->ReplicateExtent());
+    auto dst_rep_ptr = as_const_int(dst_layout->ReplicateExtent());
+    auto red_rep_ptr = as_const_int(red_layout->ReplicateExtent());
+    ICHECK(dst_rep_ptr && red_rep_ptr)
+        << "ReplicateExtent must be constant for reduce layout inference";
+    auto dst_rep = *dst_rep_ptr;
+    auto red_rep = *red_rep_ptr;
testing/python/issue/test_tilelang_issue_1719.py (3)

7-26: Test only prints kernel source without validation.

test_issue_1719_layout_1 compiles a kernel but only prints the source without any assertions. Consider adding validation to ensure the kernel compiles correctly and the layout is as expected.

💡 Suggested improvement
     kernel = _buggy_kernel()
     print(kernel.get_kernel_source())
+    # Verify kernel compiles without error
+    assert kernel.get_kernel_source() is not None

122-124: Expression statement tmp2[0] has no effect.

Line 124 loads tmp2[0] but doesn't do anything with the result. This appears to be either incomplete or intended to verify the buffer is accessible. Consider making the intent clearer:

💡 Clarify intent
             T.reduce_sum(tmp1, tmp2)
-            tmp2[0]
+            # Verify buffer is accessible after reduction
+            _ = tmp2[0]

77-150: Consider adding more assertions to remaining tests.

Tests 4-7 are smoke tests that verify compilation succeeds, which is valuable. However, consider adding at least minimal assertions to catch regressions:

  • Verify get_kernel_source() returns non-empty output
  • For tests that instantiate without parameters (like 4, 6, 7), ensure the kernel object is valid

This would make the tests more robust without significantly increasing complexity.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@src/op/reduce.cc`:
- Around line 244-245: The code dereferences the result of as_const_int on
dst_layout->ReplicateExtent() and red_layout->ReplicateExtent() into dst_rep and
red_rep without checking for nullptr; update the code in the reduction routine
to call as_const_int for both ReplicateExtent() results, check each pointer for
nullptr before dereferencing, and handle the symbolic (nullptr) case
appropriately (e.g., return an error/status, throw an exception, or fall back to
a runtime-handled path) so you avoid null pointer dereference when
ReplicateExtent is not a compile-time constant.
🧹 Nitpick comments (1)
src/op/reduce.cc (1)

163-186: Consider adding defensive bounds check for dim parameter.

The function accesses src_layout->InputShape()[dim] without validating that dim is within bounds. While the caller (Lower and InferLayout) validates dimensions elsewhere, adding a defensive check here would make the function safer for potential future use.

♻️ Suggested defensive check
 static Fragment ComputeReducerLayout(const Fragment &src_layout, int dim) {
+  ICHECK(dim >= 0 && static_cast<size_t>(dim) < src_layout->InputDim())
+      << "Reduction dimension out of bounds: " << dim;
   PrimExpr src_rep_extent = src_layout->ReplicateExtent();
   PrimExpr indice_rep_extent = src_layout->InputShape()[dim];

@github-actions
Copy link

Performance Regression Test Report

Triggered by: @kurisu6912
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/21384235880

Results

File Original Latency Current Latency Speedup
example_warp_specialize_gemm_softpipe_stage2 0.038113 0.039488 0.965179
example_warp_specialize_gemm_copy_1_gemm_0 0.038338 0.039583 0.968547
example_warp_specialize_gemm_barrierpipe_stage2 0.038784 0.039456 0.982968
example_gqa_decode 0.048737 0.049153 0.991537
example_gemm_autotune 0.022176 0.022336 0.992837
example_warp_specialize_gemm_copy_0_gemm_1 0.038945 0.039137 0.995094
example_convolution_autotune 0.991599 0.995095 0.996487
example_tilelang_gemm_fp8_2xAcc 0.190498 0.191092 0.996894
example_gqa_bwd_wgmma_pipelined 0.0695986 0.0697331 0.998071
example_gemm_intrinsics 0.035073 0.035136 0.998207
tilelang_example_sparse_tensorcore 0.0150184 0.0150342 0.998949
example_dynamic 0.65841 0.658987 0.999124
example_mha_inference 0.081154 0.081219 0.9992
example_tilelang_gemm_fp8_intrinsic 0.933298 0.933871 0.999386
example_gemv 0.284715 0.284847 0.999535
example_gqa_bwd 0.0498069 0.0498218 0.999703
example_gqa_bwd_tma_reduce_varlen 0.0522468 0.0522604 0.99974
example_mha_bwd_bshd_wgmma_pipelined 0.026206 0.0262124 0.999756
example_elementwise_add 0.295773 0.295826 0.999824
example_mha_bwd_bshd 0.0412342 0.0412404 0.999849
example_mha_fwd_varlen 0.0452667 0.0452729 0.999863
example_tilelang_gemm_splitk 1.42259 1.42265 0.999959
example_gemm 0.022657 0.022656 1.00004
example_vertical_slash_sparse_attn 0.237518 0.237484 1.00015
example_linear_attn_fwd 0.0369005 0.0368944 1.00016
example_gemm_schedule 0.0325879 0.032582 1.00018
example_tilelang_gemm_splitk_vectorize_atomicadd 1.42301 1.42272 1.00021
block_sparse_attn_tilelang 0.0102488 0.0102465 1.00022
example_tilelang_gemm_fp8 0.322165 0.322019 1.00045
example_linear_attn_bwd 0.153163 0.153072 1.0006
example_mha_bwd_bhsd 0.0406606 0.0406345 1.00064
example_topk 0.01088 0.010848 1.00295
example_dequant_gemv_fp16xint4 0.0285227 0.0284234 1.00349
example_per_token_cast_to_fp8 0.00745031 0.00741336 1.00498
example_group_per_split_token_cast_to_fp8 0.0102588 0.0102028 1.00548
example_mha_sink_fwd_bhsd_wgmma_pipelined 0.0155482 0.0154496 1.00638
example_mha_sink_fwd_bhsd_sliding_window 0.0158008 0.0157001 1.00641
example_dequant_gemm_w4a8 5.43462 5.39825 1.00674
example_tilelang_nsa_fwd 0.0070482 0.00698836 1.00856
example_gqa_sink_fwd_bhsd_wgmma_pipelined 0.0145623 0.0144372 1.00867
example_tilelang_nsa_decode 0.00738855 0.00732402 1.00881
example_mha_sink_bwd_bhsd 0.0629196 0.0623469 1.00919
example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.014699 0.0145616 1.00944
example_blocksparse_gemm 0.0228748 0.0226602 1.00947
example_mha_sink_bwd_bhsd_sliding_window 0.0452572 0.0448298 1.00953
fp8_lighting_indexer 0.0377058 0.0373444 1.00968
example_tilelang_block_sparse_attn 0.0102545 0.0101517 1.01013
example_mha_sink_fwd_bhsd 0.0160021 0.0158416 1.01013
example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0156188 0.0154576 1.01043
example_tilelang_sparse_gqa_decode_varlen_indice 0.0150063 0.0148441 1.01092
sparse_mla_bwd 0.383305 0.379096 1.0111
sparse_mla_fwd 0.132243 0.130781 1.01118
example_tilelang_sparse_gqa_decode_varlen_mask 0.0215384 0.021299 1.01124
example_convolution 1.34984 1.33327 1.01243
example_gqa_sink_bwd_bhsd 0.0421852 0.0416611 1.01258
example_dequant_groupedgemm_bf16_mxfp4_hopper 3.46832 3.42449 1.0128
example_gqa_sink_bwd_bhsd_sliding_window 0.0259053 0.0255674 1.01322
example_mla_decode 0.467593 0.461097 1.01409
sparse_mla_fwd_pipelined 0.0969941 0.0956304 1.01426
example_dequant_gemm_fp4_hopper 1.08229 1.06514 1.0161
example_dequant_gemm_bf16_fp4_hopper 0.583564 0.573961 1.01673
topk_selector 0.0534994 0.0526175 1.01676
example_dequant_gemm_bf16_mxfp4_hopper 0.518468 0.50875 1.0191

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@testing/python/issue/test_tilelang_issue_1719.py`:
- Around line 50-75: The test test_issue_1719_layout_3 creates CUDA tensors and
must be skipped on CPU-only runners; add the `@tilelang.testing.requires_cuda`
decorator immediately above the test_issue_1719_layout_3 definition to ensure it
only runs when CUDA is available, leaving the rest of the function
(buggy_kernel, kernel invocation, and assertions) unchanged.

LeiWang1999
LeiWang1999 previously approved these changes Jan 27, 2026
Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@src/op/reduce.cc`:
- Around line 389-414: The code unconditionally calls inv.pop_back() after
dst_layout->Inverse()->Forward(dst_th_indices), which breaks when the replicate
component was condensed (i.e. Forward did not append an extra element); guard
the pop by checking the returned inv size against the expected size (e.g. only
call inv.pop_back() if inv.size() > dst_layout->InputDim() or inv.size() ==
dst_layout->InputDim() + 1), and handle the mismatch case appropriately (skip
the pop and/or set predicate to false or bail out) before using inv in the loop
that builds the predicate; update the logic around inv.pop_back(), the
subsequent for-loop comparing inv[i] to dst_vars[i]->var, and keep the
analyzer->Simplify(predicate) call.
- Around line 277-321: When a layout mismatch causes creation of clear_buffer
(need_duplicate=true) but require_init is false (e.g., max/min with
clear=false), the duplicate buffer remains uninitialized; seed it from the
original dst_buffer after decl_buffer. After creating clear_buffer (the
decl_buffer call), if need_duplicate && !require_init, emit elementwise copies
that read from dst_buffer (BufferLoad(dst_buffer, ...)) and write into
clear_buffer (BufferStore(clear_buffer, ...)) using the appropriate index
mapping between dst_layout and red_layout (use the same index expressions you
use elsewhere, e.g., dst/red indices), so the accumulator starts with the
original dst values; alternatively, set require_init=true and initialize from
MakeInitValue if you prefer forcing clear. Ensure you reference decl_buffer,
clear_buffer, dst_buffer, need_duplicate, require_init, BufferLoad and
BufferStore when implementing the fix.
🧹 Nitpick comments (1)
testing/python/issue/test_tilelang_issue_1719.py (1)

96-106: Unused expression on line 104.

tmp2[0] is a standalone expression that reads the value but discards it. If this is intentional to test that the read compiles correctly after reduce_sum, consider adding a brief comment explaining the intent. Otherwise, if this was meant to be an assignment or assertion, please fix accordingly.

💡 If intentional, add a clarifying comment
             T.reduce_sum(tmp1, tmp2)
-            tmp2[0]
+            # Access tmp2[0] to verify layout inference handles post-reduction reads
+            _ = tmp2[0]

@kurisu6912 kurisu6912 merged commit fa9660b into tile-ai:main Jan 28, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Layout Inference on Reduce Ops shouldn't Widen dst_layout

2 participants