[Refactor] Phaseout legacy alloc_local statement in examples and introduce processing for floating fragment buffers#1495
Conversation
…ious examples and operations * Updated multiple files to replace local buffer allocations with variable allocations for improved performance and clarity. * Changed `alloc_local` to `alloc_var` in examples related to attention mechanisms, deep learning models, and GEMM operations. * Enhanced code readability and maintainability by streamlining buffer management across different components. * Ensured consistent handling of buffer scopes and types throughout the codebase.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughRefactors many kernels to replace one-element local tensors with scalar variables or fragment buffers, centralizes buffer-scope checks into predicate helpers, and adds detection/early-layout assignment for fragment buffers accessed outside TileOps. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Areas needing extra attention:
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/deepseek_mla/example_mla_decode.py (1)
170-192: Fix inconsistent scalar variable indexing.On lines 170-173, scalar variables are allocated using
alloc_var, but on line 190,scale_localis accessed with[0]indexing:o_accum_local[i] += po_local[i] * scale_local[0].Since
scale_localis now a scalar variable (not an array), the[0]indexing should be removed.🔎 Proposed fix
for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] + o_accum_local[i] += po_local[i] * scale_localsrc/transform/common/loop_parallel_transform_utils.h (1)
138-162: Consolidate duplicateIsFragmentBufferdefinition for safety.The code correctly uses
IsFragmentBuffer(op->buffer)fromsrc/op/utils.hwhich includes abuffer.defined()guard. However,src/transform/common/loop_fusion_utils.h(lines 70–74) defines a duplicate version without this guard. This inconsistency creates a potential null pointer dereference risk. Consolidate by removing the local definition inloop_fusion_utils.hand using the safe canonical version fromop/utils.h.
🧹 Nitpick comments (3)
examples/gdn/example_chunk_o_bwd.py (2)
278-280: Consider renaming variable for clarity.The variable
dg_last_local_0is reused on lines 278-280 with a completely different semantic meaning:
- Lines 226, 253: Accumulates gradient contributions
- Line 278: Assigned to a gate value from the G tensor
- Line 280: Scaled by an exponential
While this reuse is functionally valid, it makes the code harder to understand. Consider using a separate variable name (e.g.,
G_last_scaled) for lines 278-280 to improve readability.
291-295: Complex conditional expression may benefit from clarification.The conditional expression spans multiple lines and involves computing an exponential difference. Consider extracting the condition
G_last_local - G[bb, bs * block_S + i_s, bh] <= 0into a named variable to improve readability:exp_diff = G_last_local - G[bb, bs * block_S + i_s, bh] dk_fragment[i_s, i_k] = ( dk_fragment[i_s, i_k] * T.exp(exp_diff) if exp_diff <= 0 else 0 )src/op/parallel.cc (1)
195-198: Consider usingDLOG(INFO)instead ofLOG(INFO)for debug output.This logging statement will execute in production builds and may produce excessive output. Based on the surrounding code patterns (e.g., line 402-403 uses
DLOG(INFO)), debug-level logging withDLOGwould be more appropriate here.🔎 Suggested fix
if (IsFragmentBuffer(bl->buffer) && !indice_map_.count(bl->buffer)) { - LOG(INFO) << "ExpandLetBindings: set buffer " << bl->buffer - << " with indices " << bl->indices; + DLOG(INFO) << "ExpandLetBindings: set buffer " << bl->buffer + << " with indices " << bl->indices; indice_map_.Set(bl->buffer, bl->indices); }
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (35)
benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.pyexamples/attention_sink/example_gqa_sink_bwd_bhsd.pyexamples/attention_sink/example_mha_sink_bwd_bhsd.pyexamples/blocksparse_attention/example_tilelang_block_sparse_attn.pyexamples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.pyexamples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.pyexamples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.pyexamples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.pyexamples/deepseek_mla/example_mla_decode.pyexamples/deepseek_mla/example_mla_decode_paged.pyexamples/deepseek_mla/example_mla_decode_persistent.pyexamples/deepseek_mla/example_mla_decode_ws.pyexamples/deepseek_v32/fp8_lighting_indexer.pyexamples/deepseek_v32/sparse_mla_fwd_pipelined.pyexamples/fusedmoe/example_fusedmoe_tilelang.pyexamples/gdn/example_chunk_delta_bwd.pyexamples/gdn/example_chunk_delta_h.pyexamples/gdn/example_chunk_o_bwd.pyexamples/gemm_sp/example_custom_compress.pyexamples/grouped_gemm/example_grouped_gemm_bwd.pyexamples/grouped_gemm/example_grouped_gemm_fwd.pyexamples/minference/example_vertical_slash_sparse_attn.pyexamples/seer_attention/block_sparse_attn_tilelang.pysrc/op/atomic_add.ccsrc/op/copy.ccsrc/op/fill.ccsrc/op/gemm.ccsrc/op/gemm_sp.ccsrc/op/parallel.ccsrc/op/reduce.ccsrc/op/utils.hsrc/transform/common/loop_parallel_transform_utils.hsrc/transform/layout_inference.ccsrc/transform/legalize_safe_memory_access.ccsrc/transform/loop_partition.cc
💤 Files with no reviewable changes (1)
- src/transform/legalize_safe_memory_access.cc
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
examples/seer_attention/block_sparse_attn_tilelang.pyexamples/fusedmoe/example_fusedmoe_tilelang.py
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/op/parallel.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
src/op/gemm.cc
🧬 Code graph analysis (20)
src/transform/common/loop_parallel_transform_utils.h (1)
src/op/utils.h (1)
IsFragmentBuffer(33-35)
src/transform/loop_partition.cc (2)
src/op/utils.h (1)
IsFragmentBuffer(33-35)src/transform/common/loop_fusion_utils.h (1)
IsFragmentBuffer(70-75)
src/op/reduce.cc (2)
src/op/utils.h (1)
IsFragmentBuffer(33-35)src/transform/common/loop_fusion_utils.h (1)
IsFragmentBuffer(70-75)
examples/seer_attention/block_sparse_attn_tilelang.py (4)
src/tl_templates/cuda/reduce.h (1)
T(178-250)tilelang/language/allocate.py (1)
alloc_fragment(72-85)tilelang/language/copy_op.py (1)
copy(14-95)tilelang/language/fill_op.py (1)
fill(9-36)
src/transform/layout_inference.cc (3)
src/op/utils.h (2)
IsFragmentBuffer(33-35)IsLocalBuffer(50-53)src/transform/common/loop_fusion_utils.h (1)
IsFragmentBuffer(70-75)src/layout/layout.cc (2)
FullyReplicated(552-556)FullyReplicated(552-553)
src/op/parallel.cc (2)
src/op/utils.h (1)
IsFragmentBuffer(33-35)src/transform/common/loop_fusion_utils.h (1)
IsFragmentBuffer(70-75)
src/op/gemm.cc (2)
src/op/utils.h (1)
IsFragmentBuffer(33-35)src/transform/common/loop_fusion_utils.h (1)
IsFragmentBuffer(70-75)
src/op/atomic_add.cc (1)
src/op/utils.h (1)
IsFragmentBuffer(33-35)
src/op/utils.h (2)
src/transform/layout_inference.cc (4)
buffer(593-616)buffer(593-593)buffer(924-943)buffer(924-924)src/transform/legalize_safe_memory_access.cc (6)
buffer(61-69)buffer(61-61)buffer(72-115)buffer(72-73)buffer(266-271)buffer(266-266)
examples/deepseek_mla/example_mla_decode_ws.py (2)
tilelang/language/fill_op.py (1)
clear(39-62)tilelang/language/loop.py (2)
Pipelined(58-95)Parallel(13-33)
examples/deepseek_mla/example_mla_decode.py (1)
tilelang/language/allocate.py (3)
alloc_var(89-89)alloc_var(93-93)alloc_var(96-150)
examples/deepseek_mla/example_mla_decode_persistent.py (2)
tilelang/language/loop.py (2)
Pipelined(58-95)Parallel(13-33)src/transform/storage_rewrite.cc (1)
i(720-860)
examples/gemm_sp/example_custom_compress.py (1)
tilelang/language/allocate.py (1)
alloc_shared(40-55)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (4)
tilelang/language/v2/dtypes.py (1)
int32(236-236)tilelang/language/fill_op.py (1)
clear(39-62)tilelang/language/loop.py (1)
Parallel(13-33)src/transform/storage_rewrite.cc (1)
i(720-860)
examples/deepseek_v32/fp8_lighting_indexer.py (2)
tilelang/language/loop.py (1)
Pipelined(58-95)tilelang/language/kernel.py (1)
threads(214-218)
examples/deepseek_mla/example_mla_decode_paged.py (2)
tilelang/language/fill_op.py (1)
clear(39-62)tilelang/language/loop.py (2)
Pipelined(58-95)Parallel(13-33)
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py (3)
tilelang/language/allocate.py (1)
alloc_fragment(72-85)tilelang/language/copy_op.py (1)
copy(14-95)tilelang/language/fill_op.py (1)
fill(9-36)
src/op/gemm_sp.cc (1)
src/op/utils.h (1)
IsFragmentBuffer(33-35)
examples/gdn/example_chunk_o_bwd.py (2)
tilelang/language/allocate.py (2)
alloc_fragment(72-85)alloc_shared(40-55)tilelang/language/fill_op.py (1)
clear(39-62)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (2)
tilelang/language/fill_op.py (1)
clear(39-62)tilelang/language/loop.py (2)
Pipelined(58-95)Parallel(13-33)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (72)
examples/gdn/example_chunk_delta_h.py (2)
159-159: LGTM!Clean refactoring from single-element array to scalar allocation. Using
T.alloc_var(T.float32)is the appropriate pattern for scalar values, eliminating unnecessary array indexing overhead.
204-216: Gating logic refactored correctly.The scalar-based gating logic is cleaner and functionally equivalent:
- Line 204: Direct scalar assignment from the last element of the chunk
- Lines 209-213: Vectorized conditional correctly zeros out contributions where
G_last_local > G_fragment- Lines 214-216: Proper exponential scaling of the hidden state
The constant
1.442695(≈ 1/ln(2)) correctly implementsexp(x)viaexp2(x * 1.442695).examples/fusedmoe/example_fusedmoe_tilelang.py (2)
162-166: LGTM! Scalar allocation refactor improves clarity.The refactor correctly replaces single-element local arrays with direct scalar indexing from tensors. Using
cur_group_idx = group_idx_for_bx[bx]andcur_group_size = group_sizes[cur_group_idx]eliminates unnecessary array indirection and makes the code more straightforward. All subsequent uses of these scalars (inm_start,actual_rows, and tensor indexing) are consistent and correct.Also applies to: 179-179, 187-187
212-216: LGTM! Consistent scalar allocation pattern in Step 2 kernel.The Step 2 kernel applies the same scalar allocation refactor as Step 1, maintaining consistency across the codebase. The pattern is correctly implemented and improves code readability by removing unnecessary array indirection.
Also applies to: 228-228
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py (2)
140-140: LGTM! Appropriate refactoring to fragment allocation.The change from
alloc_localtoalloc_fragmentcorrectly standardizes the block mask allocation to use fragment buffers, which aligns with the PR's refactoring objectives and is appropriate for per-thread mask storage used in downstream fragment operations.
147-147: LGTM! Efficient bulk copy replaces element-by-element loop.The refactoring to use
T.copyfor bulk population of the block mask is more efficient and cleaner than the previous per-element loop. The slice notation correctly extracts the relevant dimension fromBlockSparseMask.benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py (3)
141-141: LGTM! Consistent fragment allocation refactoring.The change to
alloc_fragmentmaintains consistency with the corresponding example file and correctly standardizes block mask allocation.
148-148: LGTM! Efficient bulk copy implementation.The bulk copy using
T.copyis consistent with the example file and provides better performance than element-by-element population.
155-155: LGTM! More explicit boolean comparison.The explicit
!= 0comparison improves code clarity and makes the intent more obvious, especially when working with fragment buffers. This is consistent with the pattern already used in the example file.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (3)
127-131: LGTM! Scalar allocations improve clarity.The refactoring from single-element local arrays to scalar variables using
T.alloc_var()is correct and improves code readability. All subsequent operations have been properly updated to work with scalars instead of array indexing.
133-135: Verify initialization pattern consistency for scalar variables.Two different initialization patterns are used for scalar variables of the same type:
- Line 133:
T.clear(lse_logsum_local)- Line 135:
lse_max_local = -T.infinity(accum_dtype)(direct assignment)Please confirm whether
T.clear()is the intended way to zero-initialize scalars allocated withT.alloc_var(), or if direct assignment is preferred for consistency.
136-156: LGTM! All scalar operations are correctly updated.The refactoring properly removes array indexing throughout the combine function:
- Direct scalar comparisons and assignments (lines 137-140, 143-145, 148-152)
- Scalar arithmetic operations maintain correct semantics
- Fragment operations correctly use scalars in computations (line 154)
The logic flow is preserved while achieving cleaner, more intuitive code.
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
328-332: LGTM! Cleaner scalar access pattern.The refactoring from a single-element local array to direct scalar assignment (
sink = Sinks[bx]) eliminates unnecessary indirection and improves code readability. The scalarsinkis appropriately used in the computation without array indexing.examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
334-338: LGTM! Consistent scalar refactoring.The change mirrors the GQA variant and follows the same pattern: replacing single-element local array allocation with direct scalar assignment. This improves code clarity and eliminates unnecessary indexing operations.
examples/gdn/example_chunk_delta_bwd.py (3)
306-307: LGTM! Scalar allocation simplifies the code.The refactor from single-element local array to scalar variables (
G_last_localandG_last_local_exp) is cleaner and removes unnecessary array indexing. The values are correctly derived fromG_shared[block_S - 1]and used in subsequent computations.
308-313: LGTM! Conditional logic correctly implements numerical stability.The refactor simplifies the conditional update by using a ternary expression. The pattern computes all exponentials first (line 309), then conditionally applies them (lines 311-313). While this computes
exp(G_last_local - G_fragment[i_s2])even when the difference is positive (and the result won't be used), IEEE floating point handles potential overflow gracefully by returning infinity, which is safely discarded when the condition sets the result to 0. This separation also enables better parallelization.
331-331: LGTM! Scalar usage is correct.The direct multiplication by
G_last_local_exp(without array indexing) is correct and consistent with the scalar refactor. This matches the torch reference logic wheredh_tmpis scaled by the exponential of the last gate value.examples/grouped_gemm/example_grouped_gemm_bwd.py (3)
30-31: LGTM: Scalar allocation simplifies the code.The refactor from single-element arrays to scalar variables (
T.alloc_var) is cleaner and aligns with the PR objectives. Type annotations are preserved correctly.
50-51: Nativeifstatements withinT.Parallelare standard across the codebase and represent the intentional direction of TileLang. This pattern appears consistently in multiple examples (grouped_gemm_fwd, grouped_gemm_bwd, fusedmoe) and was explicitly adopted by the TileLang maintainers in PR #162. No code generation concerns apply.
37-41: No action needed—cur_batch_idxis safely initialized.T.alloc_var in TileLang defaults to zero initialization when no explicit init parameter is provided. Line 30 allocates
cur_batch_idxwithout an explicit initializer, so it defaults to zero. The subsequent loop logic on line 37 safely reads this initialized value. The concern about uninitialized variable usage is not valid; the code is correct.examples/grouped_gemm/example_grouped_gemm_fwd.py (4)
64-65: LGTM! Clean refactor to scalar variables.The migration from single-element local arrays to
T.alloc_varscalar allocations simplifies the code by eliminating unnecessary array indexing throughout the batch index computation logic.
73-75: LGTM! Consistent scalar usage.The scalar variables are used correctly throughout the batch index computations and tensor indexing. The refactor maintains logical equivalence while improving readability by removing array indirection.
Also applies to: 80-80
71-71: Verify initialization ofcur_batch_idxbefore use.Line 71 references
cur_batch_idxon the right-hand side of the assignment within the conditional expression. If the first iteration of the loop hasin_cur_batch_idx = false, this would read from an uninitialized variable. While the old array-based allocation would zero-initialize, it's unclear whetherT.alloc_varprovides default initialization.The logic likely prevents this in practice (since
batch_padded_offsets[0]should be 0 andm_start_padded >= 0), but the code would be more robust with explicit initialization.Run the following script to verify whether
T.alloc_varguarantees initialization and check for any related documentation:#!/bin/bash # Description: Search for alloc_var initialization behavior and usage patterns # Check for alloc_var initialization semantics rg -n -C5 'def alloc_var' # Look for other examples where alloc_var is explicitly initialized rg -n -A3 'alloc_var.*=.*T\.alloc_var' --type=py # Check for patterns where alloc_var is used before assignment ast-grep --pattern $'$VAR = T.alloc_var($$$) $$$ $_ = T.if_then_else($_, $_, $VAR)'
84-85:andexamples/deepseek_mla/example_mla_decode_persistent.py (1)
53-56: LGTM! Scalar allocation refactoring is correct.The refactoring from single-element local arrays to scalar variables using
alloc_varis properly implemented. All usages throughout the combine function correctly access these scalars directly without array indexing.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)
124-127: LGTM! Consistent scalar variable refactoring.The conversion from fragment-based to scalar-based storage is correctly implemented. All operations (lines 131-144) properly use direct scalar access without indexing.
examples/deepseek_v32/fp8_lighting_indexer.py (1)
133-142: LGTM! Scalar variable refactoring is correct.The refactoring from array-based to scalar-based min/max tracking is properly implemented. All subsequent usages (lines 147-168) correctly access these scalars directly.
examples/deepseek_mla/example_mla_decode_paged.py (1)
191-211: LGTM! Scalar refactoring is correctly implemented.The scalar variable refactoring is properly completed with consistent direct access throughout the combine function. Line 211 correctly uses
scale_localwithout indexing.examples/deepseek_v32/sparse_mla_fwd_pipelined.py (1)
115-115: No changes needed. The code is correct as written.
T.alloc_var()creates a 1-element buffer (viaT.alloc_buffer([1], dtype, ...)), not a scalar variable. The[0]indexing throughout the code (lines 266, 273, 276, 281, 288, 289, 295, 298, 303) is required and correct—it accesses the single element of the allocated buffer, which is the intended usage pattern.Likely an incorrect or invalid review comment.
src/transform/loop_partition.cc (2)
31-31: LGTM!Adding the include for the centralized
IsFragmentBufferpredicate improves consistency with other parts of the codebase.
222-234: LGTM!The refactoring from direct string comparison (
op->buffer.scope() == "local.fragment") to theIsFragmentBuffer(op->buffer)predicate is a good improvement. This centralizes the fragment buffer detection logic, making it easier to maintain and ensuring consistency across the codebase.src/op/gemm_sp.cc (3)
224-224: LGTM!Using
IsFragmentBuffer(c_)for the precondition check is consistent with the PR's goal of centralizing fragment buffer detection.
276-276: LGTM!The predicate-based check for
a_fragment buffer is consistent with the refactoring pattern.
290-290: LGTM!The predicate-based check for
b_fragment buffer completes the consistent use ofIsFragmentBufferthroughout theInferLayoutmethod.src/transform/common/loop_parallel_transform_utils.h (1)
20-21: LGTM!The include is correctly added to use the centralized
IsFragmentBufferpredicate.src/op/reduce.cc (2)
383-384: LGTM!The refactoring to use
IsFragmentBuffer(src) && IsFragmentBuffer(dst)for the layout inference condition is consistent with the PR's goal of centralizing fragment buffer detection.
521-521: LGTM!The change to use
IsFragmentBuffer(this->src) && IsFragmentBuffer(this->dst)aligns with the predicate-based approach being adopted across the codebase.src/op/fill.cc (1)
159-181: LGTM! Clean refactor to predicate-based buffer scope checks.The transition from string-based scope comparisons to centralized predicates (
IsFragmentBuffer,IsLocalBuffer,IsSharedBuffer,IsGlobalBuffer) improves maintainability and consistency across the codebase.src/op/parallel.cc (4)
209-223: Good optimization: filtering let-binding expansion to only used variables.This change correctly limits the expansion to variables that are actually referenced in
root_, avoiding unnecessary work. The implementation properly collects used vars viaPostOrderVisitand gates expansion accordingly.
151-165: LGTM! Consistent use ofIsFragmentBufferpredicate in visitor methods.The refactoring from
op->buffer.scope() == "local.fragment"toIsFragmentBuffer(op->buffer)is applied consistently in bothBufferStoreNodeandBufferLoadNodevisitors.
275-336: LGTM! Consistent predicate usage across layout inference logic.The
IsFragmentBufferpredicate is now used consistently throughout the layout inference code paths, replacing all direct string comparisons.
15-15: Duplicate definition claim is inaccurate.The
IsFragmentBufferfunction exists in both headers, but the implementations are NOT identical:
src/op/utils.h(inline):return buffer.defined() && buffer.scope() == "local.fragment";src/transform/common/loop_fusion_utils.h(non-inline):return scope == "local.fragment";The non-inline version in
loop_fusion_utils.homits thebuffer.defined()check. Additionally, one is marked inline while the other is not, creating a more serious concern: the inline specifier cannot re-declare a function that was already defined in the translation unit as non-inline.This is a pre-existing issue in files that include both headers (e.g.,
layout_inference.cc,fill.cc,atomic_add.cc,copy.cc), but the current change toparallel.cconly includes the inline version fromutils.hand does not introduce new problems.Likely an incorrect or invalid review comment.
src/op/gemm.cc (5)
523-533: LGTM! Correct use ofIsFragmentBufferfor A, B, C buffer classification inLower().The refactoring correctly identifies fragment buffers for determining the GEMM operation variant (
gemm_rs,gemm_sr,gemm_ss) and validates that C is a fragment buffer.
605-617: LGTM! Consistent predicate usage in Volta target path.The checks for C and A buffer scopes are correctly refactored to use
IsFragmentBuffer.
633-667: LGTM! Ampere/Turing/SM120/SM100 path correctly updated.Fragment buffer checks for A and B in the MMA code path use the centralized predicate consistently.
669-714: LGTM! Hopper target path updated consistently.The
IsFragmentBuffer(c_)check and the implicit else branches for fragment A/B are handled correctly.
775-809: LGTM! CDNA target path uses predicates correctly.The fragment buffer checks for C, A, and B in the MFMA code path are consistently updated.
src/op/atomic_add.cc (3)
328-339: LGTM! Layout comparison guard updated to use predicates.The condition correctly uses
IsFragmentBufferfor both source and destination buffers when comparing fragment layouts.
434-445: LGTM! Visitor methods inAtomicLoopNestCollectoruse predicates consistently.Both
BufferStoreNodeandBufferLoadNodevisitors correctly useIsFragmentBufferfor buffer classification.
476-477: LGTM! Fragment buffer check in layout inference loop.The negated check
!IsFragmentBuffer(buf)correctly replaces the previous string comparison.examples/minference/example_vertical_slash_sparse_attn.py (4)
133-136: LGTM! Correct migration from single-element arrays to scalar variables.Replacing
T.alloc_local([1], dtype)withT.alloc_var(dtype=int_dtype)forblock_countandcolumn_counteliminates unnecessary array overhead for scalar values. This aligns with the PR objective of phasing out legacyalloc_localfor single-element allocations.
146-147: LGTM! Direct scalar assignment.The assignments now use direct scalar semantics instead of array indexing, which is cleaner and more efficient.
163-177: LGTM! Loop bounds and control flow correctly use scalar values.The
T.serial(block_count)loop bound and other usages correctly operate on the scalar variable directly.
210-285: LGTM! Conditional logic and macro calls updated for scalar semantics.All conditional checks (
column_count != 0), loop bounds (T.ceildiv(column_count, block_N)), and macro parameter passing correctly use the scalarcolumn_countvalue throughout the prefetch and compute sections.src/op/copy.cc (6)
776-780: LGTM! Clean refactoring to use predicate.The replacement of direct scope string comparison with
IsFragmentBuffer(dst)improves code maintainability and consistency across the codebase.
793-796: LGTM! Consistent predicate usage.The change aligns with the refactoring pattern established throughout the file.
808-826: LGTM! Tensor memory checks properly refactored.Both
CheckTMemLoadandCheckTMemStorenow use the centralizedIsFragmentBufferpredicate, maintaining consistency with other memory operation checks in the file.
953-958: Refined warning logic for local buffer copies.The changes improve precision by:
- Using the
IsLocalBufferpredicate for cleaner buffer type checking- Narrowing the warning to only fire for local→non-local copies (which can cause conflicted writes)
- Eliminating false-positive warnings for non-local→local copies
This is an intentional behavior refinement, not just a refactoring.
1234-1243: LGTM! Tensor memory copy type detection refactored.The conditions for detecting
tcgen05.ldandtcgen05.stoperations now use the centralized predicate, maintaining consistency with the rest of the codebase.
2063-2086: LGTM! Fragment layout collection uses predicate.The change in
CollectFragmentLayoutsaligns with the overall refactoring to use centralized buffer-type predicates.src/op/utils.h (1)
32-53: Well-designed buffer-type predicates.The predicates provide a clean abstraction for buffer classification:
- All predicates defensively check
buffer.defined()before accessing scopeIsLocalBuffercorrectly handles both"local"and"local.var"scopesIsSharedBufferprovides flexibility with theallow_dynamicparameter (defaulting totrue)- Inline implementation ensures zero overhead
These centralized predicates improve maintainability and reduce the risk of inconsistent scope checks across the codebase.
src/transform/layout_inference.cc (7)
23-23: LGTM! Required include for buffer-type predicates.The include of
utils.hprovides access to the centralizedIsFragmentBufferandIsLocalBufferpredicates used throughout this file.
174-175: LGTM! Consistent predicate usage for fragment buffer detection.All replacements of direct scope string comparisons with
IsFragmentBuffer(buffer)are functionally equivalent and improve code maintainability.Also applies to: 383-383, 595-595, 790-790
312-321: Well-designed floating fragment buffer initialization.Step 0 correctly handles fragment buffers accessed outside TileOps by:
- Assigning fully replicated layouts (since access patterns cannot be inferred from TileOp semantics)
- Using the captured
thread_boundsfrom the access point- Respecting pre-existing layouts in
layout_mapThis preprocessing step runs before strict layout inference (step 1), ensuring these buffers have valid layouts throughout the inference pipeline.
436-438: Correct placement of floating buffer detection.Computing floating fragment buffers after the main collection traversal ensures that the set of TileOp nodes is complete before scanning for accesses outside those operations.
862-964: Robust floating fragment buffer detection algorithm.The implementation correctly identifies fragment buffers accessed outside TileOps:
- Node collection (lines 884-890): Uses
PostOrderVisitto build a complete set of nodes inside TileOps- Visitor pattern (lines 894-949): Tracks thread context while scanning for fragment buffer accesses
- Access classification (lines 924-943): Efficiently checks if accesses are floating with early returns for non-fragment buffers, TileOp-internal accesses, and duplicates
- Thread bounds capture: Extracts bounds from the analyzer with appropriate fallback to
Range(0,1)The comprehensive documentation (lines 862-877) with a concrete example makes the logic easy to understand and maintain.
971-977: Well-documented data member for floating buffers.The data structure appropriately:
- Uses
std::unordered_mapfor efficient lookup- Stores
Buffer → Rangemappings to track thread bounds per buffer- Employs correct hash/equality predicates (
ObjectPtrHash,ObjectPtrEqual) for Buffer keys- Includes clear documentation linking to
ComputeFloatingFragmentBuffers()
1245-1359: Correct buffer classification in loop transformation logic.The
VisitStmt_(ForNode)method properly usesIsLocalBufferandIsFragmentBufferpredicates to:
- Lines 1268, 1288, 1292: Detect register-local-only loops that don't require thread partitioning
- Lines 1313, 1318: Identify non-local memory accesses requiring vectorization
The refactored logic maintains the correct semantics for loop partitioning and vectorization decisions.
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (3)
138-142: Good refactor from fragment to scalar allocations.The change from single-element fragment allocations to scalar variables using
T.alloc_var()removes unnecessary indirection and improves code clarity. This aligns well with the PR objectives.However, see the critical initialization issue flagged below for
max_split.
153-167: Scalar refactor correctly applied in reduction and accumulation loops.The scalar operations in the log-sum-exp reduction (lines 153-157) and output accumulation (lines 158-165) are correctly refactored from array-indexed operations to direct scalar assignments and arithmetic. The logic is preserved and the code is more readable without the unnecessary
[0]indexing.The only issue is the reliance on
max_splitwhich may be uninitialized (flagged above).
144-146: No changes needed. The code correctly usesT.clear()on variables created byT.alloc_var(), which returns a 1-element buffer, not a plain scalar. TheT.clear()API is designed for buffers and properly initializes them to zero. This usage pattern is consistent throughout the codebase and is correctly documented in the type signatures.Likely an incorrect or invalid review comment.
examples/seer_attention/block_sparse_attn_tilelang.py (1)
111-111: Verify fragment buffer performance characteristics with conditional mask access.The change from
T.alloc_localtoT.alloc_fragmentfor block_mask (line 111) is syntactically correct and the bulkT.copyinitialization (line 118) is well-formed. However, fragment memory is designed for accumulators in register files to minimize latency in tensor core operations, while this mask is only read in a conditional check at line 123.The refactoring aligns with the PR's systematic standardization of fragment buffer usage, but this unconventional use case (scalar conditional reads from a fragment buffer) should be validated to ensure the TileLang compiler generates efficient code and doesn't introduce unexpected performance overhead or compiler issues.
| for k in T.serial(num_split): | ||
| lse_local_split[0] = glse[bz, by, k] | ||
| if lse_local_split[0] != 0: | ||
| max_split[0] = k | ||
| lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) | ||
| lse_local_split = glse[bz, by, k] | ||
| if lse_local_split != 0: | ||
| max_split = k | ||
| lse_max_local = T.max(lse_max_local, glse[bz, by, k]) |
There was a problem hiding this comment.
Critical: max_split used uninitialized when all lse values are zero.
At line 150, max_split is only assigned inside the conditional if lse_local_split != 0. If all log-sum-exp values are zero (which can occur when no valid blocks are processed in any split), max_split is never initialized but is subsequently used in conditionals at lines 154 and 159.
In the previous array-based implementation, max_split[0] would have been zero-initialized, providing defined behavior. The scalar refactor introduces undefined behavior in this edge case.
🔎 Proposed fix: Initialize max_split before the loop
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local = -T.infinity(accum_dtype)
+max_split = -1
for k in T.serial(num_split):
lse_local_split = glse[bz, by, k]
if lse_local_split != 0:
max_split = k
lse_max_local = T.max(lse_max_local, glse[bz, by, k])Initializing to -1 ensures that if no valid splits are found, the conditions k <= max_split at lines 154 and 159 will be false for all k ≥ 0, correctly skipping the accumulation loops.
🤖 Prompt for AI Agents
In examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
around lines 147 to 151, max_split is only assigned inside the if and can remain
uninitialized when all lse values are zero; initialize max_split to -1
immediately before the for k in T.serial(num_split) loop so that when no valid
splits are found the later checks (k <= max_split) are false and accumulation
loops are skipped, ensuring defined behavior in the all-zero case.
| lse_local_split = T.alloc_var(accum_dtype) | ||
| lse_logsum_local = T.alloc_var(accum_dtype) | ||
| lse_max_local = T.alloc_var(accum_dtype) | ||
| scale_local = T.alloc_var(accum_dtype) | ||
|
|
||
| T.clear(lse_logsum_local) | ||
| T.clear(o_accum_local) | ||
| lse_max_local[0] = -T.infinity(accum_dtype) | ||
| lse_max_local = -T.infinity(accum_dtype) | ||
| for k in T.serial(num_split): | ||
| lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) | ||
| lse_max_local = T.max(lse_max_local, glse[bz, by, k]) | ||
| for k in T.Pipelined(num_split, num_stages=1): | ||
| lse_local_split[0] = glse[bz, by, k] | ||
| lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) | ||
| lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] | ||
| lse_local_split = glse[bz, by, k] | ||
| lse_logsum_local += T.exp2(lse_local_split - lse_max_local) | ||
| lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local | ||
| for k in T.serial(num_split): | ||
| for i in T.Parallel(dim): | ||
| po_local[i] = Output_partial[bz, by, k, i] | ||
| lse_local_split[0] = glse[bz, by, k] | ||
| scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) | ||
| lse_local_split = glse[bz, by, k] | ||
| scale_local = T.exp2(lse_local_split - lse_logsum_local) | ||
| for i in T.Parallel(dim): | ||
| o_accum_local[i] += po_local[i] * scale_local[0] |
There was a problem hiding this comment.
Bug: scale_local[0] indexing on a scalar variable.
scale_local is allocated as a scalar via T.alloc_var(accum_dtype) at line 179, but is incorrectly indexed as scale_local[0] at line 196. This will cause an error since scalars cannot be indexed.
🔎 Proposed fix
for i in T.Parallel(dim):
- o_accum_local[i] += po_local[i] * scale_local[0]
+ o_accum_local[i] += po_local[i] * scale_local📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| lse_local_split = T.alloc_var(accum_dtype) | |
| lse_logsum_local = T.alloc_var(accum_dtype) | |
| lse_max_local = T.alloc_var(accum_dtype) | |
| scale_local = T.alloc_var(accum_dtype) | |
| T.clear(lse_logsum_local) | |
| T.clear(o_accum_local) | |
| lse_max_local[0] = -T.infinity(accum_dtype) | |
| lse_max_local = -T.infinity(accum_dtype) | |
| for k in T.serial(num_split): | |
| lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) | |
| lse_max_local = T.max(lse_max_local, glse[bz, by, k]) | |
| for k in T.Pipelined(num_split, num_stages=1): | |
| lse_local_split[0] = glse[bz, by, k] | |
| lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) | |
| lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] | |
| lse_local_split = glse[bz, by, k] | |
| lse_logsum_local += T.exp2(lse_local_split - lse_max_local) | |
| lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local | |
| for k in T.serial(num_split): | |
| for i in T.Parallel(dim): | |
| po_local[i] = Output_partial[bz, by, k, i] | |
| lse_local_split[0] = glse[bz, by, k] | |
| scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) | |
| lse_local_split = glse[bz, by, k] | |
| scale_local = T.exp2(lse_local_split - lse_logsum_local) | |
| for i in T.Parallel(dim): | |
| o_accum_local[i] += po_local[i] * scale_local[0] | |
| lse_local_split = T.alloc_var(accum_dtype) | |
| lse_logsum_local = T.alloc_var(accum_dtype) | |
| lse_max_local = T.alloc_var(accum_dtype) | |
| scale_local = T.alloc_var(accum_dtype) | |
| T.clear(lse_logsum_local) | |
| T.clear(o_accum_local) | |
| lse_max_local = -T.infinity(accum_dtype) | |
| for k in T.serial(num_split): | |
| lse_max_local = T.max(lse_max_local, glse[bz, by, k]) | |
| for k in T.Pipelined(num_split, num_stages=1): | |
| lse_local_split = glse[bz, by, k] | |
| lse_logsum_local += T.exp2(lse_local_split - lse_max_local) | |
| lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local | |
| for k in T.serial(num_split): | |
| for i in T.Parallel(dim): | |
| po_local[i] = Output_partial[bz, by, k, i] | |
| lse_local_split = glse[bz, by, k] | |
| scale_local = T.exp2(lse_local_split - lse_logsum_local) | |
| for i in T.Parallel(dim): | |
| o_accum_local[i] += po_local[i] * scale_local |
🤖 Prompt for AI Agents
In examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py around lines
176 to 196, scale_local is allocated as a scalar (T.alloc_var(accum_dtype)) but
later used with an index (scale_local[0]) which is invalid; change the usage to
treat scale_local as a scalar (remove the [0] indexing) wherever it is used
(e.g., replace scale_local[0] with scale_local) so the multiplication into
o_accum_local uses the scalar value, or alternatively allocate scale_local as a
1-element buffer if indexing is required—prefer the scalar fix for minimal
change.
| lse_local_split = T.alloc_var(accum_dtype) | ||
| lse_logsum_local = T.alloc_var(accum_dtype) | ||
| lse_max_local = T.alloc_var(accum_dtype) | ||
| scale_local = T.alloc_var(accum_dtype) | ||
|
|
||
| T.clear(lse_logsum_local) | ||
| T.clear(o_accum_local) | ||
| lse_max_local[0] = -T.infinity(accum_dtype) | ||
| lse_max_local = -T.infinity(accum_dtype) | ||
| for k in T.serial(num_split): | ||
| lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k]) | ||
| lse_max_local = T.max(lse_max_local, glse[bz, hid, k]) | ||
| for k in T.Pipelined(num_split, num_stages=1): | ||
| lse_local_split[0] = glse[bz, hid, k] | ||
| lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) | ||
| lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] | ||
| lse_local_split = glse[bz, hid, k] | ||
| lse_logsum_local += T.exp2(lse_local_split - lse_max_local) | ||
| lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local | ||
| for k in T.serial(num_split): | ||
| for i in T.Parallel(dim): | ||
| po_local[i] = Output_partial[bz, hid, k, i] | ||
| lse_local_split[0] = glse[bz, hid, k] | ||
| scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) | ||
| lse_local_split = glse[bz, hid, k] | ||
| scale_local = T.exp2(lse_local_split - lse_logsum_local) | ||
| for i in T.Parallel(dim): | ||
| o_accum_local[i] += po_local[i] * scale_local[0] |
There was a problem hiding this comment.
Fix inconsistent scalar variable indexing.
Similar to example_mla_decode.py, scalar variables are allocated with alloc_var on lines 475-478, but line 495 still uses array indexing: o_accum_local[i] += po_local[i] * scale_local[0].
Since scale_local is now a scalar variable, remove the [0] indexing.
🔎 Proposed fix
for i in T.Parallel(dim):
- o_accum_local[i] += po_local[i] * scale_local[0]
+ o_accum_local[i] += po_local[i] * scale_local📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| lse_local_split = T.alloc_var(accum_dtype) | |
| lse_logsum_local = T.alloc_var(accum_dtype) | |
| lse_max_local = T.alloc_var(accum_dtype) | |
| scale_local = T.alloc_var(accum_dtype) | |
| T.clear(lse_logsum_local) | |
| T.clear(o_accum_local) | |
| lse_max_local[0] = -T.infinity(accum_dtype) | |
| lse_max_local = -T.infinity(accum_dtype) | |
| for k in T.serial(num_split): | |
| lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k]) | |
| lse_max_local = T.max(lse_max_local, glse[bz, hid, k]) | |
| for k in T.Pipelined(num_split, num_stages=1): | |
| lse_local_split[0] = glse[bz, hid, k] | |
| lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) | |
| lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] | |
| lse_local_split = glse[bz, hid, k] | |
| lse_logsum_local += T.exp2(lse_local_split - lse_max_local) | |
| lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local | |
| for k in T.serial(num_split): | |
| for i in T.Parallel(dim): | |
| po_local[i] = Output_partial[bz, hid, k, i] | |
| lse_local_split[0] = glse[bz, hid, k] | |
| scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) | |
| lse_local_split = glse[bz, hid, k] | |
| scale_local = T.exp2(lse_local_split - lse_logsum_local) | |
| for i in T.Parallel(dim): | |
| o_accum_local[i] += po_local[i] * scale_local[0] | |
| lse_local_split = T.alloc_var(accum_dtype) | |
| lse_logsum_local = T.alloc_var(accum_dtype) | |
| lse_max_local = T.alloc_var(accum_dtype) | |
| scale_local = T.alloc_var(accum_dtype) | |
| T.clear(lse_logsum_local) | |
| T.clear(o_accum_local) | |
| lse_max_local = -T.infinity(accum_dtype) | |
| for k in T.serial(num_split): | |
| lse_max_local = T.max(lse_max_local, glse[bz, hid, k]) | |
| for k in T.Pipelined(num_split, num_stages=1): | |
| lse_local_split = glse[bz, hid, k] | |
| lse_logsum_local += T.exp2(lse_local_split - lse_max_local) | |
| lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local | |
| for k in T.serial(num_split): | |
| for i in T.Parallel(dim): | |
| po_local[i] = Output_partial[bz, hid, k, i] | |
| lse_local_split = glse[bz, hid, k] | |
| scale_local = T.exp2(lse_local_split - lse_logsum_local) | |
| for i in T.Parallel(dim): | |
| o_accum_local[i] += po_local[i] * scale_local |
🤖 Prompt for AI Agents
In examples/deepseek_mla/example_mla_decode_ws.py around lines 475 to 495,
scalar locals lse_local_split, lse_logsum_local, lse_max_local, and scale_local
are allocated with alloc_var but the code treats scale_local as an array at line
495 using scale_local[0]; update the usage to treat scale_local as a scalar by
removing the [0] index (i.e., use scale_local directly in the multiplication)
and verify no other array-style indexing remains for these alloc_var scalars.
| dg_last_local_0 = T.alloc_var(dtype=gate_dtype) | ||
| dg_last_local_1 = T.alloc_var(dtype=gate_dtype) | ||
| G_last_local = T.alloc_var(dtype=gate_dtype) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find and examine the specific file
fd -name "example_chunk_o_bwd.py" -type fRepository: tile-ai/tilelang
Length of output: 290
🏁 Script executed:
#!/bin/bash
# Search for the file and examine relevant lines
find . -path "*examples/gdn*" -name "*bwd*" -type f 2>/dev/null | head -5Repository: tile-ai/tilelang
Length of output: 183
🏁 Script executed:
#!/bin/bash
# Examine the specific file, focusing on lines around 202-204, 228, and 293
wc -l ./examples/gdn/example_chunk_o_bwd.pyRepository: tile-ai/tilelang
Length of output: 101
🏁 Script executed:
#!/bin/bash
# Read the relevant sections of the file
echo "=== Lines 200-210 (around allocation at 202-204) ==="
sed -n '200,210p' ./examples/gdn/example_chunk_o_bwd.py
echo ""
echo "=== Lines 225-235 (around line 228) ==="
sed -n '225,235p' ./examples/gdn/example_chunk_o_bwd.py
echo ""
echo "=== Lines 290-300 (around line 293) ==="
sed -n '290,300p' ./examples/gdn/example_chunk_o_bwd.py
echo ""
echo "=== Full context: lines 200-320 to understand the flow ==="
sed -n '200,320p' ./examples/gdn/example_chunk_o_bwd.py | cat -nRepository: tile-ai/tilelang
Length of output: 10081
🏁 Script executed:
#!/bin/bash
# Let me see more context around where G_last_local should be assigned
echo "=== Checking lines 74-82 for G_last_local assignment ==="
sed -n '74,82p' ./examples/gdn/example_chunk_o_bwd.py | cat -nRepository: tile-ai/tilelang
Length of output: 540
🏁 Script executed:
#!/bin/bash
# Check the full context where G_last_local is being used
echo "=== Lines 74-105 context ==="
sed -n '274,310p' ./examples/gdn/example_chunk_o_bwd.py | cat -nRepository: tile-ai/tilelang
Length of output: 2720
🏁 Script executed:
#!/bin/bash
# Search for all assignments to G_last_local in the file
echo "=== All lines where G_last_local is used/assigned ==="
grep -n "G_last_local" ./examples/gdn/example_chunk_o_bwd.pyRepository: tile-ai/tilelang
Length of output: 397
Initialize G_last_local before usage on lines 292-293.
The scalar variable G_last_local is allocated on line 204 and cleared on line 228, but it is never assigned a meaningful value before being used on lines 292-293 in the expression G_last_local - G[bb, bs * block_S + i_s, bh]. Since it remains at the cleared (zero) value, this produces incorrect computation. Assign G_last_local from the G tensor before its usage, similar to how dg_last_local_0 is initialized from G on line 279 (e.g., G_last_local = G[bb, bs * block_S + block_S - 1, bh]).
🤖 Prompt for AI Agents
In examples/gdn/example_chunk_o_bwd.py around lines 202-204 (allocation) and
usages at ~228 (cleared), ~279 (dg init) and ~292-293 (use), G_last_local is
allocated and cleared but never assigned before being used in the expression
G_last_local - G[bb, bs * block_S + i_s, bh]; set G_last_local from the G tensor
prior to its usage (mirror dg_last_local_0 initialization) so it holds G[bb, bs
* block_S + block_S - 1, bh] (or the appropriate last-index value for the
current block) before the subtraction to ensure correct computation.
| # TODO: alloc_var seems buggy here | ||
| non_zero_cnt = T.alloc_local((1,), dtype=T.uint8) | ||
| non_zero_elt_log_idx = T.alloc_local((elem,), dtype=T.uint8) | ||
| non_zero_cnt = T.alloc_var(dtype=T.uint8) | ||
| non_zero_elt_log_idx = T.alloc_shared((elem,), dtype=T.uint8) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Read the file around the lines in question
echo "=== Lines 250-285 of example_custom_compress.py ==="
sed -n '250,285p' examples/gemm_sp/example_custom_compress.py
# Check if the file exists
echo -e "\n=== File check ==="
ls -la examples/gemm_sp/example_custom_compress.py 2>&1 || echo "File not found"Repository: tile-ai/tilelang
Length of output: 2160
🏁 Script executed:
#!/bin/bash
# Search for alloc_var definition and usage patterns
echo "=== Searching for T.alloc_var definitions ==="
rg -n "def alloc_var" --type py
echo -e "\n=== Searching for alloc_var usages in the file ==="
rg -n "alloc_var" examples/gemm_sp/ --type py -B 2 -A 2
echo -e "\n=== Searching for alloc_var in broader codebase ==="
rg -n "alloc_var" --type py | head -30Repository: tile-ai/tilelang
Length of output: 4487
🏁 Script executed:
#!/bin/bash
# Read the alloc_var implementation
echo "=== alloc_var implementation in allocate.py ==="
sed -n '85,130p' tilelang/language/allocate.py
# Also check for any comments about scope or parallel safety
echo -e "\n=== Looking for scope/parallel-related comments ==="
grep -n -i "scope\|parallel\|thread" tilelang/language/allocate.py | head -20Repository: tile-ai/tilelang
Length of output: 3795
🏁 Script executed:
#!/bin/bash
# Look at test cases to understand behavior
echo "=== Test case for alloc_var usage ==="
head -100 testing/python/language/test_tilelang_language_frontend_v2.py | grep -A 30 "def alloc_var\|T.alloc_var"
echo -e "\n=== Test with parallel context ==="
grep -B 5 -A 15 "T.Parallel.*alloc_var\|alloc_var.*T.Parallel" testing/python/language/ -r 2>/dev/null || echo "No parallel+alloc_var pattern found in tests"Repository: tile-ai/tilelang
Length of output: 178
🏁 Script executed:
#!/bin/bash
# Read more context around the allocate.py to see full implementation
echo "=== Full alloc_var implementation ==="
sed -n '96,155p' tilelang/language/allocate.py
# Check if there are any docs about scope semantics
echo -e "\n=== Looking for alloc_buffer wrapper to understand scope handling ==="
grep -n "def alloc_buffer" tilelang/language/ -rRepository: tile-ai/tilelang
Length of output: 3017
🏁 Script executed:
#!/bin/bash
# Read the full context of the parallel loop and variable usage
echo "=== Full context around parallel loop (lines 260-285) ==="
sed -n '260,285p' examples/gemm_sp/example_custom_compress.pyRepository: tile-ai/tilelang
Length of output: 1544
🏁 Script executed:
#!/bin/bash
# Confirm the exact usage pattern and race condition
echo "=== Detailed line-by-line analysis ==="
sed -n '264,280p' examples/gemm_sp/example_custom_compress.py | cat -n
echo -e "\n=== Summary of variable usage in parallel loop ==="
echo "non_zero_cnt (alloc_var): Line 268 (assign), 276 (increment)"
echo "non_zero_elt_log_idx (alloc_shared): Lines 269, 272, 277, 278 (multiple writes)"
echo ""
echo "Each thread in T.Parallel(block_M) has independent tm value"
echo "But both variables are allocated ONCE before the loop"
echo "non_zero_elt_log_idx with alloc_shared = SHARED MEMORY"
echo "Multiple threads write to same shared array = RACE CONDITION"Repository: tile-ai/tilelang
Length of output: 1578
Resolve TODO and verify thread-safety of allocations within parallel loop.
The TODO on line 260 flags a legitimate concern. Both non_zero_cnt and non_zero_elt_log_idx are allocated before the parallel loop and used within it, creating concurrency hazards:
-
Line 262 race condition:
non_zero_elt_log_idxusesT.alloc_shared, creating shared memory accessible by all threads. Multiple threads write to this array concurrently (lines 269, 272, 277–278) without synchronization or thread-local indexing, causing data races. -
Line 261 uncertainty:
non_zero_cntusesT.alloc_var(default "local.var" scope), which should provide thread-local storage. However, the TODO suggests the semantics may not be working as intended in this parallel context. The reassignment at line 268 and increment at line 276 need guaranteed thread-local isolation.
Before merging, verify or fix:
- Replace
alloc_sharedfornon_zero_elt_log_idxwith thread-local allocation, or add thread-safety mechanisms. - Confirm
alloc_varprovides proper thread-local semantics withinT.Parallelloops, or replace with an appropriate thread-local allocation.
…ample * Updated index handling in `sparse_mla_fwd_pipelined.py` to eliminate unnecessary local array usage, improving code clarity and performance. * Replaced instances of `indices_local[0]` with direct usage of `indices_local` for better readability and consistency in buffer access. * Commented out the main execution call in the GDN test script to focus on the specific test function, enhancing test clarity.
|
/regression-perf |
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (1)
148-152: Critical:max_splitremains uninitialized - previous review issue not addressed.This critical issue was flagged in a previous review but has not been fixed. At line 151,
max_splitis only assigned whenlse_local_split != 0. If all log-sum-exp values are zero (when no valid blocks are processed in any split),max_splitis never initialized but is used in conditionals at lines 155 and 160.The previous array-based implementation would have zero-initialized
max_split[0], providing defined behavior. The scalar refactor introduces undefined behavior in this edge case.🔎 Required fix: Initialize max_split before the loop
T.clear(o_accum_local) lse_max_local = -T.infinity(accum_dtype) +max_split = -1 for k in T.serial(num_split): lse_local_split = glse[bz, by, k] if lse_local_split != 0: max_split = k lse_max_local = T.max(lse_max_local, glse[bz, by, k])Initializing to
-1ensures that if no valid splits are found, the conditionsk <= max_splitat lines 155 and 160 will be false for all k ≥ 0, correctly skipping the accumulation loops.
🧹 Nitpick comments (2)
examples/deepseek_mla/example_mla_decode.py (1)
175-177: Consider consistent initialization pattern for scalar variables.Line 175 uses
T.clear(lse_logsum_local)while line 177 uses direct assignmentlse_max_local = -T.infinity(accum_dtype)to initialize scalar variables. Both approaches work, but for consistency, consider using the same pattern for all scalar variable initializations.Direct assignment (as on line 177) may be more idiomatic for scalars and aligns better with the scalar variable semantics.
🔎 Suggested refactor for consistency
- T.clear(lse_logsum_local) + lse_logsum_local = T.cast(0, accum_dtype) T.clear(o_accum_local) lse_max_local = -T.infinity(accum_dtype)Alternatively, if
T.clear()is the preferred TileLang pattern for initialization:T.clear(lse_logsum_local) T.clear(o_accum_local) - lse_max_local = -T.infinity(accum_dtype) + T.clear(lse_max_local) + lse_max_local = -T.infinity(accum_dtype)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (1)
145-147: Consider using direct assignment for scalar initialization.Line 145 uses
T.clear()on the scalar variablelse_logsum_local, while line 147 uses direct assignment forlse_max_local. For consistency and idiomatic style, consider initializinglse_logsum_localwith direct assignment:lse_logsum_local = 0🔎 Proposed refactor for consistency
-T.clear(lse_logsum_local) +lse_logsum_local = 0 T.clear(o_accum_local) lse_max_local = -T.infinity(accum_dtype)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
examples/attention_sink/example_gqa_sink_bwd_bhsd.pyexamples/attention_sink/example_mha_sink_bwd_bhsd.pyexamples/blocksparse_attention/example_tilelang_block_sparse_attn.pyexamples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.pyexamples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.pyexamples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.pyexamples/deepseek_mla/example_mla_decode.pyexamples/deepseek_v32/fp8_lighting_indexer.pyexamples/deepseek_v32/sparse_mla_fwd_pipelined.pyexamples/fusedmoe/example_fusedmoe_tilelang.pyexamples/gdn/test_example_gdn_compilation.pyexamples/minference/example_vertical_slash_sparse_attn.pyexamples/seer_attention/block_sparse_attn_tilelang.py
🚧 Files skipped from review as they are similar to previous changes (8)
- examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
- examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
- examples/deepseek_v32/fp8_lighting_indexer.py
- examples/fusedmoe/example_fusedmoe_tilelang.py
- examples/seer_attention/block_sparse_attn_tilelang.py
- examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
- examples/attention_sink/example_mha_sink_bwd_bhsd.py
- examples/attention_sink/example_gqa_sink_bwd_bhsd.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
examples/gdn/test_example_gdn_compilation.py
🧬 Code graph analysis (2)
examples/deepseek_mla/example_mla_decode.py (1)
tilelang/language/allocate.py (3)
alloc_var(89-89)alloc_var(93-93)alloc_var(96-150)
examples/deepseek_v32/sparse_mla_fwd_pipelined.py (1)
tilelang/language/allocate.py (3)
alloc_var(89-89)alloc_var(93-93)alloc_var(96-150)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (11)
examples/gdn/test_example_gdn_compilation.py (1)
319-320: Verify if this change is intentional for production.This change replaces the test harness with a direct call to a single test function, meaning the other 6 test functions in this file (
test_example_wy_fast_compilation,test_example_wy_fast_bwd_split_compilation,test_example_chunk_o_compilation,test_example_chunk_o_bwd_compilation,test_example_chunk_scaled_dot_kkt_compilation,test_example_cumsum_compilation,test_example_chunk_delta_h_compilation) will no longer execute when running this script directly.While the commit message mentions "focus testing," the use of a comment (rather than deletion) suggests this might be a temporary debugging change. If this is intended to be permanent, please clarify the rationale for running only this specific test function.
examples/deepseek_v32/sparse_mla_fwd_pipelined.py (2)
115-115: LGTM: Clean refactoring to scalar variable.The change from a 1-element local array to a scalar variable (
T.alloc_var) is correct and improves code clarity by eliminating unnecessary array indirection.
266-305: LGTM: All usages correctly updated.All reads and writes to
indices_localhave been correctly updated to work with the scalar variable, removing the unnecessary[0]indexing. The refactoring maintains the original logic while improving readability.examples/minference/example_vertical_slash_sparse_attn.py (4)
133-135: LGTM! Clean refactoring to scalar variables.The migration from single-element arrays to scalar variables using
T.alloc_varimproves code clarity by eliminating unnecessary array indirection forblock_countandcolumn_count.
146-147: LGTM! Scalar assignments are correct.The direct scalar assignments properly replace the previous array-indexed pattern and work correctly with the
T.alloc_varallocations.
163-177: LGTM! Loop constructs correctly use scalar values.Both
T.serial(block_count)loop constructs (lines 163, 177) correctly consume the scalar variable, eliminating the previous[0]indexing.
210-285: LGTM! Comprehensive scalar usage throughout the conditional block.All references to
column_countwithin the conditional block have been correctly updated to use the scalar variable:
- Conditional check (line 210)
- Function parameter passing to
PrefetchandComputemacros- Arithmetic operations in
T.ceildiv(column_count, block_N)The changes are thorough and consistent with the macro signatures that expect
column_count: T.int32.examples/deepseek_mla/example_mla_decode.py (2)
170-173: LGTM! Scalar variable allocation refactoring improves clarity.The conversion from
T.alloc_local([1], accum_dtype)toT.alloc_var(accum_dtype)for these scalar variables removes unnecessary array notation and makes the intent explicit. This aligns well with the PR objective to phase out legacy single-element array allocations.
179-192: Excellent refactoring of scalar operations.The removal of
[0]indexing throughout these scalar operations significantly improves code readability while preserving the original logic:
- Max reduction (line 179)
- Log-sum-exp accumulation (lines 181-183)
- Scaling computation (lines 187-188)
- Final output accumulation (line 190)
The refactoring correctly distinguishes between scalar variables (now using direct operations) and fragment arrays like
po_localando_accum_local(which still use proper indexing).examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (2)
139-143: Scalar allocations look good.The conversion from 1-element local arrays to scalar variables using
T.alloc_varis consistent with the PR's refactoring objectives and simplifies the code.
154-168: Scalar operations refactored correctly.The conversion to direct scalar operations (removing
[0]indexing) is clean and maintains the original logic. The code is now more readable and consistent with the PR's objectives.Note: The correctness of lines 155 and 160 depends on properly initializing
max_splitas flagged in the separate comment above.
|
@codex review |
|
@regression-perf |
|
Codex Review: Didn't find any major issues. Breezy! ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |
|
@regression-perf |
This pull request refactors several attention and decoding kernels to improve code clarity and consistency in local variable allocation and usage. The main changes involve replacing single-element local array allocations with scalar variable allocations, simplifying assignments, and streamlining how masks and fragments are handled. Additionally, some logic has been updated to use more direct and readable expressions.
The most important changes are:
Refactoring local variable allocations:
T.alloc_local([1], dtype)and similar calls withT.alloc_var(dtype)for scalar variables in combine functions across various files, such asexample_tilelang_sparse_gqa_decode_paged.py,example_tilelang_sparse_gqa_decode_varlen_indice.py,example_tilelang_sparse_gqa_decode_varlen_mask.py,benchmark_mla_decode_amd_tilelang.py,example_mla_decode.py,example_mla_decode_paged.py, andexample_mla_decode_persistent.py. This change simplifies code and avoids unnecessary array indexing. [1] [2] [3] [4] [5] [6] [7]Simplifying mask and fragment handling:
block_maskfromT.alloc_localtoT.alloc_fragmentand replaced manual copying withT.copyin bothbenchmark_tilelang_block_sparse_fmha.pyandexample_tilelang_block_sparse_attn.py, making the code more efficient and easier to read. [1] [2]Improving assignment and comparison logic:
sinkinexample_gqa_sink_bwd_bhsd.pyandexample_mha_sink_bwd_bhsd.pyto directly assign fromSinks[bx]instead of using a single-element array, and updated subsequent usage accordingly. [1] [2]Enhancing code clarity in reduction and scaling computations:
These changes collectively make the codebase more readable, maintainable, and less error-prone, especially for those new to the code.
Summary by CodeRabbit
Refactor
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.