Skip to content

[Cleanup] Remove unnecessary macros in tilelang examples#1514

Merged
LeiWang1999 merged 4 commits intotile-ai:mainfrom
Rachmanino:1223-rm-macro
Dec 24, 2025
Merged

[Cleanup] Remove unnecessary macros in tilelang examples#1514
LeiWang1999 merged 4 commits intotile-ai:mainfrom
Rachmanino:1223-rm-macro

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Dec 23, 2025

Summary by CodeRabbit

  • Refactor
    • Consolidated macro-based attention stages into inline, per-iteration implementations across attention kernels, simplifying control flow.
  • API
    • Added explicit split vs. no-split entry points and explicit partial-output parameters for split-mode flows, clarifying split behavior and partial-result aggregation.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 23, 2025

Warning

Rate limit exceeded

@LeiWang1999 has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 19 minutes and 38 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between bf5e2d9 and 5bcb9cb.

📒 Files selected for processing (2)
  • examples/flash_decoding/example_gqa_decode_varlen_logits.py
  • examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py
📝 Walkthrough

Walkthrough

This PR inlines previously macro-based helpers (MMA0, MMA1, Softmax, Rescale) into many attention kernels' main loops and replaces several macro split/combine flows with explicit prim_func entry points (main_split / main_no_split) that add parameters like glse, Output_partial, and Output. No widespread external API removals; some prim_func signatures were added or adjusted.

Changes

Cohort / File(s) Change Summary
Attention Sink Macro Inlining
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py, examples/attention_sink/example_mha_sink_fwd_bhsd.py, examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
Removed MMA0/MMA1/Softmax/Rescale macros; inlined K/Q/V copies, acc_s/acc_o computations, softmax-like reductions and GEMMs directly in the pipelined loop. No public signature changes.
Flash Attention Macro Inlining
examples/flash_attention/... (multiple files: example_gqa_fwd_bshd.py, example_gqa_fwd_bshd_wgmma_pipelined.py, example_mha_fwd_bhsd.py, example_mha_fwd_bhsd_wgmma_pipelined.py, example_mha_fwd_bshd.py, example_mha_fwd_bshd_wgmma_pipelined.py)
Replaced macro-based stages with inline per-iteration logic: K/V shared copies, causal/OOB masking, QK^T GEMM, scores_max/scale/logsum updates (exp2-based), acc_s_cast scaling, and V GEMM; preserves outputs and signatures.
Blocksparse / Seer Attention Inlining
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py, examples/seer_attention/block_sparse_attn_tilelang.py
Inlined Softmax/Rescale into the block loop: scores_max tracking, exp2-based scaling, sums/logsum updates, acc_s/acc_o adjustments, and final normalization.
Sparse GQA Decode → prim_func
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py, examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py, examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
Collapsed macro split/combine into single main prim_func variants; added explicit Output parameter (and preserved Output_partial where present); removed intermediate macro wrappers and combine macros.
Flash Decoding / MHA Inference → prim_func
examples/flash_decoding/example_gqa_decode.py, examples/flash_decoding/example_gqa_decode_varlen_logits.py, examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py, examples/flash_decoding/example_mha_inference.py
Replaced macro-based flash_attn and split/combine with prim_func entry points (split/no-split variants). Added/annotated Output, Output_partial, and glse parameters and adjusted Output shapes where applicable.
DeepSeek MLA Decode (split/no-split refactor)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py, examples/deepseek_mla/example_mla_decode.py, examples/deepseek_mla/example_mla_decode_paged.py, examples/deepseek_mla/example_mla_decode_ws.py
Replaced macro-based flash-attn with main_split and main_no_split prim_funcs; introduced glse, Output_partial, explicit per-split loops/indices, combine logic, and added Output to signatures.
GEMM StreamK Macro Inlining
examples/gemm_streamk/example_tilelang_gemm_streamk.py
Removed compute_first_wave / compute_full_tiles macros and inlined first-wave and full-tile GEMM loop logic inside tl_matmul_streamk.
Warp Specialize Signature Update
examples/warp_specialize/example_warp_specialize_flashmla.py
Converted macro to prim_func main_no_split; added glse and Output_partial parameters; removed duplicate definition and changed default main() batch from 1 → 132.

Sequence Diagram(s)

(omitted — changes are dispersed across many example kernels and are primarily macro inlining / prim_func reorganization rather than a new multi-component runtime flow)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • chengyupku

"A rabbit hops through code so neat,
I stitched the macros into one heartbeat.
No signatures lost, just tidy inline art —
Softmax and rescale now play their part. 🐇✨"

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly describes the main change: removing unnecessary macros from tilelang examples, which is the primary objective of this PR.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/flash_decoding/example_gqa_decode.py (1)

182-186: Potential out-of-bounds access when dim != 128.

scale_local is allocated with shape [128] (line 167), but at line 186 it's indexed by i where i iterates over dim. If dim > 128, this causes an out-of-bounds access. The existing comment on line 184 acknowledges this coupling but doesn't prevent the issue.

Since scale is uniform across the dimension (a per-head scalar), consider either:

  1. Reducing scale_local to a scalar and broadcasting explicitly, or
  2. Allocating scale_local with size [dim] and adjusting the parallel range accordingly.
🔎 Suggested fix (Option 1 - use scalar broadcast)
-            scale_local = T.alloc_fragment([128], accum_dtype)
+            scale_local = T.alloc_fragment([1], accum_dtype)
 ...
             for k in T.serial(num_split):
                 for i in T.Parallel(dim):
                     po_local[i] = Output_partial[bz, by, k, i]
-                for j in T.Parallel(128):
-                    scale_local[j] = T.exp2(lse_local[k, j] - lse_logsum_local[j])
+                scale_local[0] = T.exp2(lse_local[k, 0] - lse_logsum_local[0])
                 # Note: Pay attention to dim and the number of threads in Parallel
                 for i in T.Parallel(dim):
-                    o_accum_local[i] += po_local[i] * scale_local[i]
+                    o_accum_local[i] += po_local[i] * scale_local[0]
🧹 Nitpick comments (9)
examples/flash_decoding/example_gqa_decode_varlen_logits.py (1)

246-254: Consider removing empty annotation block.

The T.annotate_layout({}) block with all layouts commented out provides no functionality. This could be cleaned up in a follow-up, though it's not breaking anything.

🔎 Proposed cleanup
-            T.annotate_layout(
-                {
-                    # Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
-                    # K_shared: tilelang.layout.make_swizzled_layout(K_shared),
-                    # V_shared: tilelang.layout.make_swizzled_layout(V_shared),
-                    # O_shared: tilelang.layout.make_swizzled_layout(O_shared),
-                    # S_shared: tilelang.layout.make_swizzled_layout(S_shared),
-                }
-            )
-
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)

72-102: Consider abstracting shared logic between layout variants.

The inline flash attention logic is nearly identical to example_mha_fwd_bshd_wgmma_pipelined.py, differing only in tensor layout indexing. While this duplication may be intentional for example clarity, you could reduce maintenance burden by abstracting the common kernel logic.

examples/seer_attention/block_sparse_attn_tilelang.py (1)

93-117: The inlined Softmax and Rescale logic is correct.

The FlashAttention algorithm is implemented correctly:

  • Max score tracking across blocks (lines 93-97)
  • Rescaling factor computation using exp2 (lines 103-104)
  • Softmax with log2(e) optimization (lines 105-109)
  • Running logsum update (lines 111-112)
  • Accumulated output rescaling (lines 115-116)

However, consider these optional refinements:

1. Simplify initialization (lines 94-95):
The explicit T.fill on line 94 followed by reduce_max with clear=False on line 95 is redundant. Since reduce_max with clear=True (the default) automatically initializes the output buffer to -infinity, you can simplify:

🔎 Optional simplification
-T.copy(scores_max, scores_max_prev)
-T.fill(scores_max, -T.infinity(accum_dtype))
-T.reduce_max(acc_s, scores_max, dim=1, clear=False)
+T.copy(scores_max, scores_max_prev)
+T.reduce_max(acc_s, scores_max, dim=1, clear=True)

Both approaches are functionally equivalent, but the simplified version is more idiomatic.

2. Commented code (lines 101-102):
The commented-out Check_inf logic includes a clear explanation referencing FlashAttention3. If this check is not needed for correctness (as evidenced by passing tests), consider removing the commented code to reduce clutter. If it may be needed for future edge cases, consider adding a more prominent TODO or issue reference.

examples/flash_attention/example_mha_fwd_bshd.py (1)

59-88: Optional: Consider adding phase comments for readability.

The inlined logic is correct and self-contained. For improved readability, consider adding brief comments to delineate the major phases within the pipelined loop:

  • K copy and masking (lines 60-67)
  • Online softmax normalization (lines 69-82)
  • Output rescaling (lines 83-85)
  • V GEMM (lines 86-88)

This would help readers quickly navigate the attention computation flow.

examples/flash_decoding/example_gqa_decode.py (1)

200-200: Unused variable bz should be prefixed with underscore.

The kernel dimension bz is unpacked but never used in the no-split path (since num_split == 1). Prefix it with an underscore to indicate it's intentionally unused.

🔎 Suggested fix
-        with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
+        with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, _bz):
examples/flash_decoding/example_mha_inference.py (3)

84-85: Consider removing or conditionally enabling commented Check_inf code.

The commented code addresses -inf handling for causal softmax. Either remove it if not needed, or implement it with the appropriate condition mentioned in the comment ("only need to be done in the first ceil_div(kBlockM, kBlockN) steps").


64-67: Extract split offset calculation for readability.

The repeated expression (seqlen_kv // num_split) * sid in K and V indexing could be extracted to improve clarity.

🔎 Suggested refactor

Before the loop (after line 62):

kv_split_offset = (seqlen_kv // num_split) * sid

Then simplify lines 65-66 and 102-103:

-                K[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :],
+                K[bid, kv_split_offset + k * block_N : kv_split_offset + (k + 1) * block_N, hid, :],
-                V[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :],
+                V[bid, kv_split_offset + k * block_N : kv_split_offset + (k + 1) * block_N, hid, :],

Also applies to: 101-104


167-174: Document unused parameters in ref_program.

The glse and Output_partial parameters are unused in the function body but appear required for signature matching with the profiler system. Consider adding a comment explaining this, or prefix them with underscore (_glse, _Output_partial) to indicate they're intentionally unused.

examples/gemm_streamk/example_tilelang_gemm_streamk.py (1)

93-130: LGTM! The first wave computation inlining looks correct.

The macro-to-inline transformation preserves the original stream-k semantics correctly, with proper handling of partial tiles via atomic operations.

Optional: Eliminate redundant subexpression calculation

The expression start_iter[0] % iters_per_tile is computed three times (lines 107, 114, 118). Consider reusing the remain_iters variable to improve readability:

             tile_id = start_iter[0] // iters_per_tile
             remain_iters = start_iter[0] % iters_per_tile
             pid_m = tile_id // T.ceildiv(N, block_N)
             pid_n = tile_id % T.ceildiv(N, block_N)

             T.clear(C_local)
             for k in T.Pipelined(end_iter[0] - start_iter[0], num_stages=num_stages):
                 T.copy(
-                    A[pid_m * block_M, (k + (start_iter[0] % iters_per_tile)) * block_K],
+                    A[pid_m * block_M, (k + remain_iters) * block_K],
                     A_shared,
                 )
                 T.copy(
-                    B[pid_n * block_N, (k + (start_iter[0] % iters_per_tile)) * block_K],
+                    B[pid_n * block_N, (k + remain_iters) * block_K],
                     B_shared,
                 )
                 T.gemm(A_shared, B_shared, C_local, transpose_B=trans_B)
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 783694f and 64dc38a.

📒 Files selected for processing (24)
  • examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
  • examples/attention_sink/example_mha_sink_fwd_bhsd.py
  • examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
  • examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
  • examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
  • examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
  • examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
  • examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
  • examples/deepseek_mla/example_mla_decode.py
  • examples/deepseek_mla/example_mla_decode_paged.py
  • examples/deepseek_mla/example_mla_decode_ws.py
  • examples/flash_attention/example_gqa_fwd_bshd.py
  • examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
  • examples/flash_attention/example_mha_fwd_bhsd.py
  • examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
  • examples/flash_attention/example_mha_fwd_bshd.py
  • examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
  • examples/flash_decoding/example_gqa_decode.py
  • examples/flash_decoding/example_gqa_decode_varlen_logits.py
  • examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py
  • examples/flash_decoding/example_mha_inference.py
  • examples/gemm_streamk/example_tilelang_gemm_streamk.py
  • examples/seer_attention/block_sparse_attn_tilelang.py
  • examples/warp_specialize/example_warp_specialize_flashmla.py
🧰 Additional context used
🧬 Code graph analysis (16)
examples/seer_attention/block_sparse_attn_tilelang.py (4)
src/tl_templates/cuda/reduce.h (1)
  • T (178-250)
tilelang/language/copy_op.py (1)
  • copy (14-95)
tilelang/language/fill_op.py (2)
  • fill (9-36)
  • clear (39-62)
tilelang/language/reduce_op.py (2)
  • reduce_max (107-125)
  • reduce_sum (144-166)
examples/gemm_streamk/example_tilelang_gemm_streamk.py (3)
tilelang/language/loop.py (2)
  • Pipelined (58-95)
  • Parallel (13-33)
tilelang/language/copy_op.py (1)
  • copy (14-95)
examples/gemm/example_gemm.py (1)
  • gemm (8-24)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
src/op/gemm.h (4)
  • GemmWarpPolicy (59-83)
  • GemmWarpPolicy (64-68)
  • GemmWarpPolicy (70-74)
  • GemmWarpPolicy (76-82)
examples/flash_decoding/example_gqa_decode_varlen_logits.py (2)
examples/flash_decoding/example_gqa_decode.py (1)
  • flashattn_gqa_decode_no_split (191-254)
examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py (1)
  • flashattn_gqa_decode_no_split (59-142)
examples/flash_attention/example_mha_fwd_bshd.py (5)
tilelang/language/copy_op.py (1)
  • copy (14-95)
tilelang/language/loop.py (1)
  • Parallel (13-33)
examples/gemm/example_gemm.py (1)
  • gemm (8-24)
tilelang/language/fill_op.py (2)
  • fill (9-36)
  • clear (39-62)
tilelang/language/reduce_op.py (2)
  • reduce_max (107-125)
  • reduce_sum (144-166)
examples/flash_attention/example_gqa_fwd_bshd.py (3)
tilelang/language/copy_op.py (1)
  • copy (14-95)
tilelang/language/loop.py (1)
  • Parallel (13-33)
tilelang/language/reduce_op.py (2)
  • reduce_max (107-125)
  • reduce_sum (144-166)
examples/warp_specialize/example_warp_specialize_flashmla.py (4)
tilelang/jit/adapter/wrapper.py (2)
  • prim_func (531-541)
  • prim_func (788-798)
tilelang/jit/adapter/ctypes/adapter.py (1)
  • prim_func (272-274)
tilelang/jit/adapter/cython/adapter.py (1)
  • prim_func (356-358)
tilelang/language/tir/entry.py (1)
  • prim_func (10-58)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (3)
examples/deepseek_mla/example_mla_decode.py (2)
  • main_split (25-126)
  • main_no_split (129-191)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (1)
  • main_split (138-181)
examples/warp_specialize/example_warp_specialize_flashmla.py (1)
  • main_no_split (21-294)
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (3)
tilelang/language/copy_op.py (1)
  • copy (14-95)
tilelang/language/loop.py (1)
  • Parallel (13-33)
examples/gemm/example_gemm.py (1)
  • gemm (8-24)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (2)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (2)
  • main (35-155)
  • main (325-398)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (2)
  • main (36-140)
  • main (313-380)
examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py (2)
examples/flash_decoding/example_gqa_decode.py (1)
  • flashattn_gqa_decode_no_split (191-254)
examples/flash_decoding/example_gqa_decode_varlen_logits.py (1)
  • flashattn_gqa_decode_no_split (220-317)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (3)
tilelang/jit/adapter/wrapper.py (2)
  • prim_func (531-541)
  • prim_func (788-798)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (2)
  • main (43-165)
  • main (331-512)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (2)
  • main (36-140)
  • main (313-380)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)
tilelang/tileop/gemm/gemm_base.py (2)
  • K (43-44)
  • policy (120-121)
src/op/gemm.h (4)
  • GemmWarpPolicy (59-83)
  • GemmWarpPolicy (64-68)
  • GemmWarpPolicy (70-74)
  • GemmWarpPolicy (76-82)
examples/deepseek_mla/example_mla_decode.py (3)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (1)
  • main_split (46-136)
examples/deepseek_mla/example_mla_decode_paged.py (1)
  • main_split (28-137)
examples/deepseek_mla/example_mla_decode_ws.py (1)
  • main_split (37-280)
examples/flash_decoding/example_mha_inference.py (4)
tilelang/jit/adapter/ctypes/adapter.py (1)
  • prim_func (272-274)
tilelang/language/copy_op.py (1)
  • copy (14-95)
examples/gemm/example_gemm.py (1)
  • gemm (8-24)
src/op/gemm.h (4)
  • GemmWarpPolicy (59-83)
  • GemmWarpPolicy (64-68)
  • GemmWarpPolicy (70-74)
  • GemmWarpPolicy (76-82)
examples/deepseek_mla/example_mla_decode_ws.py (2)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (2)
  • main_split (46-136)
  • main_no_split (139-197)
examples/deepseek_mla/example_mla_decode.py (2)
  • main_split (25-126)
  • main_no_split (129-191)
🪛 Ruff (0.14.10)
examples/warp_specialize/example_warp_specialize_flashmla.py

26-26: Unused function argument: glse

(ARG001)


27-27: Unused function argument: Output_partial

(ARG001)

examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py

144-144: Unused function argument: glse

(ARG001)


145-145: Unused function argument: Output_partial

(ARG001)

examples/deepseek_mla/example_mla_decode_paged.py

147-147: Unused function argument: glse

(ARG001)


148-148: Unused function argument: Output_partial

(ARG001)

examples/flash_decoding/example_gqa_decode.py

200-200: Unpacked variable bz is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/deepseek_mla/example_mla_decode_ws.py

288-288: Unused function argument: glse

(ARG001)


289-289: Unused function argument: Output_partial

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (43)
examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py (2)

26-26: LGTM on JIT decorator cleanup.

Removal of debug_root_path simplifies the decorator usage appropriately for production examples.


58-68: LGTM on prim_func conversion.

The conversion from macro to @T.prim_func is consistent with the broader PR pattern. The function signature correctly includes BLOCK_TABLE for paged KV cache support, and the output tensor uses the dynamic shape_o definition.

examples/flash_decoding/example_gqa_decode_varlen_logits.py (2)

201-202: LGTM on decorator stack.

The @autotune and @tilelang.jit(out_idx=[-2, -1]) decorators are correctly stacked. The out_idx=[-2, -1] correctly identifies Output and S as the output tensors.


219-228: LGTM on prim_func conversion.

The conversion from macro to @T.prim_func is correctly implemented. The function signature properly declares the 7 parameters with Output and S as outputs (matching out_idx=[-2, -1]). The caller at line 353 correctly passes only the 5 input arguments since the JIT decorator handles output tensor allocation.

examples/flash_attention/example_mha_fwd_bhsd.py (3)

66-76: LGTM: K loading, masking, and attention score computation correctly implemented.

The inlined operations correctly implement:

  • Block-wise K loading into shared memory
  • Causal masking with proper query/key index calculation using the past_len offset
  • Padding-aware masking for non-causal paths
  • Accumulation of attention scores via GEMM (Q @ K^T) on top of mask values

77-89: LGTM: Online softmax algorithm correctly implements numerically stable attention weights.

The inlined softmax computation properly maintains:

  • Running maximum tracking across blocks (lines 77-81)
  • Exponential rescaling of previous accumulations via scores_scale (lines 82-83)
  • Numerically stable exponentiation with max subtraction and log2(e) scaling (lines 84-85)
  • Running sum updates with proper scale factors (lines 86-88)

This is the standard flash attention online softmax pattern for memory-efficient, numerically stable attention.


91-96: LGTM: Output rescaling and value accumulation correctly implement the flash attention pattern.

The inlined operations properly:

  • Rescale previous output accumulations by scores_scale to account for updated max values (lines 91-92)
  • Load the corresponding V block into shared memory (line 94)
  • Accumulate new contributions via GEMM (attention_weights @ V) into acc_o (line 95)

This correctly completes the online attention computation before final normalization.

examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (3)

88-95: LGTM! K loading and masking logic is correct.

The inlined operations correctly implement:

  • Grouped query attention indexing (by // groups)
  • Causal masking (future tokens masked with -inf when query position < key position)
  • Sequence boundary checks for non-causal attention

97-109: LGTM! Online softmax implementation is correct.

The inlined softmax operations correctly implement the numerically stable online algorithm with:

  • Running max tracking across blocks
  • Proper rescaling of previous accumulations
  • Log2-space computations (using the log2(e) scale factor from line 44)

111-115: LGTM! Rescaling and final GEMM operations are correct.

The inlined operations properly:

  • Rescale accumulated outputs before adding new block contributions
  • Use consistent group indexing for V (matching K's indexing)
  • Perform the final attention-weighted value accumulation

The macro inlining appears to be a clean refactoring that improves code readability by making the computation steps explicit. Since the pipeline schedule parameters (order, stage, group) remain unchanged, you may want to verify that the pipelined execution still performs as expected after this refactoring.

examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (4)

72-81: LGTM! K-side GEMM with proper causal masking.

The inline logic correctly handles:

  • K block loading into shared memory
  • Causal masking with past_len offset for decoding scenarios
  • Out-of-bounds masking for non-causal paths
  • Q @ K^T computation via transposed GEMM

83-95: LGTM! Numerically stable online softmax.

The inline softmax correctly implements the incremental max-sum tracking pattern:

  • Maintains running max across blocks
  • Computes per-block exponentials with numerically stable scaling
  • Updates running logsum with previous block rescaling
  • Uses exp2 with pre-scaled factors for efficiency

97-98: LGTM! Correct accumulator rescaling.

The per-row rescaling of acc_o by scores_scale correctly adjusts previous contributions when the running max changes.


100-101: LGTM! V-side GEMM completes the attention computation.

The V block loading and final GEMM correctly compute the attention-weighted values, accumulating across blocks into acc_o.

examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (4)

66-73: LGTM! K-side GEMM with causal masking for equal-length sequences.

The inline logic correctly handles the BSHD layout with equal query/key sequence lengths:

  • K block loading with BSHD indexing
  • Causal masking without past_len offset (appropriate for seq_len == seq_len)
  • Out-of-bounds masking for the non-causal path
  • Q @ K^T computation

75-87: LGTM! Numerically stable online softmax.

Identical to the BHSD variant—correctly implements incremental max-sum tracking with rescaling.


89-90: LGTM! Correct accumulator rescaling.

Identical to the BHSD variant—correctly rescales acc_o before adding new V contributions.


92-93: LGTM! V-side GEMM completes the attention computation.

Correctly loads V with BSHD layout indexing and performs the final attention-weighted GEMM.

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (2)

34-47: LGTM - Clean refactoring from macro to prim_func.

The consolidation of flash_attn_split macro into the main prim_func is well-structured. The inline comments at Lines 46-47 help document the origin of the inlined logic.


122-155: Combine kernel correctly integrated.

The combine stage properly aggregates Output_partial across splits using the log-sum-exp scaling pattern, producing the final Output.

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (1)

42-53: LGTM - Consistent refactoring for paged attention variant.

The consolidation follows the same pattern as other files, with the paged attention-specific block_table parameter properly integrated into the kernel logic.

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)

35-44: LGTM - Macro removal consistent with other sparse attention variants.

The transition from macro-based to inline prim_func follows the established pattern, maintaining the block_mask-based sparse attention logic.

examples/blocksparse_attention/example_tilelang_block_sparse_attn.py (1)

85-117: LGTM - Flash attention logic correctly inlined.

The previously macro-based operations (MMA, Softmax, Rescale) are cleanly inlined into the main loop. The implementation correctly:

  • Handles causal masking (Lines 86-90)
  • Computes online softmax with running max and log-sum-exp tracking
  • Applies rescaling to maintain numerical stability
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (1)

144-146: Unused parameters are intentional for API consistency.

The static analysis correctly identifies that glse and Output_partial are unused in main_no_split. This is expected - these parameters exist to maintain a consistent function signature with main_split, allowing the caller to use the same interface regardless of which path is selected at runtime.

examples/deepseek_mla/example_mla_decode.py (2)

24-32: LGTM - Clean split/no-split separation.

The main_split function correctly implements the split-path execution with per-split glse and Output_partial handling, followed by a combine stage.


128-196: LGTM - Non-split path correctly writes directly to Output.

The main_no_split function properly bypasses the partial output mechanism and writes the final result directly to Output, which is the expected behavior when num_split == 1.

examples/deepseek_mla/example_mla_decode_paged.py (2)

27-38: LGTM - Paged attention variant with proper block table handling.

The main_split correctly incorporates block table lookups for paged KV cache access while maintaining the split-path execution pattern.


147-149: Unused parameters are intentional for API consistency.

Similar to the AMD benchmark file, glse and Output_partial are unused in main_no_split to maintain signature compatibility with main_split. This enables the conditional return at Lines 214-217 to work with a uniform interface.

examples/deepseek_mla/example_mla_decode_ws.py (3)

36-47: LGTM - Warp-specialized split path correctly structured.

The main_split function properly implements the warp-specialized execution pattern with:

  • Thread range 0-127: QK GEMM and left-half output accumulation
  • Thread range 128-255: Right-half output accumulation
  • Thread range 256+: Producer (KV data loading)

The split-aware partial output writes at Lines 177-178 and 208 are correct.


254-280: Combine stage correctly aggregates split outputs.

The combine kernel uses the standard log-sum-exp scaling pattern to merge Output_partial across splits into the final Output.


288-290: Unused parameters are intentional for API consistency.

As with other files, glse and Output_partial are unused in main_no_split to maintain a uniform function signature with main_split.

examples/warp_specialize/example_warp_specialize_flashmla.py (1)

354-354: Verify the rationale for changing the default batch size from 1 to 132.

The default batch size was changed from 1 to 132, which is a significant increase (132×). The value 132 is unusual (not a power of 2 or typical batch size), which raises questions about whether this change is:

  • Intentional for specific benchmarking requirements
  • Related to a particular hardware configuration or test scenario
  • Accidentally introduced

A batch size of 1 is more typical for examples and quick testing, while 132 significantly increases memory usage and execution time.

Please clarify the rationale for this change. If intentional, consider adding a comment explaining why this specific value is used.

examples/flash_attention/example_gqa_fwd_bshd.py (4)

107-114: Masking and GEMM logic looks correct.

The inlined masking correctly handles both causal and non-causal attention:

  • Causal masking properly masks future tokens (query < key positions)
  • Non-causal masking correctly handles out-of-bounds positions
  • GQA indexing with by // groups appropriately maps query heads to KV heads

116-128: Online softmax normalization is correctly implemented.

The inlined sequence properly implements the online softmax algorithm with correct order of operations:

  • Max propagation and combination with previous iterations
  • Appropriate scaling factors computed in log2 domain
  • Running logsum correctly accumulated with per-iteration scaling

130-134: Accumulator scaling and final GEMM are correct.

The acc_o scaling by scores_scale before the GEMM correctly implements the online softmax rescaling, and the GQA indexing is consistent for both K and V tensors.


106-135: Verify that the inlined implementation matches the original macro behavior.

The inlined attention computation logic is correct and properly implements:

  • GQA with appropriate head indexing
  • Online softmax with numerical stability
  • Causal and non-causal masking

However, ensure that tests pass and the behavior is identical to the original macro-based implementation, particularly for edge cases like:

  • Sequence lengths that don't align with block boundaries
  • Different group configurations
  • Causal vs non-causal modes

If you'd like, I can generate a verification script to check for any remaining references to the removed macros (MMA0, MMA1, Softmax, Rescale) in this file or related test files to ensure the cleanup is complete.

examples/flash_attention/example_mha_fwd_bshd.py (4)

60-67: LGTM! Masking and GEMM logic is correct.

The inlined K copy, causal masking, and attention score computation are correctly implemented:

  • Causal masking properly masks future tokens (query_pos < key_pos)
  • Non-causal masking correctly handles out-of-bounds positions
  • The GEMM accumulates into pre-initialized mask values as expected

69-82: LGTM! Online softmax implementation is correct.

The inlined softmax normalization logic correctly implements the online/incremental softmax algorithm:

  • Properly tracks and updates running maximum (scores_max)
  • Correctly computes scale factors for previous accumulations
  • Uses exp2 with log2(e) conversion for performance optimization
  • Maintains running sum (logsum) with appropriate rescaling

83-85: LGTM! Rescale operation is correct.

The rescaling of acc_o by scores_scale is a necessary step in the online softmax algorithm to account for updated maximum scores.


86-88: LGTM! Value GEMM is correct.

The V copy and attention-weighted GEMM operation correctly compute the attention output by accumulating across all key-value blocks.

examples/flash_decoding/example_gqa_decode.py (1)

256-259: LGTM!

Clean runtime selection between split and no-split paths based on num_split. This properly handles the two execution modes.

examples/flash_decoding/example_mha_inference.py (2)

20-27: LGTM: Clean prim_func signature.

The conversion from macro-based to prim_func with explicit intermediate tensors (glse, Output_partial) is well-structured for the split-combine pattern.


56-73: Causal split masking incomplete but properly documented and protected.

The causal mask on line 71 does not account for the split offset and would be incorrect when num_split > 1 and causal=True. However, this limitation is already explicitly documented via the TODO comments (lines 56, 68) and properly protected: ref_program (line 168) asserts causal is False, and tests default to causal=False, preventing incorrect behavior. The incomplete feature is intentionally deferred and does not pose a risk in the current state.

Comment on lines +112 to +149
T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)

for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]

T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

GQA variant duplicates logic with minor indexing differences.

This GQA implementation differs from the MHA versions only in the K/V indexing (by // groups on lines 112, 147). The core attention computation is identical. This is a strong signal that a parameterized helper function could serve all variants (MHA, GQA, with/without pipelining) while eliminating duplication.

A unified helper could accept an indexing callback or parameter to handle the MHA vs GQA difference, reducing the maintenance burden across all 27 affected files.

Comment on lines +106 to +143
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)

for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]

T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Consider consolidating duplicated attention logic across files.

The inlined attention computation (K copy, masking, GEMM, softmax operations, V processing) is now duplicated across multiple files in this PR. According to the AI summary, 27 files receive similar inlined logic. This creates maintenance challenges:

  • Bug fixes require changes in 27 locations
  • Performance optimizations must be replicated everywhere
  • Risk of inconsistencies if updates are missed

The PR removes macros but doesn't clarify why they were "unnecessary." Macros or helper functions are specifically designed to avoid this kind of duplication.

Could you clarify the rationale for removing the macro-based approach? If the goal is to improve readability or performance at specific call sites, consider:

  • Keeping a shared helper function with inline hints
  • Using a template/code generation approach during build
  • Documenting the specific issues that macros were causing

This would preserve the benefits of the refactor while maintaining DRY principles.

Comment on lines +97 to +134
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)

for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]

T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Same code duplication concern applies here.

This file contains the same inlined attention logic as example_mha_sink_fwd_bhsd_wgmma_pipelined.py. The duplication makes it difficult to maintain consistency across variants. Consider extracting the shared logic into a common helper.

Comment on lines +132 to +144
# compute full tiles
if sm_patition_factor > 0:
compute_full_tiles(pid, A, A_shared_full_tiles, B, B_shared_full_tiles, C, C_local)
for p in T.serial(sm_patition_factor):
tile_id = pid + streamk_tiles + p * total_sm
pid_m = tile_id // T.ceildiv(N, block_N)
pid_n = tile_id % T.ceildiv(N, block_N)
T.clear(C_local)

for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(A[pid_m * block_M, k * block_K], A_shared_full_tiles)
T.copy(B[pid_n * block_N, k * block_K], B_shared_full_tiles)
T.gemm(A_shared_full_tiles, B_shared_full_tiles, C_local, transpose_B=trans_B)
T.copy(C_local, C[pid_m * block_M, pid_n * block_N])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cat -n examples/gemm_streamk/example_tilelang_gemm_streamk.py | head -160

Repository: tile-ai/tilelang

Length of output: 6420


🏁 Script executed:

#!/bin/bash
# Extract the variable definitions to verify the boundary condition
sed -n '15,54p' examples/gemm_streamk/example_tilelang_gemm_streamk.py | cat -n

echo "---"
echo "Analyzing boundary conditions:"
python3 << 'EOF'
# Simulate the values
streamk_programs = 108

# Test case 1: Example from file
total_tiles_1 = (256 // 16) * (1024 // 128)  # 16 * 8 = 128
streamk_tiles_1 = total_tiles_1 % streamk_programs  # 128 % 108 = 20
if (total_tiles_1 - streamk_tiles_1) > streamk_programs:
    streamk_tiles_1 += streamk_programs
blocking_tiles_1 = total_tiles_1 - streamk_tiles_1
sm_partition_factor_1 = max(blocking_tiles_1 // streamk_programs, 1)
print(f"Case 1 (m=256, n=1024, k=512):")
print(f"  total_tiles={total_tiles_1}, streamk_tiles={streamk_tiles_1}, blocking_tiles={blocking_tiles_1}")
print(f"  sm_partition_factor={sm_partition_factor_1}")
print()

# Test case 2: When total_tiles < streamk_programs (edge case)
total_tiles_2 = 50
streamk_tiles_2 = total_tiles_2 % streamk_programs  # 50 % 108 = 50
if (total_tiles_2 - streamk_tiles_2) > streamk_programs:
    streamk_tiles_2 += streamk_programs
blocking_tiles_2 = total_tiles_2 - streamk_tiles_2
sm_partition_factor_2 = max(blocking_tiles_2 // streamk_programs, 1)
print(f"Case 2 (total_tiles=50, streamk_programs=108):")
print(f"  total_tiles={total_tiles_2}, streamk_tiles={streamk_tiles_2}, blocking_tiles={blocking_tiles_2}")
print(f"  sm_partition_factor={sm_partition_factor_2}")
print(f"  ISSUE: sm_partition_factor={sm_partition_factor_2} > 0 is True, but blocking_tiles={blocking_tiles_2}")
print()

# Test case 3: When total_tiles is close to streamk_programs
total_tiles_3 = 100
streamk_tiles_3 = total_tiles_3 % streamk_programs  # 100 % 108 = 100
if (total_tiles_3 - streamk_tiles_3) > streamk_programs:
    streamk_tiles_3 += streamk_programs
blocking_tiles_3 = total_tiles_3 - streamk_tiles_3
sm_partition_factor_3 = max(blocking_tiles_3 // streamk_programs, 1)
print(f"Case 3 (total_tiles=100, streamk_programs=108):")
print(f"  total_tiles={total_tiles_3}, streamk_tiles={streamk_tiles_3}, blocking_tiles={blocking_tiles_3}")
print(f"  sm_partition_factor={sm_partition_factor_3}")
EOF

Repository: tile-ai/tilelang

Length of output: 146


🏁 Script executed:

python3 << 'EOF'
# Simulate the boundary condition logic
streamk_programs = 108

# Test case 1: Example from file
total_tiles_1 = (256 // 16) * (1024 // 128)  # 16 * 8 = 128
streamk_tiles_1 = total_tiles_1 % streamk_programs  # 128 % 108 = 20
if (total_tiles_1 - streamk_tiles_1) > streamk_programs:
    streamk_tiles_1 += streamk_programs
blocking_tiles_1 = total_tiles_1 - streamk_tiles_1
sm_partition_factor_1 = max(blocking_tiles_1 // streamk_programs, 1)
print(f"Case 1 (m=256, n=1024, k=512):")
print(f"  total_tiles={total_tiles_1}, streamk_tiles={streamk_tiles_1}, blocking_tiles={blocking_tiles_1}")
print(f"  sm_partition_factor={sm_partition_factor_1}")
print(f"  Condition 'sm_partition_factor > 0' is {sm_partition_factor_1 > 0}")
print()

# Test case 2: When total_tiles < streamk_programs
total_tiles_2 = 50
streamk_tiles_2 = total_tiles_2 % streamk_programs  # 50 % 108 = 50
if (total_tiles_2 - streamk_tiles_2) > streamk_programs:
    streamk_tiles_2 += streamk_programs
blocking_tiles_2 = total_tiles_2 - streamk_tiles_2
sm_partition_factor_2 = max(blocking_tiles_2 // streamk_programs, 1)
print(f"Case 2 (total_tiles=50):")
print(f"  blocking_tiles={blocking_tiles_2}, sm_partition_factor={sm_partition_factor_2}")
print(f"  ISSUE: Condition is True but blocking_tiles=0!")
print()

# Test case 3: When total_tiles equals streamk_programs
total_tiles_3 = 108
streamk_tiles_3 = total_tiles_3 % streamk_programs  # 108 % 108 = 0
if (total_tiles_3 - streamk_tiles_3) > streamk_programs:
    streamk_tiles_3 += streamk_programs
blocking_tiles_3 = total_tiles_3 - streamk_tiles_3
sm_partition_factor_3 = max(blocking_tiles_3 // streamk_programs, 1)
print(f"Case 3 (total_tiles=108):")
print(f"  blocking_tiles={blocking_tiles_3}, sm_partition_factor={sm_partition_factor_3}")
print(f"  Condition is True but blocking_tiles=0!")
EOF

Repository: tile-ai/tilelang

Length of output: 446


Fix guard condition for full tiles computation.

The condition if sm_patition_factor > 0: is incorrect. When blocking_tiles = 0 (which occurs when total_tiles ≤ streamk_programs), sm_patition_factor = max(0 // total_sm, 1) evaluates to 1, causing the condition to pass despite having no blocking tiles to process. This leads to invalid tile indexing.

Change line 133 to if blocking_tiles > 0: to ensure the loop executes only when blocking tiles actually exist.

🤖 Prompt for AI Agents
In examples/gemm_streamk/example_tilelang_gemm_streamk.py around lines 132 to
144, the guard uses `if sm_patition_factor > 0:` which can be true even when
there are no blocking tiles (leading to invalid tile indexing); replace that
condition with `if blocking_tiles > 0:` so the full-tiles computation only runs
when blocking_tiles actually exist.

Comment on lines +26 to 28
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the file structure and find the main_no_split and main_split functions
fd "example_warp_specialize_flashmla.py"

Repository: tile-ai/tilelang

Length of output: 120


🏁 Script executed:

# Once we locate the file, let's examine the main_no_split function and understand the parameters
cat -n examples/warp_specialize/example_warp_specialize_flashmla.py | sed -n '20,40p'

Repository: tile-ai/tilelang

Length of output: 1294


🏁 Script executed:

# Search for main_split function to see if a split variant exists
rg -n "def main_split|@T.prim_func" examples/warp_specialize/example_warp_specialize_flashmla.py | head -20

Repository: tile-ai/tilelang

Length of output: 79


🏁 Script executed:

# Search for usages of glse and Output_partial within the file
rg -n "glse|Output_partial" examples/warp_specialize/example_warp_specialize_flashmla.py | head -40

Repository: tile-ai/tilelang

Length of output: 387


🏁 Script executed:

# Let's see the full function body and look for any usage of glse or Output_partial in the main_no_split function
cat -n examples/warp_specialize/example_warp_specialize_flashmla.py | sed -n '20,300p' | grep -E "(glse|Output_partial|def main_split|def ref_program)"

Repository: tile-ai/tilelang

Length of output: 269


🏁 Script executed:

# Let's check if there's a main_split function by searching more broadly
rg -n "def main_split" examples/warp_specialize/example_warp_specialize_flashmla.py

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Let's look at the ref_program function to understand what it does with these parameters
cat -n examples/warp_specialize/example_warp_specialize_flashmla.py | sed -n '299,330p'

Repository: tile-ai/tilelang

Length of output: 1763


🏁 Script executed:

# Let's check the file size and look for any calls to main_no_split with arguments
wc -l examples/warp_specialize/example_warp_specialize_flashmla.py

Repository: tile-ai/tilelang

Length of output: 124


🏁 Script executed:

# Let's look at the ref_program function completely to see if glse and Output_partial are actually used
cat -n examples/warp_specialize/example_warp_specialize_flashmla.py | sed -n '299,362p'

Repository: tile-ai/tilelang

Length of output: 3534


🏁 Script executed:

# Let's check how main_no_split is actually called/tested
rg -n "main_no_split" examples/warp_specialize/example_warp_specialize_flashmla.py

Repository: tile-ai/tilelang

Length of output: 114


🏁 Script executed:

# Let's also check if there's a split version anywhere in the codebase
fd ".*flashmla.*" --type f

Repository: tile-ai/tilelang

Length of output: 167


🏁 Script executed:

# Check the function signature and call context more carefully
cat -n examples/warp_specialize/example_warp_specialize_flashmla.py | sed -n '340,362p'

Repository: tile-ai/tilelang

Length of output: 1437


🏁 Script executed:

# Let's check the flashattn function to understand the architecture better
rg -n "def flashattn" examples/warp_specialize/example_warp_specialize_flashmla.py -A 5

Repository: tile-ai/tilelang

Length of output: 375


🏁 Script executed:

# Let's look at the full main_no_split function body to confirm no usage of glse/Output_partial
cat -n examples/warp_specialize/example_warp_specialize_flashmla.py | sed -n '20,300p'

Repository: tile-ai/tilelang

Length of output: 15063


Remove or mark the unused glse and Output_partial parameters in the function signature.

These parameters are declared in both main_no_split and ref_program but never referenced in their implementations. Since they add unnecessary interface complexity, either remove them entirely or prefix with underscore (_glse, _Output_partial) to explicitly mark them as intentionally unused.

🧰 Tools
🪛 Ruff (0.14.10)

26-26: Unused function argument: glse

(ARG001)


27-27: Unused function argument: Output_partial

(ARG001)

🤖 Prompt for AI Agents
In examples/warp_specialize/example_warp_specialize_flashmla.py around lines 26
to 28, the function signatures for main_no_split and ref_program include unused
parameters glse and Output_partial; either remove these parameters from both
function signatures and all call sites, or rename them to _glse and
_Output_partial to mark them as intentionally unused; update any type
annotations and docstrings accordingly to keep signatures consistent across
definitions and uses.

@LeiWang1999
Copy link
Member

@codex review

@chatgpt-codex-connector
Copy link

Codex Review: Didn't find any major issues. You're on a roll.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@Rachmanino
Copy link
Collaborator Author

waiting to rebase on #1509

LeiWang1999
LeiWang1999 previously approved these changes Dec 24, 2025
@LeiWang1999
Copy link
Member

@codex review

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
examples/warp_specialize/example_warp_specialize_flashmla.py (1)

26-27: Unused parameters flagged in previous review.

The glse and Output_partial parameters remain unused as flagged in the previous review comment on lines 26-28. Please address the earlier feedback to either remove these parameters or mark them as intentionally unused with underscore prefixes.

examples/attention_sink/example_mha_sink_fwd_bhsd.py (1)

87-124: Code duplication concern from previous review still applies.

The inlined attention logic in this block appears functionally correct and follows the standard FlashAttention pattern with proper causal masking, sliding window support, and numerical stability. However, as noted in the previous review, this same logic is duplicated across multiple variant files (e.g., example_mha_sink_fwd_bhsd_wgmma_pipelined.py), creating maintenance challenges.

🧹 Nitpick comments (4)
examples/flash_decoding/example_mha_inference.py (1)

158-165: Consider adding a corresponding guard in main or flashattn.

ref_program asserts causal is False, but a caller could invoke main(causal=True) or construct a kernel directly with causal mode. Given the incomplete causal split handling noted in the TODO comments, consider adding an early check or assertion at the kernel level to fail fast with a clear message rather than relying on the reference function's assertion during validation.

examples/flash_decoding/example_gqa_decode_varlen_logits.py (2)

341-343: Consider removing unused tensor allocations.

O_tl and S_tl are allocated at lines 341-342 but immediately overwritten by the kernel call at line 343. These pre-allocations are unnecessary since the kernel returns new tensors based on out_idx.

🔎 Proposed fix
-    O_tl = torch.zeros_like(Q)
-    S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device)
     O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux)

877-880: Hardcoded argument overrides disable CLI flexibility.

These lines override parsed arguments, which defeats the purpose of the argparse configuration. If this is intentional for testing, consider removing these lines or adding a comment explaining the temporary override.

🔎 Proposed fix
-    args.test_sink = True
-    args.test_varlen = False
-    args.dtype = T.float16
-    args.num_split = 1
examples/deepseek_mla/example_mla_decode.py (1)

122-180: Minor inconsistency: Consider aligning with main_split for maintainability

The main_no_split implementation differs slightly from main_split in two ways:

  1. Accumulator clearing approach (line 159 vs line 66-67): Uses clear_accum=True parameter instead of explicit T.clear(acc_s). Both are correct.

  2. Final GEMM input (line 176 vs line 85): Uses S_shared directly instead of copying to acc_s_cast. The main_split version allocates acc_s_cast and copies through it (lines 43, 80), while main_no_split omits this buffer entirely.

While both approaches are functionally correct (both pass float16 to the final GEMM), the inconsistency may impact maintainability. Consider aligning the implementations unless the difference is intentional optimization.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ea7da83 and bf5e2d9.

📒 Files selected for processing (8)
  • examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
  • examples/attention_sink/example_mha_sink_fwd_bhsd.py
  • examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
  • examples/deepseek_mla/example_mla_decode.py
  • examples/deepseek_mla/example_mla_decode_paged.py
  • examples/flash_decoding/example_gqa_decode_varlen_logits.py
  • examples/flash_decoding/example_mha_inference.py
  • examples/warp_specialize/example_warp_specialize_flashmla.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
🧰 Additional context used
🪛 Ruff (0.14.10)
examples/warp_specialize/example_warp_specialize_flashmla.py

26-26: Unused function argument: glse

(ARG001)


27-27: Unused function argument: Output_partial

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (18)
examples/warp_specialize/example_warp_specialize_flashmla.py (2)

20-21: LGTM! Macro replaced with prim_func as intended.

The transition from a macro-based implementation to @T.prim_func aligns with the PR's cleanup objectives to remove unnecessary macros and use explicit primitive function entry points.


347-347: Verify the inconsistent batch size default.

The argparse default for --batch is now 132, but the function main() at line 328 still has batch=1 as its default parameter. This creates an inconsistency between programmatic and CLI usage.

If this is intentional (e.g., 132 is optimized for specific hardware or benchmarking), consider documenting the rationale. Otherwise, align both defaults for consistency.

examples/flash_decoding/example_mha_inference.py (4)

20-28: LGTM on the inlined prim_func signature.

The function signature correctly defines all input/output tensors with appropriate shapes for the split/combine flash decoding pattern. The transition from macro-based to inline implementation is clean.


56-67: Verify behavior when seqlen_kv is not evenly divisible by num_split.

The integer division seqlen_kv // num_split on lines 60, 65, and 102 will truncate, potentially skipping the last seqlen_kv % num_split KV elements. If this edge case is expected to be handled by callers (i.e., only passing divisible lengths), consider adding an assertion or documenting the constraint.

Additionally, the TODO at line 56 notes that causal split case handling is incomplete—the causal loop range (line 58) references full seqlen_kv but K/V indexing still uses split portions, which may produce incorrect results for causal attention with splits.


76-105: LGTM on the inlined softmax and rescale logic.

The online softmax implementation correctly follows the FlashAttention pattern with exp2/log2(e) optimization. The rescaling of acc_o before GEMM and the log-sum tracking are properly implemented.


115-153: LGTM on the combine phase.

The combine kernel correctly implements the log-sum-exp reduction across splits, matching the pattern in the reference implementation. The use of lse_max_local for numerical stability in the exp2 computations is appropriate.

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)

96-133: Code correctness verified; duplication concern already raised.

The inlined attention logic (K copy, causal/sliding-window masking, flash attention softmax, V processing) is correct and follows standard patterns. The past review comment on this range already comprehensively addresses the code duplication concern across multiple files. I have no additional issues to raise beyond what has been covered.

examples/flash_decoding/example_gqa_decode_varlen_logits.py (4)

201-202: LGTM!

Removing debug_root_path from the decorator is a good cleanup. The out_idx=[-2, -1] correctly identifies the output tensors (Output and S) based on their positions in the function signature.


219-228: LGTM!

The conversion from macro-based implementation to explicit @T.prim_func is clean. The Output tensor correctly uses shape_o which is defined earlier as [batch, heads, dim], matching the expected output shape for the decode operation.


375-376: Consistent with the tilelang path, but same verification applies.

The contiguity assertions are relaxed in both the tilelang and Triton implementations. While consistency is good, the underlying Triton kernel (_fwd_kernel_varlen) also needs to correctly handle strided K/V tensors. The kernel does pass explicit stride parameters, which suggests strided access is intended.


336-337: The K and V contiguity assertions are safely removed.

The flashattn_gqa_decode_no_split kernel uses strided memory access via T.copy(), which is designed to handle arbitrary tensor layouts. Unlike the paged variant (which requires contiguity for the block table mechanism), the non-paged variant correctly supports non-contiguous K and V tensors.

examples/deepseek_mla/example_mla_decode.py (3)

24-33: LGTM: Clean prim_func signature

The function signature correctly exposes all necessary parameters for the split attention path, including the intermediate outputs (glse, Output_partial) and final output (Output).


34-93: LGTM: Split attention kernel logic is correct

The split attention computation properly implements online softmax with per-split accumulation. The use of acc_s_cast for the final gemm ensures proper dtype handling.


94-120: LGTM: Combine logic correctly merges split results

The log-sum-exp merging logic properly combines partial attention outputs from multiple splits using numerically stable exp2 operations.

examples/deepseek_mla/example_mla_decode_paged.py (4)

27-38: LGTM: Improved parameter naming convention

The function signature correctly uses lowercase snake_case for block_table and cache_seqlens, which is more Pythonic than the previous uppercase macro-style naming.


80-82: LGTM: Correct boundary masking

The masking logic properly handles out-of-bounds elements by checking against cache_seqlens[bx], ensuring invalid positions are set to negative infinity before the softmax operation.


202-205: LGTM: Clean dispatch logic

The dispatch correctly selects between split and no-split variants based on the num_split parameter.


169-200: Remove the block_table offset concern—line 172 is correct for the no_split reverse iteration.

The offset calculation at line 172 is correct because the no_split kernel processes all blocks for a given batch element (no start offset like main_split). Since there is no start offset, both the block index (k * block_N) // block_size and the offset (k * block_N) % block_size correctly use only k * block_N.

The reverse iteration (k = loop_range - 1 - kr) is intentional and the conditional masking at kr == 0 correctly handles the last block, which may be partially filled. This pattern is consistent with other paged KV cache implementations in the codebase.

Likely an incorrect or invalid review comment.

Comment on lines +65 to 73
total_blocks = T.ceildiv(cache_seqlens[bx], block_N)
blocks_per_split = T.floordiv(total_blocks, num_split)
remaining_blocks = T.floormod(total_blocks, num_split)
loop_range = blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0)
start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N

for k in T.Pipelined(loop_range, num_stages=2):
kv_start = BLOCK_TABLE[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size
kv_start = block_table[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size
T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Incorrect offset calculation in block_table indexing

Line 72 has a bug in the offset calculation within the block:

kv_start = block_table[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size

The offset calculation (k * block_N) % block_size ignores the start position, which can lead to incorrect memory access when start is not aligned to block_size.

Example scenario:

  • block_size = 128, block_N = 64
  • start = 192 (split doesn't start at block boundary)
  • k = 0: The code computes offset = (0 * 64) % 128 = 0, but position 192 is actually at offset 64 within its block.
🔎 Proposed fix
-                kv_start = block_table[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size
+                kv_start = block_table[bx, (start + k * block_N) // block_size] * block_size + (start + k * block_N) % block_size
🤖 Prompt for AI Agents
examples/deepseek_mla/example_mla_decode_paged.py around lines 65 to 73: the
block_table indexing computes the intra-block offset using (k * block_N) %
block_size which ignores the split start and can point to the wrong position;
change the index and offset calculations to use the absolute position = start +
k * block_N, e.g. compute idx = (position) // block_size and offset = (position)
% block_size, then compute kv_start = block_table[bx, idx] * block_size + offset
so the start contribution is included and accesses are correct.

@chatgpt-codex-connector
Copy link

Codex Review: Didn't find any major issues. Delightful!

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@LeiWang1999 LeiWang1999 merged commit 42697c0 into tile-ai:main Dec 24, 2025
1 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants