Skip to content

[Refactor] Phaseout legacy alloc_local statement in examples and introduce processing for floating fragment buffers#1495

Merged
LeiWang1999 merged 6 commits intotile-ai:mainfrom
LeiWang1999:refactor_1222
Dec 22, 2025
Merged

[Refactor] Phaseout legacy alloc_local statement in examples and introduce processing for floating fragment buffers#1495
LeiWang1999 merged 6 commits intotile-ai:mainfrom
LeiWang1999:refactor_1222

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Dec 22, 2025

This pull request refactors several attention and decoding kernels to improve code clarity and consistency in local variable allocation and usage. The main changes involve replacing single-element local array allocations with scalar variable allocations, simplifying assignments, and streamlining how masks and fragments are handled. Additionally, some logic has been updated to use more direct and readable expressions.

The most important changes are:

Refactoring local variable allocations:

  • Replaced T.alloc_local([1], dtype) and similar calls with T.alloc_var(dtype) for scalar variables in combine functions across various files, such as example_tilelang_sparse_gqa_decode_paged.py, example_tilelang_sparse_gqa_decode_varlen_indice.py, example_tilelang_sparse_gqa_decode_varlen_mask.py, benchmark_mla_decode_amd_tilelang.py, example_mla_decode.py, example_mla_decode_paged.py, and example_mla_decode_persistent.py. This change simplifies code and avoids unnecessary array indexing. [1] [2] [3] [4] [5] [6] [7]

Simplifying mask and fragment handling:

  • Changed the allocation of block_mask from T.alloc_local to T.alloc_fragment and replaced manual copying with T.copy in both benchmark_tilelang_block_sparse_fmha.py and example_tilelang_block_sparse_attn.py, making the code more efficient and easier to read. [1] [2]

Improving assignment and comparison logic:

  • Updated assignment of sink in example_gqa_sink_bwd_bhsd.py and example_mha_sink_bwd_bhsd.py to directly assign from Sinks[bx] instead of using a single-element array, and updated subsequent usage accordingly. [1] [2]

Enhancing code clarity in reduction and scaling computations:

  • Refactored reduction and scaling logic in combine functions to use direct scalar assignments and arithmetic, removing unnecessary array indexing and layout annotations for local scalars. [1] [2] [3] [4] [5] [6] [7]

These changes collectively make the codebase more readable, maintainable, and less error-prone, especially for those new to the code.

Summary by CodeRabbit

  • Refactor

    • Converted many per-thread local single-element buffers to scalar temporaries, simplifying kernels and reducing local allocation overhead.
    • Centralized buffer-type checks into helper predicates for more consistent memory-region classification.
  • Bug Fixes

    • Improved layout inference to detect and handle fragment buffers accessed outside expected regions, reducing layout-related failures.
  • Tests

    • Adjusted a test entrypoint to run a specific compilation check directly.

✏️ Tip: You can customize this high-level summary in your review settings.

…ious examples and operations

* Updated multiple files to replace local buffer allocations with variable allocations for improved performance and clarity.
* Changed `alloc_local` to `alloc_var` in examples related to attention mechanisms, deep learning models, and GEMM operations.
* Enhanced code readability and maintainability by streamlining buffer management across different components.
* Ensured consistent handling of buffer scopes and types throughout the codebase.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 22, 2025

Walkthrough

Refactors many kernels to replace one-element local tensors with scalar variables or fragment buffers, centralizes buffer-scope checks into predicate helpers, and adds detection/early-layout assignment for fragment buffers accessed outside TileOps.

Changes

Cohort / File(s) Summary
Block-sparse mask -> fragment
benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py, examples/blocksparse_attention/example_tilelang_block_sparse_attn.py, examples/seer_attention/block_sparse_attn_tilelang.py
Convert block_mask from alloc_local to alloc_fragment, replace per-element population loops with bulk T.copy, and update gating checks to test non-zero entries.
Sink / scalar aliasing
examples/attention_sink/example_gqa_sink_bwd_bhsd.py, examples/attention_sink/example_mha_sink_bwd_bhsd.py
Replace one-element local sink allocations with direct scalar aliases (Sinks[bx]) and update downstream computations to use the scalar.
Per-thread/local fragment -> scalar alloc_var (LSE/scale reductions)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py, examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py, examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py, examples/deepseek_mla/..., examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py, examples/deepseek_mla/example_mla_decode*.py, examples/deepseek_mla/*_paged.py, examples/deepseek_mla/*_persistent.py, examples/deepseek_mla/*_ws.py
Replace one-element alloc_local([1], ...) fragments with scalar alloc_var variables. Remove [0] indexing and rewrite initialization, max/exp/log accumulation, and final scaling to use scalars.
Single-element index/counter -> scalar
examples/deepseek_v32/fp8_lighting_indexer.py, examples/deepseek_v32/sparse_mla_fwd_pipelined.py, examples/fusedmoe/example_fusedmoe_tilelang.py, examples/grouped_gemm/example_grouped_gemm_bwd.py, examples/grouped_gemm/example_grouped_gemm_fwd.py, examples/minference/example_vertical_slash_sparse_attn.py
Replace single-element local index/counter buffers with scalar alloc_var variables; update all usages and control-flow arithmetic/indices accordingly.
GDN per-block local -> scalar / inline conditionals
examples/gdn/example_chunk_delta_bwd.py, examples/gdn/example_chunk_delta_h.py, examples/gdn/example_chunk_o_bwd.py
Replace small local tensors with scalars and refactor guarded If blocks into scalar-driven conditional expressions or ternary-like forms; update dependent arithmetic to use scalars.
Grouped GEMM / compress / misc example scalar refactors
examples/gemm_sp/example_custom_compress.py, examples/grouped_gemm/*, examples/fusedmoe/*, examples/minference/*, examples/blocksparse_attention/*
Convert various 1-element local buffers to scalars/alloc_var and update index/counter/accumulator uses to direct scalar operations.
Core: New buffer-type predicates
src/op/utils.h
Add inline helpers: IsFragmentBuffer, IsSharedBuffer, IsGlobalBuffer, IsLocalBuffer to centralize buffer-scope classification.
Core: Replace string-scope checks with predicates
src/op/atomic_add.cc, src/op/copy.cc, src/op/fill.cc, src/op/gemm.cc, src/op/gemm_sp.cc, src/op/reduce.cc
Replace direct scope string comparisons (e.g., "local.fragment", "local", "shared") with IsFragmentBuffer/IsLocalBuffer/IsSharedBuffer predicates across layout-inference and lowering logic.
Core: Parallel / loop transforms updated
src/op/parallel.cc, src/transform/common/loop_parallel_transform_utils.h, src/transform/loop_partition.cc
Use IsFragmentBuffer in BufferLoad/BufferStore visitors and let-binding expansion; add utils.h include; restrict let-binding expansion to used vars and add logging for fragment expansions.
Core: Layout inference — floating fragment handling
src/transform/layout_inference.cc
Add ComputeFloatingFragmentBuffers and floating_fragment_buffers_ to detect fragment buffers accessed outside TileOps; introduce an early step to assign fully replicated layouts for those buffers; switch scope checks to predicate helpers.
Core: Legalize safe memory access cleanup
src/transform/legalize_safe_memory_access.cc
Remove now-redundant per-file helpers (IsLocalBuffer, isSharedBuffer, IsGlobalBuffer) in favor of central predicates.
Tests / scripts
examples/gdn/test_example_gdn_compilation.py
Replace tilelang.testing.main() entrypoint call with a direct test invocation (test_example_chunk_delta_bwd_compilation()) and remove tilelang.testing import.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Areas needing extra attention:

  • src/transform/layout_inference.cc — new floating-fragment detection, thread_bounds capture, and early layout assignment.
  • src/op/gemm.cc — broad scope-check replacements across architectures; ensure predicate semantics fully match prior string checks.
  • Examples switching from per-element loops to T.copy and fragment allocations — validate fragment sizing and that T.copy preserves intended mask/layout/element semantics.
  • Examples where multi-branch Ifs were simplified to inline conditional expressions (numerical equivalence and edge-case handling).

Possibly related PRs

Suggested reviewers

  • chengyupku

Poem

🐰 I nudged a fragment, made it small and neat,
Arrays turned scalars with a tidy beat.
Masks copied whole, no loops to comb,
Predicates whisper where scopes once roamed.
Hop — layouts find their cozy home.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 21.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main refactoring work: phasing out legacy alloc_local statements and introducing floating fragment buffer processing.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/deepseek_mla/example_mla_decode.py (1)

170-192: Fix inconsistent scalar variable indexing.

On lines 170-173, scalar variables are allocated using alloc_var, but on line 190, scale_local is accessed with [0] indexing: o_accum_local[i] += po_local[i] * scale_local[0].

Since scale_local is now a scalar variable (not an array), the [0] indexing should be removed.

🔎 Proposed fix
                 for i in T.Parallel(dim):
-                    o_accum_local[i] += po_local[i] * scale_local[0]
+                    o_accum_local[i] += po_local[i] * scale_local
src/transform/common/loop_parallel_transform_utils.h (1)

138-162: Consolidate duplicate IsFragmentBuffer definition for safety.

The code correctly uses IsFragmentBuffer(op->buffer) from src/op/utils.h which includes a buffer.defined() guard. However, src/transform/common/loop_fusion_utils.h (lines 70–74) defines a duplicate version without this guard. This inconsistency creates a potential null pointer dereference risk. Consolidate by removing the local definition in loop_fusion_utils.h and using the safe canonical version from op/utils.h.

🧹 Nitpick comments (3)
examples/gdn/example_chunk_o_bwd.py (2)

278-280: Consider renaming variable for clarity.

The variable dg_last_local_0 is reused on lines 278-280 with a completely different semantic meaning:

  • Lines 226, 253: Accumulates gradient contributions
  • Line 278: Assigned to a gate value from the G tensor
  • Line 280: Scaled by an exponential

While this reuse is functionally valid, it makes the code harder to understand. Consider using a separate variable name (e.g., G_last_scaled) for lines 278-280 to improve readability.


291-295: Complex conditional expression may benefit from clarification.

The conditional expression spans multiple lines and involves computing an exponential difference. Consider extracting the condition G_last_local - G[bb, bs * block_S + i_s, bh] <= 0 into a named variable to improve readability:

exp_diff = G_last_local - G[bb, bs * block_S + i_s, bh]
dk_fragment[i_s, i_k] = (
    dk_fragment[i_s, i_k] * T.exp(exp_diff) if exp_diff <= 0 else 0
)
src/op/parallel.cc (1)

195-198: Consider using DLOG(INFO) instead of LOG(INFO) for debug output.

This logging statement will execute in production builds and may produce excessive output. Based on the surrounding code patterns (e.g., line 402-403 uses DLOG(INFO)), debug-level logging with DLOG would be more appropriate here.

🔎 Suggested fix
       if (IsFragmentBuffer(bl->buffer) && !indice_map_.count(bl->buffer)) {
-          LOG(INFO) << "ExpandLetBindings: set buffer " << bl->buffer
-                    << " with indices " << bl->indices;
+          DLOG(INFO) << "ExpandLetBindings: set buffer " << bl->buffer
+                     << " with indices " << bl->indices;
           indice_map_.Set(bl->buffer, bl->indices);
         }
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a431797 and c471c62.

📒 Files selected for processing (35)
  • benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py
  • examples/attention_sink/example_gqa_sink_bwd_bhsd.py
  • examples/attention_sink/example_mha_sink_bwd_bhsd.py
  • examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
  • examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
  • examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
  • examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
  • examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
  • examples/deepseek_mla/example_mla_decode.py
  • examples/deepseek_mla/example_mla_decode_paged.py
  • examples/deepseek_mla/example_mla_decode_persistent.py
  • examples/deepseek_mla/example_mla_decode_ws.py
  • examples/deepseek_v32/fp8_lighting_indexer.py
  • examples/deepseek_v32/sparse_mla_fwd_pipelined.py
  • examples/fusedmoe/example_fusedmoe_tilelang.py
  • examples/gdn/example_chunk_delta_bwd.py
  • examples/gdn/example_chunk_delta_h.py
  • examples/gdn/example_chunk_o_bwd.py
  • examples/gemm_sp/example_custom_compress.py
  • examples/grouped_gemm/example_grouped_gemm_bwd.py
  • examples/grouped_gemm/example_grouped_gemm_fwd.py
  • examples/minference/example_vertical_slash_sparse_attn.py
  • examples/seer_attention/block_sparse_attn_tilelang.py
  • src/op/atomic_add.cc
  • src/op/copy.cc
  • src/op/fill.cc
  • src/op/gemm.cc
  • src/op/gemm_sp.cc
  • src/op/parallel.cc
  • src/op/reduce.cc
  • src/op/utils.h
  • src/transform/common/loop_parallel_transform_utils.h
  • src/transform/layout_inference.cc
  • src/transform/legalize_safe_memory_access.cc
  • src/transform/loop_partition.cc
💤 Files with no reviewable changes (1)
  • src/transform/legalize_safe_memory_access.cc
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • examples/seer_attention/block_sparse_attn_tilelang.py
  • examples/fusedmoe/example_fusedmoe_tilelang.py
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • src/op/parallel.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/op/gemm.cc
🧬 Code graph analysis (20)
src/transform/common/loop_parallel_transform_utils.h (1)
src/op/utils.h (1)
  • IsFragmentBuffer (33-35)
src/transform/loop_partition.cc (2)
src/op/utils.h (1)
  • IsFragmentBuffer (33-35)
src/transform/common/loop_fusion_utils.h (1)
  • IsFragmentBuffer (70-75)
src/op/reduce.cc (2)
src/op/utils.h (1)
  • IsFragmentBuffer (33-35)
src/transform/common/loop_fusion_utils.h (1)
  • IsFragmentBuffer (70-75)
examples/seer_attention/block_sparse_attn_tilelang.py (4)
src/tl_templates/cuda/reduce.h (1)
  • T (178-250)
tilelang/language/allocate.py (1)
  • alloc_fragment (72-85)
tilelang/language/copy_op.py (1)
  • copy (14-95)
tilelang/language/fill_op.py (1)
  • fill (9-36)
src/transform/layout_inference.cc (3)
src/op/utils.h (2)
  • IsFragmentBuffer (33-35)
  • IsLocalBuffer (50-53)
src/transform/common/loop_fusion_utils.h (1)
  • IsFragmentBuffer (70-75)
src/layout/layout.cc (2)
  • FullyReplicated (552-556)
  • FullyReplicated (552-553)
src/op/parallel.cc (2)
src/op/utils.h (1)
  • IsFragmentBuffer (33-35)
src/transform/common/loop_fusion_utils.h (1)
  • IsFragmentBuffer (70-75)
src/op/gemm.cc (2)
src/op/utils.h (1)
  • IsFragmentBuffer (33-35)
src/transform/common/loop_fusion_utils.h (1)
  • IsFragmentBuffer (70-75)
src/op/atomic_add.cc (1)
src/op/utils.h (1)
  • IsFragmentBuffer (33-35)
src/op/utils.h (2)
src/transform/layout_inference.cc (4)
  • buffer (593-616)
  • buffer (593-593)
  • buffer (924-943)
  • buffer (924-924)
src/transform/legalize_safe_memory_access.cc (6)
  • buffer (61-69)
  • buffer (61-61)
  • buffer (72-115)
  • buffer (72-73)
  • buffer (266-271)
  • buffer (266-266)
examples/deepseek_mla/example_mla_decode_ws.py (2)
tilelang/language/fill_op.py (1)
  • clear (39-62)
tilelang/language/loop.py (2)
  • Pipelined (58-95)
  • Parallel (13-33)
examples/deepseek_mla/example_mla_decode.py (1)
tilelang/language/allocate.py (3)
  • alloc_var (89-89)
  • alloc_var (93-93)
  • alloc_var (96-150)
examples/deepseek_mla/example_mla_decode_persistent.py (2)
tilelang/language/loop.py (2)
  • Pipelined (58-95)
  • Parallel (13-33)
src/transform/storage_rewrite.cc (1)
  • i (720-860)
examples/gemm_sp/example_custom_compress.py (1)
tilelang/language/allocate.py (1)
  • alloc_shared (40-55)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (4)
tilelang/language/v2/dtypes.py (1)
  • int32 (236-236)
tilelang/language/fill_op.py (1)
  • clear (39-62)
tilelang/language/loop.py (1)
  • Parallel (13-33)
src/transform/storage_rewrite.cc (1)
  • i (720-860)
examples/deepseek_v32/fp8_lighting_indexer.py (2)
tilelang/language/loop.py (1)
  • Pipelined (58-95)
tilelang/language/kernel.py (1)
  • threads (214-218)
examples/deepseek_mla/example_mla_decode_paged.py (2)
tilelang/language/fill_op.py (1)
  • clear (39-62)
tilelang/language/loop.py (2)
  • Pipelined (58-95)
  • Parallel (13-33)
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py (3)
tilelang/language/allocate.py (1)
  • alloc_fragment (72-85)
tilelang/language/copy_op.py (1)
  • copy (14-95)
tilelang/language/fill_op.py (1)
  • fill (9-36)
src/op/gemm_sp.cc (1)
src/op/utils.h (1)
  • IsFragmentBuffer (33-35)
examples/gdn/example_chunk_o_bwd.py (2)
tilelang/language/allocate.py (2)
  • alloc_fragment (72-85)
  • alloc_shared (40-55)
tilelang/language/fill_op.py (1)
  • clear (39-62)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (2)
tilelang/language/fill_op.py (1)
  • clear (39-62)
tilelang/language/loop.py (2)
  • Pipelined (58-95)
  • Parallel (13-33)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (72)
examples/gdn/example_chunk_delta_h.py (2)

159-159: LGTM!

Clean refactoring from single-element array to scalar allocation. Using T.alloc_var(T.float32) is the appropriate pattern for scalar values, eliminating unnecessary array indexing overhead.


204-216: Gating logic refactored correctly.

The scalar-based gating logic is cleaner and functionally equivalent:

  • Line 204: Direct scalar assignment from the last element of the chunk
  • Lines 209-213: Vectorized conditional correctly zeros out contributions where G_last_local > G_fragment
  • Lines 214-216: Proper exponential scaling of the hidden state

The constant 1.442695 (≈ 1/ln(2)) correctly implements exp(x) via exp2(x * 1.442695).

examples/fusedmoe/example_fusedmoe_tilelang.py (2)

162-166: LGTM! Scalar allocation refactor improves clarity.

The refactor correctly replaces single-element local arrays with direct scalar indexing from tensors. Using cur_group_idx = group_idx_for_bx[bx] and cur_group_size = group_sizes[cur_group_idx] eliminates unnecessary array indirection and makes the code more straightforward. All subsequent uses of these scalars (in m_start, actual_rows, and tensor indexing) are consistent and correct.

Also applies to: 179-179, 187-187


212-216: LGTM! Consistent scalar allocation pattern in Step 2 kernel.

The Step 2 kernel applies the same scalar allocation refactor as Step 1, maintaining consistency across the codebase. The pattern is correctly implemented and improves code readability by removing unnecessary array indirection.

Also applies to: 228-228

examples/blocksparse_attention/example_tilelang_block_sparse_attn.py (2)

140-140: LGTM! Appropriate refactoring to fragment allocation.

The change from alloc_local to alloc_fragment correctly standardizes the block mask allocation to use fragment buffers, which aligns with the PR's refactoring objectives and is appropriate for per-thread mask storage used in downstream fragment operations.


147-147: LGTM! Efficient bulk copy replaces element-by-element loop.

The refactoring to use T.copy for bulk population of the block mask is more efficient and cleaner than the previous per-element loop. The slice notation correctly extracts the relevant dimension from BlockSparseMask.

benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py (3)

141-141: LGTM! Consistent fragment allocation refactoring.

The change to alloc_fragment maintains consistency with the corresponding example file and correctly standardizes block mask allocation.


148-148: LGTM! Efficient bulk copy implementation.

The bulk copy using T.copy is consistent with the example file and provides better performance than element-by-element population.


155-155: LGTM! More explicit boolean comparison.

The explicit != 0 comparison improves code clarity and makes the intent more obvious, especially when working with fragment buffers. This is consistent with the pattern already used in the example file.

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (3)

127-131: LGTM! Scalar allocations improve clarity.

The refactoring from single-element local arrays to scalar variables using T.alloc_var() is correct and improves code readability. All subsequent operations have been properly updated to work with scalars instead of array indexing.


133-135: Verify initialization pattern consistency for scalar variables.

Two different initialization patterns are used for scalar variables of the same type:

  • Line 133: T.clear(lse_logsum_local)
  • Line 135: lse_max_local = -T.infinity(accum_dtype) (direct assignment)

Please confirm whether T.clear() is the intended way to zero-initialize scalars allocated with T.alloc_var(), or if direct assignment is preferred for consistency.


136-156: LGTM! All scalar operations are correctly updated.

The refactoring properly removes array indexing throughout the combine function:

  • Direct scalar comparisons and assignments (lines 137-140, 143-145, 148-152)
  • Scalar arithmetic operations maintain correct semantics
  • Fragment operations correctly use scalars in computations (line 154)

The logic flow is preserved while achieving cleaner, more intuitive code.

examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)

328-332: LGTM! Cleaner scalar access pattern.

The refactoring from a single-element local array to direct scalar assignment (sink = Sinks[bx]) eliminates unnecessary indirection and improves code readability. The scalar sink is appropriately used in the computation without array indexing.

examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)

334-338: LGTM! Consistent scalar refactoring.

The change mirrors the GQA variant and follows the same pattern: replacing single-element local array allocation with direct scalar assignment. This improves code clarity and eliminates unnecessary indexing operations.

examples/gdn/example_chunk_delta_bwd.py (3)

306-307: LGTM! Scalar allocation simplifies the code.

The refactor from single-element local array to scalar variables (G_last_local and G_last_local_exp) is cleaner and removes unnecessary array indexing. The values are correctly derived from G_shared[block_S - 1] and used in subsequent computations.


308-313: LGTM! Conditional logic correctly implements numerical stability.

The refactor simplifies the conditional update by using a ternary expression. The pattern computes all exponentials first (line 309), then conditionally applies them (lines 311-313). While this computes exp(G_last_local - G_fragment[i_s2]) even when the difference is positive (and the result won't be used), IEEE floating point handles potential overflow gracefully by returning infinity, which is safely discarded when the condition sets the result to 0. This separation also enables better parallelization.


331-331: LGTM! Scalar usage is correct.

The direct multiplication by G_last_local_exp (without array indexing) is correct and consistent with the scalar refactor. This matches the torch reference logic where dh_tmp is scaled by the exponential of the last gate value.

examples/grouped_gemm/example_grouped_gemm_bwd.py (3)

30-31: LGTM: Scalar allocation simplifies the code.

The refactor from single-element arrays to scalar variables (T.alloc_var) is cleaner and aligns with the PR objectives. Type annotations are preserved correctly.


50-51: Native if statements within T.Parallel are standard across the codebase and represent the intentional direction of TileLang. This pattern appears consistently in multiple examples (grouped_gemm_fwd, grouped_gemm_bwd, fusedmoe) and was explicitly adopted by the TileLang maintainers in PR #162. No code generation concerns apply.


37-41: No action needed—cur_batch_idx is safely initialized.

T.alloc_var in TileLang defaults to zero initialization when no explicit init parameter is provided. Line 30 allocates cur_batch_idx without an explicit initializer, so it defaults to zero. The subsequent loop logic on line 37 safely reads this initialized value. The concern about uninitialized variable usage is not valid; the code is correct.

examples/grouped_gemm/example_grouped_gemm_fwd.py (4)

64-65: LGTM! Clean refactor to scalar variables.

The migration from single-element local arrays to T.alloc_var scalar allocations simplifies the code by eliminating unnecessary array indexing throughout the batch index computation logic.


73-75: LGTM! Consistent scalar usage.

The scalar variables are used correctly throughout the batch index computations and tensor indexing. The refactor maintains logical equivalence while improving readability by removing array indirection.

Also applies to: 80-80


71-71: Verify initialization of cur_batch_idx before use.

Line 71 references cur_batch_idx on the right-hand side of the assignment within the conditional expression. If the first iteration of the loop has in_cur_batch_idx = false, this would read from an uninitialized variable. While the old array-based allocation would zero-initialize, it's unclear whether T.alloc_var provides default initialization.

The logic likely prevents this in practice (since batch_padded_offsets[0] should be 0 and m_start_padded >= 0), but the code would be more robust with explicit initialization.

Run the following script to verify whether T.alloc_var guarantees initialization and check for any related documentation:

#!/bin/bash
# Description: Search for alloc_var initialization behavior and usage patterns

# Check for alloc_var initialization semantics
rg -n -C5 'def alloc_var'

# Look for other examples where alloc_var is explicitly initialized
rg -n -A3 'alloc_var.*=.*T\.alloc_var' --type=py

# Check for patterns where alloc_var is used before assignment
ast-grep --pattern $'$VAR = T.alloc_var($$$)
$$$
$_ = T.if_then_else($_, $_, $VAR)'

84-85: and

examples/deepseek_mla/example_mla_decode_persistent.py (1)

53-56: LGTM! Scalar allocation refactoring is correct.

The refactoring from single-element local arrays to scalar variables using alloc_var is properly implemented. All usages throughout the combine function correctly access these scalars directly without array indexing.

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)

124-127: LGTM! Consistent scalar variable refactoring.

The conversion from fragment-based to scalar-based storage is correctly implemented. All operations (lines 131-144) properly use direct scalar access without indexing.

examples/deepseek_v32/fp8_lighting_indexer.py (1)

133-142: LGTM! Scalar variable refactoring is correct.

The refactoring from array-based to scalar-based min/max tracking is properly implemented. All subsequent usages (lines 147-168) correctly access these scalars directly.

examples/deepseek_mla/example_mla_decode_paged.py (1)

191-211: LGTM! Scalar refactoring is correctly implemented.

The scalar variable refactoring is properly completed with consistent direct access throughout the combine function. Line 211 correctly uses scale_local without indexing.

examples/deepseek_v32/sparse_mla_fwd_pipelined.py (1)

115-115: No changes needed. The code is correct as written.

T.alloc_var() creates a 1-element buffer (via T.alloc_buffer([1], dtype, ...)), not a scalar variable. The [0] indexing throughout the code (lines 266, 273, 276, 281, 288, 289, 295, 298, 303) is required and correct—it accesses the single element of the allocated buffer, which is the intended usage pattern.

Likely an incorrect or invalid review comment.

src/transform/loop_partition.cc (2)

31-31: LGTM!

Adding the include for the centralized IsFragmentBuffer predicate improves consistency with other parts of the codebase.


222-234: LGTM!

The refactoring from direct string comparison (op->buffer.scope() == "local.fragment") to the IsFragmentBuffer(op->buffer) predicate is a good improvement. This centralizes the fragment buffer detection logic, making it easier to maintain and ensuring consistency across the codebase.

src/op/gemm_sp.cc (3)

224-224: LGTM!

Using IsFragmentBuffer(c_) for the precondition check is consistent with the PR's goal of centralizing fragment buffer detection.


276-276: LGTM!

The predicate-based check for a_ fragment buffer is consistent with the refactoring pattern.


290-290: LGTM!

The predicate-based check for b_ fragment buffer completes the consistent use of IsFragmentBuffer throughout the InferLayout method.

src/transform/common/loop_parallel_transform_utils.h (1)

20-21: LGTM!

The include is correctly added to use the centralized IsFragmentBuffer predicate.

src/op/reduce.cc (2)

383-384: LGTM!

The refactoring to use IsFragmentBuffer(src) && IsFragmentBuffer(dst) for the layout inference condition is consistent with the PR's goal of centralizing fragment buffer detection.


521-521: LGTM!

The change to use IsFragmentBuffer(this->src) && IsFragmentBuffer(this->dst) aligns with the predicate-based approach being adopted across the codebase.

src/op/fill.cc (1)

159-181: LGTM! Clean refactor to predicate-based buffer scope checks.

The transition from string-based scope comparisons to centralized predicates (IsFragmentBuffer, IsLocalBuffer, IsSharedBuffer, IsGlobalBuffer) improves maintainability and consistency across the codebase.

src/op/parallel.cc (4)

209-223: Good optimization: filtering let-binding expansion to only used variables.

This change correctly limits the expansion to variables that are actually referenced in root_, avoiding unnecessary work. The implementation properly collects used vars via PostOrderVisit and gates expansion accordingly.


151-165: LGTM! Consistent use of IsFragmentBuffer predicate in visitor methods.

The refactoring from op->buffer.scope() == "local.fragment" to IsFragmentBuffer(op->buffer) is applied consistently in both BufferStoreNode and BufferLoadNode visitors.


275-336: LGTM! Consistent predicate usage across layout inference logic.

The IsFragmentBuffer predicate is now used consistently throughout the layout inference code paths, replacing all direct string comparisons.


15-15: Duplicate definition claim is inaccurate.

The IsFragmentBuffer function exists in both headers, but the implementations are NOT identical:

  • src/op/utils.h (inline): return buffer.defined() && buffer.scope() == "local.fragment";
  • src/transform/common/loop_fusion_utils.h (non-inline): return scope == "local.fragment";

The non-inline version in loop_fusion_utils.h omits the buffer.defined() check. Additionally, one is marked inline while the other is not, creating a more serious concern: the inline specifier cannot re-declare a function that was already defined in the translation unit as non-inline.

This is a pre-existing issue in files that include both headers (e.g., layout_inference.cc, fill.cc, atomic_add.cc, copy.cc), but the current change to parallel.cc only includes the inline version from utils.h and does not introduce new problems.

Likely an incorrect or invalid review comment.

src/op/gemm.cc (5)

523-533: LGTM! Correct use of IsFragmentBuffer for A, B, C buffer classification in Lower().

The refactoring correctly identifies fragment buffers for determining the GEMM operation variant (gemm_rs, gemm_sr, gemm_ss) and validates that C is a fragment buffer.


605-617: LGTM! Consistent predicate usage in Volta target path.

The checks for C and A buffer scopes are correctly refactored to use IsFragmentBuffer.


633-667: LGTM! Ampere/Turing/SM120/SM100 path correctly updated.

Fragment buffer checks for A and B in the MMA code path use the centralized predicate consistently.


669-714: LGTM! Hopper target path updated consistently.

The IsFragmentBuffer(c_) check and the implicit else branches for fragment A/B are handled correctly.


775-809: LGTM! CDNA target path uses predicates correctly.

The fragment buffer checks for C, A, and B in the MFMA code path are consistently updated.

src/op/atomic_add.cc (3)

328-339: LGTM! Layout comparison guard updated to use predicates.

The condition correctly uses IsFragmentBuffer for both source and destination buffers when comparing fragment layouts.


434-445: LGTM! Visitor methods in AtomicLoopNestCollector use predicates consistently.

Both BufferStoreNode and BufferLoadNode visitors correctly use IsFragmentBuffer for buffer classification.


476-477: LGTM! Fragment buffer check in layout inference loop.

The negated check !IsFragmentBuffer(buf) correctly replaces the previous string comparison.

examples/minference/example_vertical_slash_sparse_attn.py (4)

133-136: LGTM! Correct migration from single-element arrays to scalar variables.

Replacing T.alloc_local([1], dtype) with T.alloc_var(dtype=int_dtype) for block_count and column_count eliminates unnecessary array overhead for scalar values. This aligns with the PR objective of phasing out legacy alloc_local for single-element allocations.


146-147: LGTM! Direct scalar assignment.

The assignments now use direct scalar semantics instead of array indexing, which is cleaner and more efficient.


163-177: LGTM! Loop bounds and control flow correctly use scalar values.

The T.serial(block_count) loop bound and other usages correctly operate on the scalar variable directly.


210-285: LGTM! Conditional logic and macro calls updated for scalar semantics.

All conditional checks (column_count != 0), loop bounds (T.ceildiv(column_count, block_N)), and macro parameter passing correctly use the scalar column_count value throughout the prefetch and compute sections.

src/op/copy.cc (6)

776-780: LGTM! Clean refactoring to use predicate.

The replacement of direct scope string comparison with IsFragmentBuffer(dst) improves code maintainability and consistency across the codebase.


793-796: LGTM! Consistent predicate usage.

The change aligns with the refactoring pattern established throughout the file.


808-826: LGTM! Tensor memory checks properly refactored.

Both CheckTMemLoad and CheckTMemStore now use the centralized IsFragmentBuffer predicate, maintaining consistency with other memory operation checks in the file.


953-958: Refined warning logic for local buffer copies.

The changes improve precision by:

  • Using the IsLocalBuffer predicate for cleaner buffer type checking
  • Narrowing the warning to only fire for local→non-local copies (which can cause conflicted writes)
  • Eliminating false-positive warnings for non-local→local copies

This is an intentional behavior refinement, not just a refactoring.


1234-1243: LGTM! Tensor memory copy type detection refactored.

The conditions for detecting tcgen05.ld and tcgen05.st operations now use the centralized predicate, maintaining consistency with the rest of the codebase.


2063-2086: LGTM! Fragment layout collection uses predicate.

The change in CollectFragmentLayouts aligns with the overall refactoring to use centralized buffer-type predicates.

src/op/utils.h (1)

32-53: Well-designed buffer-type predicates.

The predicates provide a clean abstraction for buffer classification:

  • All predicates defensively check buffer.defined() before accessing scope
  • IsLocalBuffer correctly handles both "local" and "local.var" scopes
  • IsSharedBuffer provides flexibility with the allow_dynamic parameter (defaulting to true)
  • Inline implementation ensures zero overhead

These centralized predicates improve maintainability and reduce the risk of inconsistent scope checks across the codebase.

src/transform/layout_inference.cc (7)

23-23: LGTM! Required include for buffer-type predicates.

The include of utils.h provides access to the centralized IsFragmentBuffer and IsLocalBuffer predicates used throughout this file.


174-175: LGTM! Consistent predicate usage for fragment buffer detection.

All replacements of direct scope string comparisons with IsFragmentBuffer(buffer) are functionally equivalent and improve code maintainability.

Also applies to: 383-383, 595-595, 790-790


312-321: Well-designed floating fragment buffer initialization.

Step 0 correctly handles fragment buffers accessed outside TileOps by:

  • Assigning fully replicated layouts (since access patterns cannot be inferred from TileOp semantics)
  • Using the captured thread_bounds from the access point
  • Respecting pre-existing layouts in layout_map

This preprocessing step runs before strict layout inference (step 1), ensuring these buffers have valid layouts throughout the inference pipeline.


436-438: Correct placement of floating buffer detection.

Computing floating fragment buffers after the main collection traversal ensures that the set of TileOp nodes is complete before scanning for accesses outside those operations.


862-964: Robust floating fragment buffer detection algorithm.

The implementation correctly identifies fragment buffers accessed outside TileOps:

  1. Node collection (lines 884-890): Uses PostOrderVisit to build a complete set of nodes inside TileOps
  2. Visitor pattern (lines 894-949): Tracks thread context while scanning for fragment buffer accesses
  3. Access classification (lines 924-943): Efficiently checks if accesses are floating with early returns for non-fragment buffers, TileOp-internal accesses, and duplicates
  4. Thread bounds capture: Extracts bounds from the analyzer with appropriate fallback to Range(0,1)

The comprehensive documentation (lines 862-877) with a concrete example makes the logic easy to understand and maintain.


971-977: Well-documented data member for floating buffers.

The data structure appropriately:

  • Uses std::unordered_map for efficient lookup
  • Stores Buffer → Range mappings to track thread bounds per buffer
  • Employs correct hash/equality predicates (ObjectPtrHash, ObjectPtrEqual) for Buffer keys
  • Includes clear documentation linking to ComputeFloatingFragmentBuffers()

1245-1359: Correct buffer classification in loop transformation logic.

The VisitStmt_(ForNode) method properly uses IsLocalBuffer and IsFragmentBuffer predicates to:

  • Lines 1268, 1288, 1292: Detect register-local-only loops that don't require thread partitioning
  • Lines 1313, 1318: Identify non-local memory accesses requiring vectorization

The refactored logic maintains the correct semantics for loop partitioning and vectorization decisions.

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (3)

138-142: Good refactor from fragment to scalar allocations.

The change from single-element fragment allocations to scalar variables using T.alloc_var() removes unnecessary indirection and improves code clarity. This aligns well with the PR objectives.

However, see the critical initialization issue flagged below for max_split.


153-167: Scalar refactor correctly applied in reduction and accumulation loops.

The scalar operations in the log-sum-exp reduction (lines 153-157) and output accumulation (lines 158-165) are correctly refactored from array-indexed operations to direct scalar assignments and arithmetic. The logic is preserved and the code is more readable without the unnecessary [0] indexing.

The only issue is the reliance on max_split which may be uninitialized (flagged above).


144-146: No changes needed. The code correctly uses T.clear() on variables created by T.alloc_var(), which returns a 1-element buffer, not a plain scalar. The T.clear() API is designed for buffers and properly initializes them to zero. This usage pattern is consistent throughout the codebase and is correctly documented in the type signatures.

Likely an incorrect or invalid review comment.

examples/seer_attention/block_sparse_attn_tilelang.py (1)

111-111: Verify fragment buffer performance characteristics with conditional mask access.

The change from T.alloc_local to T.alloc_fragment for block_mask (line 111) is syntactically correct and the bulk T.copy initialization (line 118) is well-formed. However, fragment memory is designed for accumulators in register files to minimize latency in tensor core operations, while this mask is only read in a conditional check at line 123.

The refactoring aligns with the PR's systematic standardization of fragment buffer usage, but this unconventional use case (scalar conditional reads from a fragment buffer) should be validated to ensure the TileLang compiler generates efficient code and doesn't introduce unexpected performance overhead or compiler issues.

Comment on lines 147 to +151
for k in T.serial(num_split):
lse_local_split[0] = glse[bz, by, k]
if lse_local_split[0] != 0:
max_split[0] = k
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
lse_local_split = glse[bz, by, k]
if lse_local_split != 0:
max_split = k
lse_max_local = T.max(lse_max_local, glse[bz, by, k])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: max_split used uninitialized when all lse values are zero.

At line 150, max_split is only assigned inside the conditional if lse_local_split != 0. If all log-sum-exp values are zero (which can occur when no valid blocks are processed in any split), max_split is never initialized but is subsequently used in conditionals at lines 154 and 159.

In the previous array-based implementation, max_split[0] would have been zero-initialized, providing defined behavior. The scalar refactor introduces undefined behavior in this edge case.

🔎 Proposed fix: Initialize max_split before the loop
 T.clear(lse_logsum_local)
 T.clear(o_accum_local)
 lse_max_local = -T.infinity(accum_dtype)
+max_split = -1
 for k in T.serial(num_split):
     lse_local_split = glse[bz, by, k]
     if lse_local_split != 0:
         max_split = k
         lse_max_local = T.max(lse_max_local, glse[bz, by, k])

Initializing to -1 ensures that if no valid splits are found, the conditions k <= max_split at lines 154 and 159 will be false for all k ≥ 0, correctly skipping the accumulation loops.

🤖 Prompt for AI Agents
In examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
around lines 147 to 151, max_split is only assigned inside the if and can remain
uninitialized when all lse values are zero; initialize max_split to -1
immediately before the for k in T.serial(num_split) loop so that when no valid
splits are found the later checks (k <= max_split) are false and accumulation
loops are skipped, ensuring defined behavior in the all-zero case.

Comment on lines +176 to 196
lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_var(accum_dtype)
lse_max_local = T.alloc_var(accum_dtype)
scale_local = T.alloc_var(accum_dtype)

T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
lse_max_local = T.max(lse_max_local, glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
lse_local_split = glse[bz, by, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
lse_local_split = glse[bz, by, k]
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: scale_local[0] indexing on a scalar variable.

scale_local is allocated as a scalar via T.alloc_var(accum_dtype) at line 179, but is incorrectly indexed as scale_local[0] at line 196. This will cause an error since scalars cannot be indexed.

🔎 Proposed fix
                 for i in T.Parallel(dim):
-                    o_accum_local[i] += po_local[i] * scale_local[0]
+                    o_accum_local[i] += po_local[i] * scale_local
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_var(accum_dtype)
lse_max_local = T.alloc_var(accum_dtype)
scale_local = T.alloc_var(accum_dtype)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
lse_max_local = T.max(lse_max_local, glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
lse_local_split = glse[bz, by, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
lse_local_split = glse[bz, by, k]
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_var(accum_dtype)
lse_max_local = T.alloc_var(accum_dtype)
scale_local = T.alloc_var(accum_dtype)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local = T.max(lse_max_local, glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split = glse[bz, by, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split = glse[bz, by, k]
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local
🤖 Prompt for AI Agents
In examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py around lines
176 to 196, scale_local is allocated as a scalar (T.alloc_var(accum_dtype)) but
later used with an index (scale_local[0]) which is invalid; change the usage to
treat scale_local as a scalar (remove the [0] indexing) wherever it is used
(e.g., replace scale_local[0] with scale_local) so the multiplication into
o_accum_local uses the scalar value, or alternatively allocate scale_local as a
1-element buffer if indexing is required—prefer the scalar fix for minimal
change.

Comment on lines +475 to 495
lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_var(accum_dtype)
lse_max_local = T.alloc_var(accum_dtype)
scale_local = T.alloc_var(accum_dtype)

T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k])
lse_max_local = T.max(lse_max_local, glse[bz, hid, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, hid, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
lse_local_split = glse[bz, hid, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, hid, k, i]
lse_local_split[0] = glse[bz, hid, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
lse_local_split = glse[bz, hid, k]
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix inconsistent scalar variable indexing.

Similar to example_mla_decode.py, scalar variables are allocated with alloc_var on lines 475-478, but line 495 still uses array indexing: o_accum_local[i] += po_local[i] * scale_local[0].

Since scale_local is now a scalar variable, remove the [0] indexing.

🔎 Proposed fix
                 for i in T.Parallel(dim):
-                    o_accum_local[i] += po_local[i] * scale_local[0]
+                    o_accum_local[i] += po_local[i] * scale_local
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_var(accum_dtype)
lse_max_local = T.alloc_var(accum_dtype)
scale_local = T.alloc_var(accum_dtype)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k])
lse_max_local = T.max(lse_max_local, glse[bz, hid, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, hid, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
lse_local_split = glse[bz, hid, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, hid, k, i]
lse_local_split[0] = glse[bz, hid, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
lse_local_split = glse[bz, hid, k]
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_var(accum_dtype)
lse_max_local = T.alloc_var(accum_dtype)
scale_local = T.alloc_var(accum_dtype)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local = T.max(lse_max_local, glse[bz, hid, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split = glse[bz, hid, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, hid, k, i]
lse_local_split = glse[bz, hid, k]
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local
🤖 Prompt for AI Agents
In examples/deepseek_mla/example_mla_decode_ws.py around lines 475 to 495,
scalar locals lse_local_split, lse_logsum_local, lse_max_local, and scale_local
are allocated with alloc_var but the code treats scale_local as an array at line
495 using scale_local[0]; update the usage to treat scale_local as a scalar by
removing the [0] index (i.e., use scale_local directly in the multiplication)
and verify no other array-style indexing remains for these alloc_var scalars.

Comment on lines +202 to +204
dg_last_local_0 = T.alloc_var(dtype=gate_dtype)
dg_last_local_1 = T.alloc_var(dtype=gate_dtype)
G_last_local = T.alloc_var(dtype=gate_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find and examine the specific file
fd -name "example_chunk_o_bwd.py" -type f

Repository: tile-ai/tilelang

Length of output: 290


🏁 Script executed:

#!/bin/bash
# Search for the file and examine relevant lines
find . -path "*examples/gdn*" -name "*bwd*" -type f 2>/dev/null | head -5

Repository: tile-ai/tilelang

Length of output: 183


🏁 Script executed:

#!/bin/bash
# Examine the specific file, focusing on lines around 202-204, 228, and 293
wc -l ./examples/gdn/example_chunk_o_bwd.py

Repository: tile-ai/tilelang

Length of output: 101


🏁 Script executed:

#!/bin/bash
# Read the relevant sections of the file
echo "=== Lines 200-210 (around allocation at 202-204) ==="
sed -n '200,210p' ./examples/gdn/example_chunk_o_bwd.py

echo ""
echo "=== Lines 225-235 (around line 228) ==="
sed -n '225,235p' ./examples/gdn/example_chunk_o_bwd.py

echo ""
echo "=== Lines 290-300 (around line 293) ==="
sed -n '290,300p' ./examples/gdn/example_chunk_o_bwd.py

echo ""
echo "=== Full context: lines 200-320 to understand the flow ==="
sed -n '200,320p' ./examples/gdn/example_chunk_o_bwd.py | cat -n

Repository: tile-ai/tilelang

Length of output: 10081


🏁 Script executed:

#!/bin/bash
# Let me see more context around where G_last_local should be assigned
echo "=== Checking lines 74-82 for G_last_local assignment ==="
sed -n '74,82p' ./examples/gdn/example_chunk_o_bwd.py | cat -n

Repository: tile-ai/tilelang

Length of output: 540


🏁 Script executed:

#!/bin/bash
# Check the full context where G_last_local is being used
echo "=== Lines 74-105 context ==="
sed -n '274,310p' ./examples/gdn/example_chunk_o_bwd.py | cat -n

Repository: tile-ai/tilelang

Length of output: 2720


🏁 Script executed:

#!/bin/bash
# Search for all assignments to G_last_local in the file
echo "=== All lines where G_last_local is used/assigned ==="
grep -n "G_last_local" ./examples/gdn/example_chunk_o_bwd.py

Repository: tile-ai/tilelang

Length of output: 397


Initialize G_last_local before usage on lines 292-293.

The scalar variable G_last_local is allocated on line 204 and cleared on line 228, but it is never assigned a meaningful value before being used on lines 292-293 in the expression G_last_local - G[bb, bs * block_S + i_s, bh]. Since it remains at the cleared (zero) value, this produces incorrect computation. Assign G_last_local from the G tensor before its usage, similar to how dg_last_local_0 is initialized from G on line 279 (e.g., G_last_local = G[bb, bs * block_S + block_S - 1, bh]).

🤖 Prompt for AI Agents
In examples/gdn/example_chunk_o_bwd.py around lines 202-204 (allocation) and
usages at ~228 (cleared), ~279 (dg init) and ~292-293 (use), G_last_local is
allocated and cleared but never assigned before being used in the expression
G_last_local - G[bb, bs * block_S + i_s, bh]; set G_last_local from the G tensor
prior to its usage (mirror dg_last_local_0 initialization) so it holds G[bb, bs
* block_S + block_S - 1, bh] (or the appropriate last-index value for the
current block) before the subtraction to ensure correct computation.

Comment on lines 260 to +262
# TODO: alloc_var seems buggy here
non_zero_cnt = T.alloc_local((1,), dtype=T.uint8)
non_zero_elt_log_idx = T.alloc_local((elem,), dtype=T.uint8)
non_zero_cnt = T.alloc_var(dtype=T.uint8)
non_zero_elt_log_idx = T.alloc_shared((elem,), dtype=T.uint8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash

# Read the file around the lines in question
echo "=== Lines 250-285 of example_custom_compress.py ==="
sed -n '250,285p' examples/gemm_sp/example_custom_compress.py

# Check if the file exists
echo -e "\n=== File check ==="
ls -la examples/gemm_sp/example_custom_compress.py 2>&1 || echo "File not found"

Repository: tile-ai/tilelang

Length of output: 2160


🏁 Script executed:

#!/bin/bash
# Search for alloc_var definition and usage patterns
echo "=== Searching for T.alloc_var definitions ==="
rg -n "def alloc_var" --type py

echo -e "\n=== Searching for alloc_var usages in the file ==="
rg -n "alloc_var" examples/gemm_sp/ --type py -B 2 -A 2

echo -e "\n=== Searching for alloc_var in broader codebase ==="
rg -n "alloc_var" --type py | head -30

Repository: tile-ai/tilelang

Length of output: 4487


🏁 Script executed:

#!/bin/bash
# Read the alloc_var implementation
echo "=== alloc_var implementation in allocate.py ==="
sed -n '85,130p' tilelang/language/allocate.py

# Also check for any comments about scope or parallel safety
echo -e "\n=== Looking for scope/parallel-related comments ==="
grep -n -i "scope\|parallel\|thread" tilelang/language/allocate.py | head -20

Repository: tile-ai/tilelang

Length of output: 3795


🏁 Script executed:

#!/bin/bash
# Look at test cases to understand behavior
echo "=== Test case for alloc_var usage ==="
head -100 testing/python/language/test_tilelang_language_frontend_v2.py | grep -A 30 "def alloc_var\|T.alloc_var"

echo -e "\n=== Test with parallel context ==="
grep -B 5 -A 15 "T.Parallel.*alloc_var\|alloc_var.*T.Parallel" testing/python/language/ -r 2>/dev/null || echo "No parallel+alloc_var pattern found in tests"

Repository: tile-ai/tilelang

Length of output: 178


🏁 Script executed:

#!/bin/bash
# Read more context around the allocate.py to see full implementation
echo "=== Full alloc_var implementation ==="
sed -n '96,155p' tilelang/language/allocate.py

# Check if there are any docs about scope semantics
echo -e "\n=== Looking for alloc_buffer wrapper to understand scope handling ==="
grep -n "def alloc_buffer" tilelang/language/ -r

Repository: tile-ai/tilelang

Length of output: 3017


🏁 Script executed:

#!/bin/bash
# Read the full context of the parallel loop and variable usage
echo "=== Full context around parallel loop (lines 260-285) ==="
sed -n '260,285p' examples/gemm_sp/example_custom_compress.py

Repository: tile-ai/tilelang

Length of output: 1544


🏁 Script executed:

#!/bin/bash
# Confirm the exact usage pattern and race condition
echo "=== Detailed line-by-line analysis ==="
sed -n '264,280p' examples/gemm_sp/example_custom_compress.py | cat -n

echo -e "\n=== Summary of variable usage in parallel loop ==="
echo "non_zero_cnt (alloc_var): Line 268 (assign), 276 (increment)"
echo "non_zero_elt_log_idx (alloc_shared): Lines 269, 272, 277, 278 (multiple writes)"
echo ""
echo "Each thread in T.Parallel(block_M) has independent tm value"
echo "But both variables are allocated ONCE before the loop"
echo "non_zero_elt_log_idx with alloc_shared = SHARED MEMORY"
echo "Multiple threads write to same shared array = RACE CONDITION"

Repository: tile-ai/tilelang

Length of output: 1578


Resolve TODO and verify thread-safety of allocations within parallel loop.

The TODO on line 260 flags a legitimate concern. Both non_zero_cnt and non_zero_elt_log_idx are allocated before the parallel loop and used within it, creating concurrency hazards:

  1. Line 262 race condition: non_zero_elt_log_idx uses T.alloc_shared, creating shared memory accessible by all threads. Multiple threads write to this array concurrently (lines 269, 272, 277–278) without synchronization or thread-local indexing, causing data races.

  2. Line 261 uncertainty: non_zero_cnt uses T.alloc_var (default "local.var" scope), which should provide thread-local storage. However, the TODO suggests the semantics may not be working as intended in this parallel context. The reassignment at line 268 and increment at line 276 need guaranteed thread-local isolation.

Before merging, verify or fix:

  • Replace alloc_shared for non_zero_elt_log_idx with thread-local allocation, or add thread-safety mechanisms.
  • Confirm alloc_var provides proper thread-local semantics within T.Parallel loops, or replace with an appropriate thread-local allocation.

…ample

* Updated index handling in `sparse_mla_fwd_pipelined.py` to eliminate unnecessary local array usage, improving code clarity and performance.
* Replaced instances of `indices_local[0]` with direct usage of `indices_local` for better readability and consistency in buffer access.
* Commented out the main execution call in the GDN test script to focus on the specific test function, enhancing test clarity.
@LeiWang1999
Copy link
Member Author

/regression-perf

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (1)

148-152: Critical: max_split remains uninitialized - previous review issue not addressed.

This critical issue was flagged in a previous review but has not been fixed. At line 151, max_split is only assigned when lse_local_split != 0. If all log-sum-exp values are zero (when no valid blocks are processed in any split), max_split is never initialized but is used in conditionals at lines 155 and 160.

The previous array-based implementation would have zero-initialized max_split[0], providing defined behavior. The scalar refactor introduces undefined behavior in this edge case.

🔎 Required fix: Initialize max_split before the loop
 T.clear(o_accum_local)
 lse_max_local = -T.infinity(accum_dtype)
+max_split = -1
 for k in T.serial(num_split):
     lse_local_split = glse[bz, by, k]
     if lse_local_split != 0:
         max_split = k
         lse_max_local = T.max(lse_max_local, glse[bz, by, k])

Initializing to -1 ensures that if no valid splits are found, the conditions k <= max_split at lines 155 and 160 will be false for all k ≥ 0, correctly skipping the accumulation loops.

🧹 Nitpick comments (2)
examples/deepseek_mla/example_mla_decode.py (1)

175-177: Consider consistent initialization pattern for scalar variables.

Line 175 uses T.clear(lse_logsum_local) while line 177 uses direct assignment lse_max_local = -T.infinity(accum_dtype) to initialize scalar variables. Both approaches work, but for consistency, consider using the same pattern for all scalar variable initializations.

Direct assignment (as on line 177) may be more idiomatic for scalars and aligns better with the scalar variable semantics.

🔎 Suggested refactor for consistency
-            T.clear(lse_logsum_local)
+            lse_logsum_local = T.cast(0, accum_dtype)
             T.clear(o_accum_local)
             lse_max_local = -T.infinity(accum_dtype)

Alternatively, if T.clear() is the preferred TileLang pattern for initialization:

             T.clear(lse_logsum_local)
             T.clear(o_accum_local)
-            lse_max_local = -T.infinity(accum_dtype)
+            T.clear(lse_max_local)
+            lse_max_local = -T.infinity(accum_dtype)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (1)

145-147: Consider using direct assignment for scalar initialization.

Line 145 uses T.clear() on the scalar variable lse_logsum_local, while line 147 uses direct assignment for lse_max_local. For consistency and idiomatic style, consider initializing lse_logsum_local with direct assignment:

lse_logsum_local = 0
🔎 Proposed refactor for consistency
-T.clear(lse_logsum_local)
+lse_logsum_local = 0
 T.clear(o_accum_local)
 lse_max_local = -T.infinity(accum_dtype)
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b2d281c and a377fa5.

📒 Files selected for processing (13)
  • examples/attention_sink/example_gqa_sink_bwd_bhsd.py
  • examples/attention_sink/example_mha_sink_bwd_bhsd.py
  • examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
  • examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
  • examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
  • examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
  • examples/deepseek_mla/example_mla_decode.py
  • examples/deepseek_v32/fp8_lighting_indexer.py
  • examples/deepseek_v32/sparse_mla_fwd_pipelined.py
  • examples/fusedmoe/example_fusedmoe_tilelang.py
  • examples/gdn/test_example_gdn_compilation.py
  • examples/minference/example_vertical_slash_sparse_attn.py
  • examples/seer_attention/block_sparse_attn_tilelang.py
🚧 Files skipped from review as they are similar to previous changes (8)
  • examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
  • examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
  • examples/deepseek_v32/fp8_lighting_indexer.py
  • examples/fusedmoe/example_fusedmoe_tilelang.py
  • examples/seer_attention/block_sparse_attn_tilelang.py
  • examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
  • examples/attention_sink/example_mha_sink_bwd_bhsd.py
  • examples/attention_sink/example_gqa_sink_bwd_bhsd.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • examples/gdn/test_example_gdn_compilation.py
🧬 Code graph analysis (2)
examples/deepseek_mla/example_mla_decode.py (1)
tilelang/language/allocate.py (3)
  • alloc_var (89-89)
  • alloc_var (93-93)
  • alloc_var (96-150)
examples/deepseek_v32/sparse_mla_fwd_pipelined.py (1)
tilelang/language/allocate.py (3)
  • alloc_var (89-89)
  • alloc_var (93-93)
  • alloc_var (96-150)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (11)
examples/gdn/test_example_gdn_compilation.py (1)

319-320: Verify if this change is intentional for production.

This change replaces the test harness with a direct call to a single test function, meaning the other 6 test functions in this file (test_example_wy_fast_compilation, test_example_wy_fast_bwd_split_compilation, test_example_chunk_o_compilation, test_example_chunk_o_bwd_compilation, test_example_chunk_scaled_dot_kkt_compilation, test_example_cumsum_compilation, test_example_chunk_delta_h_compilation) will no longer execute when running this script directly.

While the commit message mentions "focus testing," the use of a comment (rather than deletion) suggests this might be a temporary debugging change. If this is intended to be permanent, please clarify the rationale for running only this specific test function.

examples/deepseek_v32/sparse_mla_fwd_pipelined.py (2)

115-115: LGTM: Clean refactoring to scalar variable.

The change from a 1-element local array to a scalar variable (T.alloc_var) is correct and improves code clarity by eliminating unnecessary array indirection.


266-305: LGTM: All usages correctly updated.

All reads and writes to indices_local have been correctly updated to work with the scalar variable, removing the unnecessary [0] indexing. The refactoring maintains the original logic while improving readability.

examples/minference/example_vertical_slash_sparse_attn.py (4)

133-135: LGTM! Clean refactoring to scalar variables.

The migration from single-element arrays to scalar variables using T.alloc_var improves code clarity by eliminating unnecessary array indirection for block_count and column_count.


146-147: LGTM! Scalar assignments are correct.

The direct scalar assignments properly replace the previous array-indexed pattern and work correctly with the T.alloc_var allocations.


163-177: LGTM! Loop constructs correctly use scalar values.

Both T.serial(block_count) loop constructs (lines 163, 177) correctly consume the scalar variable, eliminating the previous [0] indexing.


210-285: LGTM! Comprehensive scalar usage throughout the conditional block.

All references to column_count within the conditional block have been correctly updated to use the scalar variable:

  • Conditional check (line 210)
  • Function parameter passing to Prefetch and Compute macros
  • Arithmetic operations in T.ceildiv(column_count, block_N)

The changes are thorough and consistent with the macro signatures that expect column_count: T.int32.

examples/deepseek_mla/example_mla_decode.py (2)

170-173: LGTM! Scalar variable allocation refactoring improves clarity.

The conversion from T.alloc_local([1], accum_dtype) to T.alloc_var(accum_dtype) for these scalar variables removes unnecessary array notation and makes the intent explicit. This aligns well with the PR objective to phase out legacy single-element array allocations.


179-192: Excellent refactoring of scalar operations.

The removal of [0] indexing throughout these scalar operations significantly improves code readability while preserving the original logic:

  • Max reduction (line 179)
  • Log-sum-exp accumulation (lines 181-183)
  • Scaling computation (lines 187-188)
  • Final output accumulation (line 190)

The refactoring correctly distinguishes between scalar variables (now using direct operations) and fragment arrays like po_local and o_accum_local (which still use proper indexing).

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (2)

139-143: Scalar allocations look good.

The conversion from 1-element local arrays to scalar variables using T.alloc_var is consistent with the PR's refactoring objectives and simplifies the code.


154-168: Scalar operations refactored correctly.

The conversion to direct scalar operations (removing [0] indexing) is clean and maintains the original logic. The code is now more readable and consistent with the PR's objectives.

Note: The correctness of lines 155 and 160 depends on properly initializing max_split as flagged in the separate comment above.

@LeiWang1999
Copy link
Member Author

@codex review

@LeiWang1999
Copy link
Member Author

@regression-perf

@chatgpt-codex-connector
Copy link

Codex Review: Didn't find any major issues. Breezy!

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@xwhzz
Copy link
Contributor

xwhzz commented Dec 22, 2025

@regression-perf

@LeiWang1999 LeiWang1999 merged commit 1d9a2ea into tile-ai:main Dec 22, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants