[Clean][Refactor] Phaseout Legacy Pass ParallelLoopTransformer#1672
[Clean][Refactor] Phaseout Legacy Pass ParallelLoopTransformer#1672LeiWang1999 merged 16 commits intotile-ai:mainfrom
ParallelLoopTransformer#1672Conversation
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Refactor: Relocate the _is_running_autodd function to the env module for better organization and encapsulation. * Update initialization logic to skip logger and heavy imports based on a new light import mode, enhancing flexibility in module usage. * Ensure consistent handling of environment variables across the package, improving overall code clarity and maintainability.
* Introduced a comprehensive guide on AutoDD (Automatic Delta Debugging) for isolating bugs in TileLang programs. * Explained Delta Debugging methodology, usage, parameters, and provided examples for clarity. * Highlighted the benefits of using AutoDD for large codebases and hard-to-locate errors, emphasizing time-saving aspects. * Included tips for effective usage and a reference to a complete example in the documentation.
* Refactor: Simplify the flash attention function signatures in example scripts to accept parameters directly, enhancing clarity and usability. * Update the kernel invocation logic in SparseFlashAttn class to align with the new function signatures. * Remove redundant code and improve the organization of dynamic parameters for better maintainability. * Enhance the handling of cache sequence lengths and block sizes in the regression performance tests, ensuring consistency across examples. * Clean up unused imports and streamline the code for improved readability and performance.
…ng examples * Refactor: Consolidate multi-line function calls into single lines for better clarity in `tilelang_buggy.py` and `tilelang_minimized_expected.py`. * Remove unnecessary blank lines and streamline print statement formatting for consistency. * Enhance overall code organization and maintainability across example scripts.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughReplaces nested kernel_func-based sparse GQA/flash attention kernels with TileLang-decorated top-level prim_funcs and updates call sites to invoke the new wrapper directly; removes the ParallelLoopTransformer utilities and adjusts layout/parallelization lowering to operate on fused loops. Changes
Sequence Diagram(s)sequenceDiagram
actor Caller
participant SparseFlashAttn
participant flashattn_wrapper
participant prim_main
participant Memory as KV/O caches
Caller->>SparseFlashAttn: forward(query,key,value, auxiliaries)
SparseFlashAttn->>flashattn_wrapper: construct/obtain wrapper(block_N,block_H,num_stages,threads,...)
SparseFlashAttn->>flashattn_wrapper: invoke(query,key,value, block_indices, cache_seqlens, block_table, glse, output_partial)
flashattn_wrapper->>prim_main: call prim_func main with tensors & params
prim_main->>Memory: read/write K/V caches, accumulate Output_partial/Output
prim_main-->>flashattn_wrapper: return outputs
flashattn_wrapper-->>SparseFlashAttn: return outputs
SparseFlashAttn-->>Caller: return final output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
💤 Files with no reviewable changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
✏️ Tip: You can disable this entire section by setting Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (1)
418-418: Bug:sparse_kernel.kerneldoes not exist - will raiseAttributeError.The
SparseFlashAttnclass (lines 159-207) does not define akernelattribute. This will fail at runtime whenrun_regression_perfis called.Compare with
example_tilelang_sparse_gqa_decode_varlen_mask.py(lines 434-444) which correctly creates the kernel directly usingflashattn(...).🐛 Suggested fix
- kernel = sparse_kernel.kernel + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=sparse_kernel.block_H, + num_stages=2, + threads=128, + )
♻️ Duplicate comments (2)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (1)
311-319: Same issue as varlen_mask:assert_closedoesn't actually assert.See previous comment on the varlen_mask file. Consider making this function consistent across files by either renaming or adding actual assertion behavior.
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (1)
144-148: Same sentinel value concern as varlen_mask file.The combine kernel uses
lse_local_split != 0to detect valid splits, with the same potential issues around uninitializedglsememory. See the comment on the varlen_mask file for details.
🧹 Nitpick comments (3)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)
303-311: Function nameassert_closeis misleading - it doesn't actually assert.The function only prints a message when values don't match but doesn't raise an exception. Consider either:
- Renaming to
check_closeorprint_close_checkto reflect actual behavior- Adding an actual assertion to match the name
🔧 Suggested fix to make it actually assert
def assert_close(name, expect, actual, atol=1e-3, rtol=1e-3): all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol) print(name + " all_close={}".format(all_close)) if not all_close: diff = (expect - actual).abs() print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) max_indices = torch.nonzero(diff == diff.max().item()) first_index = tuple(max_indices[0].tolist()) print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") + raise AssertionError(f"{name} tensors are not close: max diff = {diff.max().item()}")examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (1)
327-327: Unnecessaryint()cast -math.ceilalready returns an integer.As flagged by static analysis (RUF046),
math.ceil()returns an integer in Python 3, so theint()wrapper is redundant.🔧 Suggested fix
- max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + max_selected_blocks = math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (1)
166-167: Remove debugprint(main)statement.This debug print statement will output the kernel IR on every JIT compilation, which clutters output in production. Consider removing it or guarding it behind a debug/verbose flag.
🔧 Suggested fix
- print(main) return main
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.pyexamples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.pyexamples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.pysrc/op/atomic_add.ccsrc/op/copy.ccsrc/transform/common/loop_parallel_transform_utils.hsrc/transform/layout_inference.cc
💤 Files with no reviewable changes (2)
- src/transform/common/loop_parallel_transform_utils.h
- src/transform/layout_inference.cc
🧰 Additional context used
🧬 Code graph analysis (3)
src/op/copy.cc (4)
src/op/parallel.h (1)
ParallelOp(170-178)tilelang/jit/adapter/utils.py (1)
is_cpu_target(106-107)tilelang/transform/__init__.py (1)
VectorizeLoop(334-342)src/transform/loop_vectorize.cc (4)
VectorizeLoop(396-406)VectorizeLoop(396-396)VectorizeLoop(408-418)VectorizeLoop(408-409)
src/op/atomic_add.cc (2)
tilelang/language/tir/op.py (1)
ret(1878-1891)src/transform/loop_partition.cc (2)
PartitionLoop(64-167)PartitionLoop(64-65)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (2)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (4)
flashattn(19-156)main(38-154)main(322-371)SparseFlashAttn(159-207)tilelang/language/loop.py (2)
Pipelined(97-134)Parallel(13-72)
🪛 Ruff (0.14.11)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
327-327: Value being cast to int is already an integer
Remove unnecessary int call
(RUF046)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (13)
src/op/atomic_add.cc (1)
785-790: LGTM - Clean removal of the intermediate transformation step.The refactoring correctly replaces
transformed_loopwithfused_loopat bothAtomicAddInferLayoutandPartitionLoopcall sites, aligning with the goal of phasing outParallelLoopTransformer. The substitution logic has been relocated tolayout_inference.ccwhereLayoutInferencer::Substitutenow performs the transformation after thread-binding analysis, andParallelLoopTransformerhas been fully removed from the codebase.src/op/copy.cc (1)
719-727: Clean removal of intermediate transformation step is complete and consistent across the codebase.The refactor correctly eliminates the
ParallelLoopTransformer::Substitutestep by passingfused_loopdirectly to bothParallelOpandVectorizeLoop. No references toParallelLoopTransformerremain in the codebase, and the pattern is consistent with similar lowering paths infill.ccandlayout_inference.cc, whereParallelOpreceives aForobject directly.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (4)
13-18: LGTM! Clean decorator configuration.The
@tilelang.jitdecorator without_idx=[-1]indicating the last parameter is the output, andTL_ENABLE_FAST_MATHenabled for performance, is a clean and explicit configuration.
186-199: LGTM! Clean kernel instantiation and invocation pattern.The new pattern of constructing the kernel with compile-time parameters via
flashattn(...)and then invoking with runtime tensors is cleaner and more explicit than the previous approach.
384-386: LGTM! Deterministic seeding for reproducibility.Setting both
torch.manual_seed(42)andtorch.cuda.manual_seed_all(42)ensures reproducible test results across runs.
128-132: The!= 0check is a valid sentinel pattern in this code. Sinceglseis always written at lines 108-110 (even whenhas_valid_blockis false), and invalid splits retain their initialized logsum value of 0 while valid splits compute a non-zero log2 value, this approach is sound. No issue to address.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (2)
13-18: LGTM! Consistent decorator configuration across files.The
@tilelang.jitconfiguration matches the varlen_mask variant, maintaining consistency across the codebase.
74-78: LGTM! Block scheduling logic is correct.The scheduling logic correctly distributes blocks across splits using
floordivandfloormodoperations.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (5)
15-21: LGTM! Extended JIT signature for paged KV cache.The
flashattnfunction signature correctly includes additional parameters for paged attention:page_block_sizeandnum_pages.
39-40: LGTM! Proper constraint validation for paged cache.The assertion ensures
block_Ndivides evenly intopage_block_size, which is required for correct block-to-page mapping viablock_ratio.
170-222: LGTM! Well-structured SparseFlashAttn for paged KV cache.The class correctly stores paged cache parameters (
page_block_size,block_N,num_pages) and passes them toflashattnin the forward method.
503-505: LGTM! Reproducibility seeds added.Manual seeds for both PyTorch and CUDA RNG ensure reproducible benchmark results.
610-625: LGTM! Correct kernel instantiation pattern in regression test.Unlike
example_tilelang_sparse_gqa_decode_varlen_indice.py, this file correctly creates the kernel directly usingflashattn(...)rather than accessing a non-existent class attribute.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
* Cleaned up the fill.cc file by removing the unused import of loop_parallel_transform_utils.h, improving code clarity and maintainability.
This pull request refactors and simplifies the kernel invocation pattern for the
flashattnfunction in bothexample_tilelang_sparse_gqa_decode_paged.pyandexample_tilelang_sparse_gqa_decode_varlen_indice.py. The main improvements are making the kernel instantiation more dynamic and consistent, reducing code duplication, and cleaning up legacy or commented code. Additionally, it introduces minor bug fixes and improves code readability.Refactoring and Simplification of Kernel Invocation:
flashattnfunction in both files to accept all necessary parameters up front and return a kernel function, removing the previous pattern of partially applying parameters and then calling the returned function. This change makes kernel usage more explicit and less error-prone. [1] [2]flashattnkernel in theSparseFlashAttnclass and related functions to match the new invocation pattern, removing the need to store the kernel as a class attribute and reducing indirection. [1] [2] [3] [4] [5]Bug Fixes and Code Corrections:
has_valid_blockto useT.boolinstead of the string"bool"in both files, ensuring correct type usage. [1] [2]num_blockswithnum_pagesinrun_regression_perfand related cache allocations. [1] [2]Code Cleanup and Readability Improvements:
assert_closefor clarity and updated its usage. [1] [2]Reproducibility and Testing Improvements:
Import and Dependency Updates:
Summary by CodeRabbit
Refactor
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.