[Cleanup] Remove unnecessary macros in tilelang examples#1514
[Cleanup] Remove unnecessary macros in tilelang examples#1514LeiWang1999 merged 4 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Warning Rate limit exceeded@LeiWang1999 has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 19 minutes and 38 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughThis PR inlines previously macro-based helpers (MMA0, MMA1, Softmax, Rescale) into many attention kernels' main loops and replaces several macro split/combine flows with explicit prim_func entry points (main_split / main_no_split) that add parameters like glse, Output_partial, and Output. No widespread external API removals; some prim_func signatures were added or adjusted. Changes
Sequence Diagram(s)(omitted — changes are dispersed across many example kernels and are primarily macro inlining / prim_func reorganization rather than a new multi-component runtime flow) Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/flash_decoding/example_gqa_decode.py (1)
182-186: Potential out-of-bounds access whendim != 128.
scale_localis allocated with shape[128](line 167), but at line 186 it's indexed byiwhereiiterates overdim. Ifdim > 128, this causes an out-of-bounds access. The existing comment on line 184 acknowledges this coupling but doesn't prevent the issue.Since
scaleis uniform across the dimension (a per-head scalar), consider either:
- Reducing
scale_localto a scalar and broadcasting explicitly, or- Allocating
scale_localwith size[dim]and adjusting the parallel range accordingly.🔎 Suggested fix (Option 1 - use scalar broadcast)
- scale_local = T.alloc_fragment([128], accum_dtype) + scale_local = T.alloc_fragment([1], accum_dtype) ... for k in T.serial(num_split): for i in T.Parallel(dim): po_local[i] = Output_partial[bz, by, k, i] - for j in T.Parallel(128): - scale_local[j] = T.exp2(lse_local[k, j] - lse_logsum_local[j]) + scale_local[0] = T.exp2(lse_local[k, 0] - lse_logsum_local[0]) # Note: Pay attention to dim and the number of threads in Parallel for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[i] + o_accum_local[i] += po_local[i] * scale_local[0]
🧹 Nitpick comments (9)
examples/flash_decoding/example_gqa_decode_varlen_logits.py (1)
246-254: Consider removing empty annotation block.The
T.annotate_layout({})block with all layouts commented out provides no functionality. This could be cleaned up in a follow-up, though it's not breaking anything.🔎 Proposed cleanup
- T.annotate_layout( - { - # Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - # K_shared: tilelang.layout.make_swizzled_layout(K_shared), - # V_shared: tilelang.layout.make_swizzled_layout(V_shared), - # O_shared: tilelang.layout.make_swizzled_layout(O_shared), - # S_shared: tilelang.layout.make_swizzled_layout(S_shared), - } - ) -examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
72-102: Consider abstracting shared logic between layout variants.The inline flash attention logic is nearly identical to
example_mha_fwd_bshd_wgmma_pipelined.py, differing only in tensor layout indexing. While this duplication may be intentional for example clarity, you could reduce maintenance burden by abstracting the common kernel logic.examples/seer_attention/block_sparse_attn_tilelang.py (1)
93-117: The inlined Softmax and Rescale logic is correct.The FlashAttention algorithm is implemented correctly:
- Max score tracking across blocks (lines 93-97)
- Rescaling factor computation using exp2 (lines 103-104)
- Softmax with log2(e) optimization (lines 105-109)
- Running logsum update (lines 111-112)
- Accumulated output rescaling (lines 115-116)
However, consider these optional refinements:
1. Simplify initialization (lines 94-95):
The explicitT.fillon line 94 followed byreduce_maxwithclear=Falseon line 95 is redundant. Sincereduce_maxwithclear=True(the default) automatically initializes the output buffer to-infinity, you can simplify:🔎 Optional simplification
-T.copy(scores_max, scores_max_prev) -T.fill(scores_max, -T.infinity(accum_dtype)) -T.reduce_max(acc_s, scores_max, dim=1, clear=False) +T.copy(scores_max, scores_max_prev) +T.reduce_max(acc_s, scores_max, dim=1, clear=True)Both approaches are functionally equivalent, but the simplified version is more idiomatic.
2. Commented code (lines 101-102):
The commented-out Check_inf logic includes a clear explanation referencing FlashAttention3. If this check is not needed for correctness (as evidenced by passing tests), consider removing the commented code to reduce clutter. If it may be needed for future edge cases, consider adding a more prominent TODO or issue reference.examples/flash_attention/example_mha_fwd_bshd.py (1)
59-88: Optional: Consider adding phase comments for readability.The inlined logic is correct and self-contained. For improved readability, consider adding brief comments to delineate the major phases within the pipelined loop:
- K copy and masking (lines 60-67)
- Online softmax normalization (lines 69-82)
- Output rescaling (lines 83-85)
- V GEMM (lines 86-88)
This would help readers quickly navigate the attention computation flow.
examples/flash_decoding/example_gqa_decode.py (1)
200-200: Unused variablebzshould be prefixed with underscore.The kernel dimension
bzis unpacked but never used in the no-split path (sincenum_split == 1). Prefix it with an underscore to indicate it's intentionally unused.🔎 Suggested fix
- with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, _bz):examples/flash_decoding/example_mha_inference.py (3)
84-85: Consider removing or conditionally enabling commented Check_inf code.The commented code addresses -inf handling for causal softmax. Either remove it if not needed, or implement it with the appropriate condition mentioned in the comment ("only need to be done in the first ceil_div(kBlockM, kBlockN) steps").
64-67: Extract split offset calculation for readability.The repeated expression
(seqlen_kv // num_split) * sidin K and V indexing could be extracted to improve clarity.🔎 Suggested refactor
Before the loop (after line 62):
kv_split_offset = (seqlen_kv // num_split) * sidThen simplify lines 65-66 and 102-103:
- K[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], + K[bid, kv_split_offset + k * block_N : kv_split_offset + (k + 1) * block_N, hid, :],- V[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], + V[bid, kv_split_offset + k * block_N : kv_split_offset + (k + 1) * block_N, hid, :],Also applies to: 101-104
167-174: Document unused parameters in ref_program.The
glseandOutput_partialparameters are unused in the function body but appear required for signature matching with the profiler system. Consider adding a comment explaining this, or prefix them with underscore (_glse,_Output_partial) to indicate they're intentionally unused.examples/gemm_streamk/example_tilelang_gemm_streamk.py (1)
93-130: LGTM! The first wave computation inlining looks correct.The macro-to-inline transformation preserves the original stream-k semantics correctly, with proper handling of partial tiles via atomic operations.
Optional: Eliminate redundant subexpression calculation
The expression
start_iter[0] % iters_per_tileis computed three times (lines 107, 114, 118). Consider reusing theremain_itersvariable to improve readability:tile_id = start_iter[0] // iters_per_tile remain_iters = start_iter[0] % iters_per_tile pid_m = tile_id // T.ceildiv(N, block_N) pid_n = tile_id % T.ceildiv(N, block_N) T.clear(C_local) for k in T.Pipelined(end_iter[0] - start_iter[0], num_stages=num_stages): T.copy( - A[pid_m * block_M, (k + (start_iter[0] % iters_per_tile)) * block_K], + A[pid_m * block_M, (k + remain_iters) * block_K], A_shared, ) T.copy( - B[pid_n * block_N, (k + (start_iter[0] % iters_per_tile)) * block_K], + B[pid_n * block_N, (k + remain_iters) * block_K], B_shared, ) T.gemm(A_shared, B_shared, C_local, transpose_B=trans_B)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (24)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.pyexamples/attention_sink/example_mha_sink_fwd_bhsd.pyexamples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.pyexamples/blocksparse_attention/example_tilelang_block_sparse_attn.pyexamples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.pyexamples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.pyexamples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.pyexamples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.pyexamples/deepseek_mla/example_mla_decode.pyexamples/deepseek_mla/example_mla_decode_paged.pyexamples/deepseek_mla/example_mla_decode_ws.pyexamples/flash_attention/example_gqa_fwd_bshd.pyexamples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.pyexamples/flash_attention/example_mha_fwd_bhsd.pyexamples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.pyexamples/flash_attention/example_mha_fwd_bshd.pyexamples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.pyexamples/flash_decoding/example_gqa_decode.pyexamples/flash_decoding/example_gqa_decode_varlen_logits.pyexamples/flash_decoding/example_gqa_decode_varlen_logits_paged.pyexamples/flash_decoding/example_mha_inference.pyexamples/gemm_streamk/example_tilelang_gemm_streamk.pyexamples/seer_attention/block_sparse_attn_tilelang.pyexamples/warp_specialize/example_warp_specialize_flashmla.py
🧰 Additional context used
🧬 Code graph analysis (16)
examples/seer_attention/block_sparse_attn_tilelang.py (4)
src/tl_templates/cuda/reduce.h (1)
T(178-250)tilelang/language/copy_op.py (1)
copy(14-95)tilelang/language/fill_op.py (2)
fill(9-36)clear(39-62)tilelang/language/reduce_op.py (2)
reduce_max(107-125)reduce_sum(144-166)
examples/gemm_streamk/example_tilelang_gemm_streamk.py (3)
tilelang/language/loop.py (2)
Pipelined(58-95)Parallel(13-33)tilelang/language/copy_op.py (1)
copy(14-95)examples/gemm/example_gemm.py (1)
gemm(8-24)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
src/op/gemm.h (4)
GemmWarpPolicy(59-83)GemmWarpPolicy(64-68)GemmWarpPolicy(70-74)GemmWarpPolicy(76-82)
examples/flash_decoding/example_gqa_decode_varlen_logits.py (2)
examples/flash_decoding/example_gqa_decode.py (1)
flashattn_gqa_decode_no_split(191-254)examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py (1)
flashattn_gqa_decode_no_split(59-142)
examples/flash_attention/example_mha_fwd_bshd.py (5)
tilelang/language/copy_op.py (1)
copy(14-95)tilelang/language/loop.py (1)
Parallel(13-33)examples/gemm/example_gemm.py (1)
gemm(8-24)tilelang/language/fill_op.py (2)
fill(9-36)clear(39-62)tilelang/language/reduce_op.py (2)
reduce_max(107-125)reduce_sum(144-166)
examples/flash_attention/example_gqa_fwd_bshd.py (3)
tilelang/language/copy_op.py (1)
copy(14-95)tilelang/language/loop.py (1)
Parallel(13-33)tilelang/language/reduce_op.py (2)
reduce_max(107-125)reduce_sum(144-166)
examples/warp_specialize/example_warp_specialize_flashmla.py (4)
tilelang/jit/adapter/wrapper.py (2)
prim_func(531-541)prim_func(788-798)tilelang/jit/adapter/ctypes/adapter.py (1)
prim_func(272-274)tilelang/jit/adapter/cython/adapter.py (1)
prim_func(356-358)tilelang/language/tir/entry.py (1)
prim_func(10-58)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (3)
examples/deepseek_mla/example_mla_decode.py (2)
main_split(25-126)main_no_split(129-191)examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (1)
main_split(138-181)examples/warp_specialize/example_warp_specialize_flashmla.py (1)
main_no_split(21-294)
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (3)
tilelang/language/copy_op.py (1)
copy(14-95)tilelang/language/loop.py (1)
Parallel(13-33)examples/gemm/example_gemm.py (1)
gemm(8-24)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (2)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (2)
main(35-155)main(325-398)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (2)
main(36-140)main(313-380)
examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py (2)
examples/flash_decoding/example_gqa_decode.py (1)
flashattn_gqa_decode_no_split(191-254)examples/flash_decoding/example_gqa_decode_varlen_logits.py (1)
flashattn_gqa_decode_no_split(220-317)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (3)
tilelang/jit/adapter/wrapper.py (2)
prim_func(531-541)prim_func(788-798)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (2)
main(43-165)main(331-512)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (2)
main(36-140)main(313-380)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)
tilelang/tileop/gemm/gemm_base.py (2)
K(43-44)policy(120-121)src/op/gemm.h (4)
GemmWarpPolicy(59-83)GemmWarpPolicy(64-68)GemmWarpPolicy(70-74)GemmWarpPolicy(76-82)
examples/deepseek_mla/example_mla_decode.py (3)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (1)
main_split(46-136)examples/deepseek_mla/example_mla_decode_paged.py (1)
main_split(28-137)examples/deepseek_mla/example_mla_decode_ws.py (1)
main_split(37-280)
examples/flash_decoding/example_mha_inference.py (4)
tilelang/jit/adapter/ctypes/adapter.py (1)
prim_func(272-274)tilelang/language/copy_op.py (1)
copy(14-95)examples/gemm/example_gemm.py (1)
gemm(8-24)src/op/gemm.h (4)
GemmWarpPolicy(59-83)GemmWarpPolicy(64-68)GemmWarpPolicy(70-74)GemmWarpPolicy(76-82)
examples/deepseek_mla/example_mla_decode_ws.py (2)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (2)
main_split(46-136)main_no_split(139-197)examples/deepseek_mla/example_mla_decode.py (2)
main_split(25-126)main_no_split(129-191)
🪛 Ruff (0.14.10)
examples/warp_specialize/example_warp_specialize_flashmla.py
26-26: Unused function argument: glse
(ARG001)
27-27: Unused function argument: Output_partial
(ARG001)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
144-144: Unused function argument: glse
(ARG001)
145-145: Unused function argument: Output_partial
(ARG001)
examples/deepseek_mla/example_mla_decode_paged.py
147-147: Unused function argument: glse
(ARG001)
148-148: Unused function argument: Output_partial
(ARG001)
examples/flash_decoding/example_gqa_decode.py
200-200: Unpacked variable bz is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/deepseek_mla/example_mla_decode_ws.py
288-288: Unused function argument: glse
(ARG001)
289-289: Unused function argument: Output_partial
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (43)
examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py (2)
26-26: LGTM on JIT decorator cleanup.Removal of
debug_root_pathsimplifies the decorator usage appropriately for production examples.
58-68: LGTM on prim_func conversion.The conversion from macro to
@T.prim_funcis consistent with the broader PR pattern. The function signature correctly includesBLOCK_TABLEfor paged KV cache support, and the output tensor uses the dynamicshape_odefinition.examples/flash_decoding/example_gqa_decode_varlen_logits.py (2)
201-202: LGTM on decorator stack.The
@autotuneand@tilelang.jit(out_idx=[-2, -1])decorators are correctly stacked. Theout_idx=[-2, -1]correctly identifiesOutputandSas the output tensors.
219-228: LGTM on prim_func conversion.The conversion from macro to
@T.prim_funcis correctly implemented. The function signature properly declares the 7 parameters withOutputandSas outputs (matchingout_idx=[-2, -1]). The caller at line 353 correctly passes only the 5 input arguments since the JIT decorator handles output tensor allocation.examples/flash_attention/example_mha_fwd_bhsd.py (3)
66-76: LGTM: K loading, masking, and attention score computation correctly implemented.The inlined operations correctly implement:
- Block-wise K loading into shared memory
- Causal masking with proper query/key index calculation using the
past_lenoffset- Padding-aware masking for non-causal paths
- Accumulation of attention scores via GEMM (Q @ K^T) on top of mask values
77-89: LGTM: Online softmax algorithm correctly implements numerically stable attention weights.The inlined softmax computation properly maintains:
- Running maximum tracking across blocks (lines 77-81)
- Exponential rescaling of previous accumulations via
scores_scale(lines 82-83)- Numerically stable exponentiation with max subtraction and log2(e) scaling (lines 84-85)
- Running sum updates with proper scale factors (lines 86-88)
This is the standard flash attention online softmax pattern for memory-efficient, numerically stable attention.
91-96: LGTM: Output rescaling and value accumulation correctly implement the flash attention pattern.The inlined operations properly:
- Rescale previous output accumulations by
scores_scaleto account for updated max values (lines 91-92)- Load the corresponding V block into shared memory (line 94)
- Accumulate new contributions via GEMM (attention_weights @ V) into
acc_o(line 95)This correctly completes the online attention computation before final normalization.
examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (3)
88-95: LGTM! K loading and masking logic is correct.The inlined operations correctly implement:
- Grouped query attention indexing (
by // groups)- Causal masking (future tokens masked with -inf when query position < key position)
- Sequence boundary checks for non-causal attention
97-109: LGTM! Online softmax implementation is correct.The inlined softmax operations correctly implement the numerically stable online algorithm with:
- Running max tracking across blocks
- Proper rescaling of previous accumulations
- Log2-space computations (using the log2(e) scale factor from line 44)
111-115: LGTM! Rescaling and final GEMM operations are correct.The inlined operations properly:
- Rescale accumulated outputs before adding new block contributions
- Use consistent group indexing for V (matching K's indexing)
- Perform the final attention-weighted value accumulation
The macro inlining appears to be a clean refactoring that improves code readability by making the computation steps explicit. Since the pipeline schedule parameters (
order,stage,group) remain unchanged, you may want to verify that the pipelined execution still performs as expected after this refactoring.examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (4)
72-81: LGTM! K-side GEMM with proper causal masking.The inline logic correctly handles:
- K block loading into shared memory
- Causal masking with
past_lenoffset for decoding scenarios- Out-of-bounds masking for non-causal paths
- Q @ K^T computation via transposed GEMM
83-95: LGTM! Numerically stable online softmax.The inline softmax correctly implements the incremental max-sum tracking pattern:
- Maintains running max across blocks
- Computes per-block exponentials with numerically stable scaling
- Updates running logsum with previous block rescaling
- Uses
exp2with pre-scaled factors for efficiency
97-98: LGTM! Correct accumulator rescaling.The per-row rescaling of
acc_obyscores_scalecorrectly adjusts previous contributions when the running max changes.
100-101: LGTM! V-side GEMM completes the attention computation.The V block loading and final GEMM correctly compute the attention-weighted values, accumulating across blocks into
acc_o.examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (4)
66-73: LGTM! K-side GEMM with causal masking for equal-length sequences.The inline logic correctly handles the BSHD layout with equal query/key sequence lengths:
- K block loading with BSHD indexing
- Causal masking without
past_lenoffset (appropriate forseq_len == seq_len)- Out-of-bounds masking for the non-causal path
- Q @ K^T computation
75-87: LGTM! Numerically stable online softmax.Identical to the BHSD variant—correctly implements incremental max-sum tracking with rescaling.
89-90: LGTM! Correct accumulator rescaling.Identical to the BHSD variant—correctly rescales
acc_obefore adding new V contributions.
92-93: LGTM! V-side GEMM completes the attention computation.Correctly loads V with BSHD layout indexing and performs the final attention-weighted GEMM.
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (2)
34-47: LGTM - Clean refactoring from macro to prim_func.The consolidation of
flash_attn_splitmacro into themainprim_func is well-structured. The inline comments at Lines 46-47 help document the origin of the inlined logic.
122-155: Combine kernel correctly integrated.The combine stage properly aggregates
Output_partialacross splits using the log-sum-exp scaling pattern, producing the finalOutput.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (1)
42-53: LGTM - Consistent refactoring for paged attention variant.The consolidation follows the same pattern as other files, with the paged attention-specific
block_tableparameter properly integrated into the kernel logic.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)
35-44: LGTM - Macro removal consistent with other sparse attention variants.The transition from macro-based to inline prim_func follows the established pattern, maintaining the block_mask-based sparse attention logic.
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py (1)
85-117: LGTM - Flash attention logic correctly inlined.The previously macro-based operations (MMA, Softmax, Rescale) are cleanly inlined into the main loop. The implementation correctly:
- Handles causal masking (Lines 86-90)
- Computes online softmax with running max and log-sum-exp tracking
- Applies rescaling to maintain numerical stability
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (1)
144-146: Unused parameters are intentional for API consistency.The static analysis correctly identifies that
glseandOutput_partialare unused inmain_no_split. This is expected - these parameters exist to maintain a consistent function signature withmain_split, allowing the caller to use the same interface regardless of which path is selected at runtime.examples/deepseek_mla/example_mla_decode.py (2)
24-32: LGTM - Clean split/no-split separation.The
main_splitfunction correctly implements the split-path execution with per-splitglseandOutput_partialhandling, followed by a combine stage.
128-196: LGTM - Non-split path correctly writes directly to Output.The
main_no_splitfunction properly bypasses the partial output mechanism and writes the final result directly toOutput, which is the expected behavior whennum_split == 1.examples/deepseek_mla/example_mla_decode_paged.py (2)
27-38: LGTM - Paged attention variant with proper block table handling.The
main_splitcorrectly incorporates block table lookups for paged KV cache access while maintaining the split-path execution pattern.
147-149: Unused parameters are intentional for API consistency.Similar to the AMD benchmark file,
glseandOutput_partialare unused inmain_no_splitto maintain signature compatibility withmain_split. This enables the conditional return at Lines 214-217 to work with a uniform interface.examples/deepseek_mla/example_mla_decode_ws.py (3)
36-47: LGTM - Warp-specialized split path correctly structured.The
main_splitfunction properly implements the warp-specialized execution pattern with:
- Thread range 0-127: QK GEMM and left-half output accumulation
- Thread range 128-255: Right-half output accumulation
- Thread range 256+: Producer (KV data loading)
The split-aware partial output writes at Lines 177-178 and 208 are correct.
254-280: Combine stage correctly aggregates split outputs.The combine kernel uses the standard log-sum-exp scaling pattern to merge
Output_partialacross splits into the finalOutput.
288-290: Unused parameters are intentional for API consistency.As with other files,
glseandOutput_partialare unused inmain_no_splitto maintain a uniform function signature withmain_split.examples/warp_specialize/example_warp_specialize_flashmla.py (1)
354-354: Verify the rationale for changing the default batch size from 1 to 132.The default batch size was changed from 1 to 132, which is a significant increase (132×). The value 132 is unusual (not a power of 2 or typical batch size), which raises questions about whether this change is:
- Intentional for specific benchmarking requirements
- Related to a particular hardware configuration or test scenario
- Accidentally introduced
A batch size of 1 is more typical for examples and quick testing, while 132 significantly increases memory usage and execution time.
Please clarify the rationale for this change. If intentional, consider adding a comment explaining why this specific value is used.
examples/flash_attention/example_gqa_fwd_bshd.py (4)
107-114: Masking and GEMM logic looks correct.The inlined masking correctly handles both causal and non-causal attention:
- Causal masking properly masks future tokens (query < key positions)
- Non-causal masking correctly handles out-of-bounds positions
- GQA indexing with
by // groupsappropriately maps query heads to KV heads
116-128: Online softmax normalization is correctly implemented.The inlined sequence properly implements the online softmax algorithm with correct order of operations:
- Max propagation and combination with previous iterations
- Appropriate scaling factors computed in log2 domain
- Running logsum correctly accumulated with per-iteration scaling
130-134: Accumulator scaling and final GEMM are correct.The acc_o scaling by scores_scale before the GEMM correctly implements the online softmax rescaling, and the GQA indexing is consistent for both K and V tensors.
106-135: Verify that the inlined implementation matches the original macro behavior.The inlined attention computation logic is correct and properly implements:
- GQA with appropriate head indexing
- Online softmax with numerical stability
- Causal and non-causal masking
However, ensure that tests pass and the behavior is identical to the original macro-based implementation, particularly for edge cases like:
- Sequence lengths that don't align with block boundaries
- Different group configurations
- Causal vs non-causal modes
If you'd like, I can generate a verification script to check for any remaining references to the removed macros (MMA0, MMA1, Softmax, Rescale) in this file or related test files to ensure the cleanup is complete.
examples/flash_attention/example_mha_fwd_bshd.py (4)
60-67: LGTM! Masking and GEMM logic is correct.The inlined K copy, causal masking, and attention score computation are correctly implemented:
- Causal masking properly masks future tokens (query_pos < key_pos)
- Non-causal masking correctly handles out-of-bounds positions
- The GEMM accumulates into pre-initialized mask values as expected
69-82: LGTM! Online softmax implementation is correct.The inlined softmax normalization logic correctly implements the online/incremental softmax algorithm:
- Properly tracks and updates running maximum (scores_max)
- Correctly computes scale factors for previous accumulations
- Uses exp2 with log2(e) conversion for performance optimization
- Maintains running sum (logsum) with appropriate rescaling
83-85: LGTM! Rescale operation is correct.The rescaling of
acc_obyscores_scaleis a necessary step in the online softmax algorithm to account for updated maximum scores.
86-88: LGTM! Value GEMM is correct.The V copy and attention-weighted GEMM operation correctly compute the attention output by accumulating across all key-value blocks.
examples/flash_decoding/example_gqa_decode.py (1)
256-259: LGTM!Clean runtime selection between split and no-split paths based on
num_split. This properly handles the two execution modes.examples/flash_decoding/example_mha_inference.py (2)
20-27: LGTM: Clean prim_func signature.The conversion from macro-based to prim_func with explicit intermediate tensors (glse, Output_partial) is well-structured for the split-combine pattern.
56-73: Causal split masking incomplete but properly documented and protected.The causal mask on line 71 does not account for the split offset and would be incorrect when
num_split > 1andcausal=True. However, this limitation is already explicitly documented via the TODO comments (lines 56, 68) and properly protected:ref_program(line 168) assertscausal is False, and tests default tocausal=False, preventing incorrect behavior. The incomplete feature is intentionally deferred and does not pose a risk in the current state.
| T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared) | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| q_idx = bx * block_M + i + past_len | ||
| k_idx = k * block_N + j | ||
| if window_size is not None: | ||
| acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) | ||
| else: | ||
| acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) | ||
| T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) | ||
|
|
||
| T.copy(scores_max, scores_max_prev) | ||
| T.fill(scores_max, -T.infinity(accum_dtype)) | ||
| T.reduce_max(acc_s, scores_max, dim=1, clear=False) | ||
| for i in T.Parallel(block_M): | ||
| scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) | ||
| # To do causal softmax, we need to set the scores_max to 0 if it is -inf | ||
| # This process is called Check_inf in FlashAttention3 code, and it only need to be done | ||
| # NOTE(wt): check_inf is necessary for sliding window attention. | ||
| for i in T.Parallel(block_M): | ||
| if window_size is not None: | ||
| scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) | ||
| scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - | ||
| # max * log_2(e)) This allows the compiler to use the ffma | ||
| # instruction instead of fadd and fmul separately. | ||
| acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) | ||
| T.reduce_sum(acc_s, scores_sum, dim=1) | ||
| for i in T.Parallel(block_M): | ||
| logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] | ||
| T.copy(acc_s, acc_s_cast) | ||
|
|
||
| for i, j in T.Parallel(block_M, dim): | ||
| acc_o[i, j] *= scores_scale[i] | ||
|
|
||
| T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared) | ||
| T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
GQA variant duplicates logic with minor indexing differences.
This GQA implementation differs from the MHA versions only in the K/V indexing (by // groups on lines 112, 147). The core attention computation is identical. This is a strong signal that a parameterized helper function could serve all variants (MHA, GQA, with/without pipelining) while eliminating duplication.
A unified helper could accept an indexing callback or parameter to handle the MHA vs GQA difference, reducing the maintenance burden across all 27 affected files.
| T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| q_idx = bx * block_M + i + past_len | ||
| k_idx = k * block_N + j | ||
| if window_size is not None: | ||
| acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) | ||
| else: | ||
| acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) | ||
| T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) | ||
|
|
||
| T.copy(scores_max, scores_max_prev) | ||
| T.fill(scores_max, -T.infinity(accum_dtype)) | ||
| T.reduce_max(acc_s, scores_max, dim=1, clear=False) | ||
| for i in T.Parallel(block_M): | ||
| scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) | ||
| # To do causal softmax, we need to set the scores_max to 0 if it is -inf | ||
| # This process is called Check_inf in FlashAttention3 code, and it only need to be done | ||
| # NOTE(wt): check_inf is necessary for sliding window attention. | ||
| for i in T.Parallel(block_M): | ||
| if window_size is not None: | ||
| scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) | ||
| scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - | ||
| # max * log_2(e)) This allows the compiler to use the ffma | ||
| # instruction instead of fadd and fmul separately. | ||
| acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) | ||
| T.reduce_sum(acc_s, scores_sum, dim=1) | ||
| for i in T.Parallel(block_M): | ||
| logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] | ||
| T.copy(acc_s, acc_s_cast) | ||
|
|
||
| for i, j in T.Parallel(block_M, dim): | ||
| acc_o[i, j] *= scores_scale[i] | ||
|
|
||
| T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) | ||
| T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Consider consolidating duplicated attention logic across files.
The inlined attention computation (K copy, masking, GEMM, softmax operations, V processing) is now duplicated across multiple files in this PR. According to the AI summary, 27 files receive similar inlined logic. This creates maintenance challenges:
- Bug fixes require changes in 27 locations
- Performance optimizations must be replicated everywhere
- Risk of inconsistencies if updates are missed
The PR removes macros but doesn't clarify why they were "unnecessary." Macros or helper functions are specifically designed to avoid this kind of duplication.
Could you clarify the rationale for removing the macro-based approach? If the goal is to improve readability or performance at specific call sites, consider:
- Keeping a shared helper function with inline hints
- Using a template/code generation approach during build
- Documenting the specific issues that macros were causing
This would preserve the benefits of the refactor while maintaining DRY principles.
| T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| q_idx = bx * block_M + i + past_len | ||
| k_idx = k * block_N + j | ||
| if window_size is not None: | ||
| acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) | ||
| else: | ||
| acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) | ||
| T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) | ||
|
|
||
| T.copy(scores_max, scores_max_prev) | ||
| T.fill(scores_max, -T.infinity(accum_dtype)) | ||
| T.reduce_max(acc_s, scores_max, dim=1, clear=False) | ||
| for i in T.Parallel(block_M): | ||
| scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) | ||
| # To do causal softmax, we need to set the scores_max to 0 if it is -inf | ||
| # This process is called Check_inf in FlashAttention3 code, and it only need to be done | ||
| # NOTE(wt): check_inf is necessary for sliding window attention. | ||
| for i in T.Parallel(block_M): | ||
| if window_size is not None: | ||
| scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) | ||
| scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - | ||
| # max * log_2(e)) This allows the compiler to use the ffma | ||
| # instruction instead of fadd and fmul separately. | ||
| acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) | ||
| T.reduce_sum(acc_s, scores_sum, dim=1) | ||
| for i in T.Parallel(block_M): | ||
| logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] | ||
| T.copy(acc_s, acc_s_cast) | ||
|
|
||
| for i, j in T.Parallel(block_M, dim): | ||
| acc_o[i, j] *= scores_scale[i] | ||
|
|
||
| T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) | ||
| T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Same code duplication concern applies here.
This file contains the same inlined attention logic as example_mha_sink_fwd_bhsd_wgmma_pipelined.py. The duplication makes it difficult to maintain consistency across variants. Consider extracting the shared logic into a common helper.
| # compute full tiles | ||
| if sm_patition_factor > 0: | ||
| compute_full_tiles(pid, A, A_shared_full_tiles, B, B_shared_full_tiles, C, C_local) | ||
| for p in T.serial(sm_patition_factor): | ||
| tile_id = pid + streamk_tiles + p * total_sm | ||
| pid_m = tile_id // T.ceildiv(N, block_N) | ||
| pid_n = tile_id % T.ceildiv(N, block_N) | ||
| T.clear(C_local) | ||
|
|
||
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): | ||
| T.copy(A[pid_m * block_M, k * block_K], A_shared_full_tiles) | ||
| T.copy(B[pid_n * block_N, k * block_K], B_shared_full_tiles) | ||
| T.gemm(A_shared_full_tiles, B_shared_full_tiles, C_local, transpose_B=trans_B) | ||
| T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n examples/gemm_streamk/example_tilelang_gemm_streamk.py | head -160Repository: tile-ai/tilelang
Length of output: 6420
🏁 Script executed:
#!/bin/bash
# Extract the variable definitions to verify the boundary condition
sed -n '15,54p' examples/gemm_streamk/example_tilelang_gemm_streamk.py | cat -n
echo "---"
echo "Analyzing boundary conditions:"
python3 << 'EOF'
# Simulate the values
streamk_programs = 108
# Test case 1: Example from file
total_tiles_1 = (256 // 16) * (1024 // 128) # 16 * 8 = 128
streamk_tiles_1 = total_tiles_1 % streamk_programs # 128 % 108 = 20
if (total_tiles_1 - streamk_tiles_1) > streamk_programs:
streamk_tiles_1 += streamk_programs
blocking_tiles_1 = total_tiles_1 - streamk_tiles_1
sm_partition_factor_1 = max(blocking_tiles_1 // streamk_programs, 1)
print(f"Case 1 (m=256, n=1024, k=512):")
print(f" total_tiles={total_tiles_1}, streamk_tiles={streamk_tiles_1}, blocking_tiles={blocking_tiles_1}")
print(f" sm_partition_factor={sm_partition_factor_1}")
print()
# Test case 2: When total_tiles < streamk_programs (edge case)
total_tiles_2 = 50
streamk_tiles_2 = total_tiles_2 % streamk_programs # 50 % 108 = 50
if (total_tiles_2 - streamk_tiles_2) > streamk_programs:
streamk_tiles_2 += streamk_programs
blocking_tiles_2 = total_tiles_2 - streamk_tiles_2
sm_partition_factor_2 = max(blocking_tiles_2 // streamk_programs, 1)
print(f"Case 2 (total_tiles=50, streamk_programs=108):")
print(f" total_tiles={total_tiles_2}, streamk_tiles={streamk_tiles_2}, blocking_tiles={blocking_tiles_2}")
print(f" sm_partition_factor={sm_partition_factor_2}")
print(f" ISSUE: sm_partition_factor={sm_partition_factor_2} > 0 is True, but blocking_tiles={blocking_tiles_2}")
print()
# Test case 3: When total_tiles is close to streamk_programs
total_tiles_3 = 100
streamk_tiles_3 = total_tiles_3 % streamk_programs # 100 % 108 = 100
if (total_tiles_3 - streamk_tiles_3) > streamk_programs:
streamk_tiles_3 += streamk_programs
blocking_tiles_3 = total_tiles_3 - streamk_tiles_3
sm_partition_factor_3 = max(blocking_tiles_3 // streamk_programs, 1)
print(f"Case 3 (total_tiles=100, streamk_programs=108):")
print(f" total_tiles={total_tiles_3}, streamk_tiles={streamk_tiles_3}, blocking_tiles={blocking_tiles_3}")
print(f" sm_partition_factor={sm_partition_factor_3}")
EOFRepository: tile-ai/tilelang
Length of output: 146
🏁 Script executed:
python3 << 'EOF'
# Simulate the boundary condition logic
streamk_programs = 108
# Test case 1: Example from file
total_tiles_1 = (256 // 16) * (1024 // 128) # 16 * 8 = 128
streamk_tiles_1 = total_tiles_1 % streamk_programs # 128 % 108 = 20
if (total_tiles_1 - streamk_tiles_1) > streamk_programs:
streamk_tiles_1 += streamk_programs
blocking_tiles_1 = total_tiles_1 - streamk_tiles_1
sm_partition_factor_1 = max(blocking_tiles_1 // streamk_programs, 1)
print(f"Case 1 (m=256, n=1024, k=512):")
print(f" total_tiles={total_tiles_1}, streamk_tiles={streamk_tiles_1}, blocking_tiles={blocking_tiles_1}")
print(f" sm_partition_factor={sm_partition_factor_1}")
print(f" Condition 'sm_partition_factor > 0' is {sm_partition_factor_1 > 0}")
print()
# Test case 2: When total_tiles < streamk_programs
total_tiles_2 = 50
streamk_tiles_2 = total_tiles_2 % streamk_programs # 50 % 108 = 50
if (total_tiles_2 - streamk_tiles_2) > streamk_programs:
streamk_tiles_2 += streamk_programs
blocking_tiles_2 = total_tiles_2 - streamk_tiles_2
sm_partition_factor_2 = max(blocking_tiles_2 // streamk_programs, 1)
print(f"Case 2 (total_tiles=50):")
print(f" blocking_tiles={blocking_tiles_2}, sm_partition_factor={sm_partition_factor_2}")
print(f" ISSUE: Condition is True but blocking_tiles=0!")
print()
# Test case 3: When total_tiles equals streamk_programs
total_tiles_3 = 108
streamk_tiles_3 = total_tiles_3 % streamk_programs # 108 % 108 = 0
if (total_tiles_3 - streamk_tiles_3) > streamk_programs:
streamk_tiles_3 += streamk_programs
blocking_tiles_3 = total_tiles_3 - streamk_tiles_3
sm_partition_factor_3 = max(blocking_tiles_3 // streamk_programs, 1)
print(f"Case 3 (total_tiles=108):")
print(f" blocking_tiles={blocking_tiles_3}, sm_partition_factor={sm_partition_factor_3}")
print(f" Condition is True but blocking_tiles=0!")
EOFRepository: tile-ai/tilelang
Length of output: 446
Fix guard condition for full tiles computation.
The condition if sm_patition_factor > 0: is incorrect. When blocking_tiles = 0 (which occurs when total_tiles ≤ streamk_programs), sm_patition_factor = max(0 // total_sm, 1) evaluates to 1, causing the condition to pass despite having no blocking tiles to process. This leads to invalid tile indexing.
Change line 133 to if blocking_tiles > 0: to ensure the loop executes only when blocking tiles actually exist.
🤖 Prompt for AI Agents
In examples/gemm_streamk/example_tilelang_gemm_streamk.py around lines 132 to
144, the guard uses `if sm_patition_factor > 0:` which can be true even when
there are no blocking tiles (leading to invalid tile indexing); replace that
condition with `if blocking_tiles > 0:` so the full-tiles computation only runs
when blocking_tiles actually exist.
| glse: T.Tensor([batch, heads, num_split], dtype), | ||
| Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), | ||
| Output: T.Tensor([batch, heads, dim], dtype), |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the file structure and find the main_no_split and main_split functions
fd "example_warp_specialize_flashmla.py"Repository: tile-ai/tilelang
Length of output: 120
🏁 Script executed:
# Once we locate the file, let's examine the main_no_split function and understand the parameters
cat -n examples/warp_specialize/example_warp_specialize_flashmla.py | sed -n '20,40p'Repository: tile-ai/tilelang
Length of output: 1294
🏁 Script executed:
# Search for main_split function to see if a split variant exists
rg -n "def main_split|@T.prim_func" examples/warp_specialize/example_warp_specialize_flashmla.py | head -20Repository: tile-ai/tilelang
Length of output: 79
🏁 Script executed:
# Search for usages of glse and Output_partial within the file
rg -n "glse|Output_partial" examples/warp_specialize/example_warp_specialize_flashmla.py | head -40Repository: tile-ai/tilelang
Length of output: 387
🏁 Script executed:
# Let's see the full function body and look for any usage of glse or Output_partial in the main_no_split function
cat -n examples/warp_specialize/example_warp_specialize_flashmla.py | sed -n '20,300p' | grep -E "(glse|Output_partial|def main_split|def ref_program)"Repository: tile-ai/tilelang
Length of output: 269
🏁 Script executed:
# Let's check if there's a main_split function by searching more broadly
rg -n "def main_split" examples/warp_specialize/example_warp_specialize_flashmla.pyRepository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Let's look at the ref_program function to understand what it does with these parameters
cat -n examples/warp_specialize/example_warp_specialize_flashmla.py | sed -n '299,330p'Repository: tile-ai/tilelang
Length of output: 1763
🏁 Script executed:
# Let's check the file size and look for any calls to main_no_split with arguments
wc -l examples/warp_specialize/example_warp_specialize_flashmla.pyRepository: tile-ai/tilelang
Length of output: 124
🏁 Script executed:
# Let's look at the ref_program function completely to see if glse and Output_partial are actually used
cat -n examples/warp_specialize/example_warp_specialize_flashmla.py | sed -n '299,362p'Repository: tile-ai/tilelang
Length of output: 3534
🏁 Script executed:
# Let's check how main_no_split is actually called/tested
rg -n "main_no_split" examples/warp_specialize/example_warp_specialize_flashmla.pyRepository: tile-ai/tilelang
Length of output: 114
🏁 Script executed:
# Let's also check if there's a split version anywhere in the codebase
fd ".*flashmla.*" --type fRepository: tile-ai/tilelang
Length of output: 167
🏁 Script executed:
# Check the function signature and call context more carefully
cat -n examples/warp_specialize/example_warp_specialize_flashmla.py | sed -n '340,362p'Repository: tile-ai/tilelang
Length of output: 1437
🏁 Script executed:
# Let's check the flashattn function to understand the architecture better
rg -n "def flashattn" examples/warp_specialize/example_warp_specialize_flashmla.py -A 5Repository: tile-ai/tilelang
Length of output: 375
🏁 Script executed:
# Let's look at the full main_no_split function body to confirm no usage of glse/Output_partial
cat -n examples/warp_specialize/example_warp_specialize_flashmla.py | sed -n '20,300p'Repository: tile-ai/tilelang
Length of output: 15063
Remove or mark the unused glse and Output_partial parameters in the function signature.
These parameters are declared in both main_no_split and ref_program but never referenced in their implementations. Since they add unnecessary interface complexity, either remove them entirely or prefix with underscore (_glse, _Output_partial) to explicitly mark them as intentionally unused.
🧰 Tools
🪛 Ruff (0.14.10)
26-26: Unused function argument: glse
(ARG001)
27-27: Unused function argument: Output_partial
(ARG001)
🤖 Prompt for AI Agents
In examples/warp_specialize/example_warp_specialize_flashmla.py around lines 26
to 28, the function signatures for main_no_split and ref_program include unused
parameters glse and Output_partial; either remove these parameters from both
function signatures and all call sites, or rename them to _glse and
_Output_partial to mark them as intentionally unused; update any type
annotations and docstrings accordingly to keep signatures consistent across
definitions and uses.
|
@codex review |
|
Codex Review: Didn't find any major issues. You're on a roll. ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |
|
waiting to rebase on #1509 |
|
@codex review |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
examples/warp_specialize/example_warp_specialize_flashmla.py (1)
26-27: Unused parameters flagged in previous review.The
glseandOutput_partialparameters remain unused as flagged in the previous review comment on lines 26-28. Please address the earlier feedback to either remove these parameters or mark them as intentionally unused with underscore prefixes.examples/attention_sink/example_mha_sink_fwd_bhsd.py (1)
87-124: Code duplication concern from previous review still applies.The inlined attention logic in this block appears functionally correct and follows the standard FlashAttention pattern with proper causal masking, sliding window support, and numerical stability. However, as noted in the previous review, this same logic is duplicated across multiple variant files (e.g.,
example_mha_sink_fwd_bhsd_wgmma_pipelined.py), creating maintenance challenges.
🧹 Nitpick comments (4)
examples/flash_decoding/example_mha_inference.py (1)
158-165: Consider adding a corresponding guard inmainorflashattn.
ref_programassertscausal is False, but a caller could invokemain(causal=True)or construct a kernel directly with causal mode. Given the incomplete causal split handling noted in the TODO comments, consider adding an early check or assertion at the kernel level to fail fast with a clear message rather than relying on the reference function's assertion during validation.examples/flash_decoding/example_gqa_decode_varlen_logits.py (2)
341-343: Consider removing unused tensor allocations.
O_tlandS_tlare allocated at lines 341-342 but immediately overwritten by the kernel call at line 343. These pre-allocations are unnecessary since the kernel returns new tensors based onout_idx.🔎 Proposed fix
- O_tl = torch.zeros_like(Q) - S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device) O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux)
877-880: Hardcoded argument overrides disable CLI flexibility.These lines override parsed arguments, which defeats the purpose of the argparse configuration. If this is intentional for testing, consider removing these lines or adding a comment explaining the temporary override.
🔎 Proposed fix
- args.test_sink = True - args.test_varlen = False - args.dtype = T.float16 - args.num_split = 1examples/deepseek_mla/example_mla_decode.py (1)
122-180: Minor inconsistency: Consider aligning withmain_splitfor maintainabilityThe
main_no_splitimplementation differs slightly frommain_splitin two ways:
Accumulator clearing approach (line 159 vs line 66-67): Uses
clear_accum=Trueparameter instead of explicitT.clear(acc_s). Both are correct.Final GEMM input (line 176 vs line 85): Uses
S_shareddirectly instead of copying toacc_s_cast. Themain_splitversion allocatesacc_s_castand copies through it (lines 43, 80), whilemain_no_splitomits this buffer entirely.While both approaches are functionally correct (both pass float16 to the final GEMM), the inconsistency may impact maintainability. Consider aligning the implementations unless the difference is intentional optimization.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.pyexamples/attention_sink/example_mha_sink_fwd_bhsd.pyexamples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.pyexamples/deepseek_mla/example_mla_decode.pyexamples/deepseek_mla/example_mla_decode_paged.pyexamples/flash_decoding/example_gqa_decode_varlen_logits.pyexamples/flash_decoding/example_mha_inference.pyexamples/warp_specialize/example_warp_specialize_flashmla.py
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
🧰 Additional context used
🪛 Ruff (0.14.10)
examples/warp_specialize/example_warp_specialize_flashmla.py
26-26: Unused function argument: glse
(ARG001)
27-27: Unused function argument: Output_partial
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (18)
examples/warp_specialize/example_warp_specialize_flashmla.py (2)
20-21: LGTM! Macro replaced with prim_func as intended.The transition from a macro-based implementation to
@T.prim_funcaligns with the PR's cleanup objectives to remove unnecessary macros and use explicit primitive function entry points.
347-347: Verify the inconsistent batch size default.The argparse default for
--batchis now 132, but the functionmain()at line 328 still hasbatch=1as its default parameter. This creates an inconsistency between programmatic and CLI usage.If this is intentional (e.g., 132 is optimized for specific hardware or benchmarking), consider documenting the rationale. Otherwise, align both defaults for consistency.
examples/flash_decoding/example_mha_inference.py (4)
20-28: LGTM on the inlined prim_func signature.The function signature correctly defines all input/output tensors with appropriate shapes for the split/combine flash decoding pattern. The transition from macro-based to inline implementation is clean.
56-67: Verify behavior whenseqlen_kvis not evenly divisible bynum_split.The integer division
seqlen_kv // num_spliton lines 60, 65, and 102 will truncate, potentially skipping the lastseqlen_kv % num_splitKV elements. If this edge case is expected to be handled by callers (i.e., only passing divisible lengths), consider adding an assertion or documenting the constraint.Additionally, the TODO at line 56 notes that causal split case handling is incomplete—the causal loop range (line 58) references full
seqlen_kvbut K/V indexing still uses split portions, which may produce incorrect results for causal attention with splits.
76-105: LGTM on the inlined softmax and rescale logic.The online softmax implementation correctly follows the FlashAttention pattern with exp2/log2(e) optimization. The rescaling of
acc_obefore GEMM and the log-sum tracking are properly implemented.
115-153: LGTM on the combine phase.The combine kernel correctly implements the log-sum-exp reduction across splits, matching the pattern in the reference implementation. The use of
lse_max_localfor numerical stability in the exp2 computations is appropriate.examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
96-133: Code correctness verified; duplication concern already raised.The inlined attention logic (K copy, causal/sliding-window masking, flash attention softmax, V processing) is correct and follows standard patterns. The past review comment on this range already comprehensively addresses the code duplication concern across multiple files. I have no additional issues to raise beyond what has been covered.
examples/flash_decoding/example_gqa_decode_varlen_logits.py (4)
201-202: LGTM!Removing
debug_root_pathfrom the decorator is a good cleanup. Theout_idx=[-2, -1]correctly identifies the output tensors (Output and S) based on their positions in the function signature.
219-228: LGTM!The conversion from macro-based implementation to explicit
@T.prim_funcis clean. TheOutputtensor correctly usesshape_owhich is defined earlier as[batch, heads, dim], matching the expected output shape for the decode operation.
375-376: Consistent with the tilelang path, but same verification applies.The contiguity assertions are relaxed in both the tilelang and Triton implementations. While consistency is good, the underlying Triton kernel (
_fwd_kernel_varlen) also needs to correctly handle strided K/V tensors. The kernel does pass explicit stride parameters, which suggests strided access is intended.
336-337: The K and V contiguity assertions are safely removed.The
flashattn_gqa_decode_no_splitkernel uses strided memory access viaT.copy(), which is designed to handle arbitrary tensor layouts. Unlike the paged variant (which requires contiguity for the block table mechanism), the non-paged variant correctly supports non-contiguous K and V tensors.examples/deepseek_mla/example_mla_decode.py (3)
24-33: LGTM: Clean prim_func signatureThe function signature correctly exposes all necessary parameters for the split attention path, including the intermediate outputs (
glse,Output_partial) and final output (Output).
34-93: LGTM: Split attention kernel logic is correctThe split attention computation properly implements online softmax with per-split accumulation. The use of
acc_s_castfor the final gemm ensures proper dtype handling.
94-120: LGTM: Combine logic correctly merges split resultsThe log-sum-exp merging logic properly combines partial attention outputs from multiple splits using numerically stable exp2 operations.
examples/deepseek_mla/example_mla_decode_paged.py (4)
27-38: LGTM: Improved parameter naming conventionThe function signature correctly uses lowercase snake_case for
block_tableandcache_seqlens, which is more Pythonic than the previous uppercase macro-style naming.
80-82: LGTM: Correct boundary maskingThe masking logic properly handles out-of-bounds elements by checking against
cache_seqlens[bx], ensuring invalid positions are set to negative infinity before the softmax operation.
202-205: LGTM: Clean dispatch logicThe dispatch correctly selects between split and no-split variants based on the
num_splitparameter.
169-200: Remove the block_table offset concern—line 172 is correct for the no_split reverse iteration.The offset calculation at line 172 is correct because the no_split kernel processes all blocks for a given batch element (no
startoffset like main_split). Since there is nostartoffset, both the block index(k * block_N) // block_sizeand the offset(k * block_N) % block_sizecorrectly use onlyk * block_N.The reverse iteration (
k = loop_range - 1 - kr) is intentional and the conditional masking atkr == 0correctly handles the last block, which may be partially filled. This pattern is consistent with other paged KV cache implementations in the codebase.Likely an incorrect or invalid review comment.
| total_blocks = T.ceildiv(cache_seqlens[bx], block_N) | ||
| blocks_per_split = T.floordiv(total_blocks, num_split) | ||
| remaining_blocks = T.floormod(total_blocks, num_split) | ||
| loop_range = blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0) | ||
| start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N | ||
|
|
||
| for k in T.Pipelined(loop_range, num_stages=2): | ||
| kv_start = BLOCK_TABLE[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size | ||
| kv_start = block_table[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size | ||
| T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared) |
There was a problem hiding this comment.
Critical: Incorrect offset calculation in block_table indexing
Line 72 has a bug in the offset calculation within the block:
kv_start = block_table[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_sizeThe offset calculation (k * block_N) % block_size ignores the start position, which can lead to incorrect memory access when start is not aligned to block_size.
Example scenario:
block_size = 128,block_N = 64start = 192(split doesn't start at block boundary)k = 0: The code computes offset =(0 * 64) % 128 = 0, but position 192 is actually at offset 64 within its block.
🔎 Proposed fix
- kv_start = block_table[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size
+ kv_start = block_table[bx, (start + k * block_N) // block_size] * block_size + (start + k * block_N) % block_size🤖 Prompt for AI Agents
examples/deepseek_mla/example_mla_decode_paged.py around lines 65 to 73: the
block_table indexing computes the intra-block offset using (k * block_N) %
block_size which ignores the split start and can point to the wrong position;
change the index and offset calculations to use the absolute position = start +
k * block_N, e.g. compute idx = (position) // block_size and offset = (position)
% block_size, then compute kv_start = block_table[bx, idx] * block_size + offset
so the start contribution is included and accesses are correct.
|
Codex Review: Didn't find any major issues. Delightful! ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.