Skip to content

[Clean][Refactor] Phaseout Legacy Pass ParallelLoopTransformer#1672

Merged
LeiWang1999 merged 16 commits intotile-ai:mainfrom
LeiWang1999:cleanup_01114
Jan 15, 2026
Merged

[Clean][Refactor] Phaseout Legacy Pass ParallelLoopTransformer#1672
LeiWang1999 merged 16 commits intotile-ai:mainfrom
LeiWang1999:cleanup_01114

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Jan 14, 2026

This pull request refactors and simplifies the kernel invocation pattern for the flashattn function in both example_tilelang_sparse_gqa_decode_paged.py and example_tilelang_sparse_gqa_decode_varlen_indice.py. The main improvements are making the kernel instantiation more dynamic and consistent, reducing code duplication, and cleaning up legacy or commented code. Additionally, it introduces minor bug fixes and improves code readability.

Refactoring and Simplification of Kernel Invocation:

  • Refactored the flashattn function in both files to accept all necessary parameters up front and return a kernel function, removing the previous pattern of partially applying parameters and then calling the returned function. This change makes kernel usage more explicit and less error-prone. [1] [2]
  • Updated all usages of the flashattn kernel in the SparseFlashAttn class and related functions to match the new invocation pattern, removing the need to store the kernel as a class attribute and reducing indirection. [1] [2] [3] [4] [5]

Bug Fixes and Code Corrections:

  • Fixed the allocation of has_valid_block to use T.bool instead of the string "bool" in both files, ensuring correct type usage. [1] [2]
  • Replaced the incorrect variable num_blocks with num_pages in run_regression_perf and related cache allocations. [1] [2]

Code Cleanup and Readability Improvements:

  • Removed commented and legacy code, as well as unnecessary print/debug statements, to improve code clarity and maintainability. [1] [2]
  • Renamed the debug function to assert_close for clarity and updated its usage. [1] [2]
  • Added a TODO comment for future parallelization support and improved inline comments. [1] [2]

Reproducibility and Testing Improvements:

  • Set manual seeds for PyTorch and CUDA in the regression performance test to ensure reproducibility.

Import and Dependency Updates:

  • Removed unused imports and cleaned up the import section for better maintainability.

Summary by CodeRabbit

  • Refactor

    • Reworked sparse attention kernel path to use a compiled, top‑level kernel entry, simplifying invocation and runtime behavior.
    • Removed legacy parallel loop transformation utilities and streamlined loop lowering.
  • Tests

    • Updated regression and benchmark flows to use the new kernel invocation pattern and to run deterministically (seeded).
  • Chores

    • Cleaned up debug/test output and unified assertion helpers.

✏️ Tip: You can customize this high-level summary in your review settings.

KEKE046 and others added 15 commits January 8, 2026 06:44
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Refactor: Relocate the _is_running_autodd function to the env module for better organization and encapsulation.
* Update initialization logic to skip logger and heavy imports based on a new light import mode, enhancing flexibility in module usage.
* Ensure consistent handling of environment variables across the package, improving overall code clarity and maintainability.
* Introduced a comprehensive guide on AutoDD (Automatic Delta Debugging) for isolating bugs in TileLang programs.
* Explained Delta Debugging methodology, usage, parameters, and provided examples for clarity.
* Highlighted the benefits of using AutoDD for large codebases and hard-to-locate errors, emphasizing time-saving aspects.
* Included tips for effective usage and a reference to a complete example in the documentation.
* Refactor: Simplify the flash attention function signatures in example scripts to accept parameters directly, enhancing clarity and usability.
* Update the kernel invocation logic in SparseFlashAttn class to align with the new function signatures.
* Remove redundant code and improve the organization of dynamic parameters for better maintainability.
* Enhance the handling of cache sequence lengths and block sizes in the regression performance tests, ensuring consistency across examples.
* Clean up unused imports and streamline the code for improved readability and performance.
…ng examples

* Refactor: Consolidate multi-line function calls into single lines for better clarity in `tilelang_buggy.py` and `tilelang_minimized_expected.py`.
* Remove unnecessary blank lines and streamline print statement formatting for consistency.
* Enhance overall code organization and maintainability across example scripts.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 14, 2026

📝 Walkthrough

Walkthrough

Replaces nested kernel_func-based sparse GQA/flash attention kernels with TileLang-decorated top-level prim_funcs and updates call sites to invoke the new wrapper directly; removes the ParallelLoopTransformer utilities and adjusts layout/parallelization lowering to operate on fused loops.

Changes

Cohort / File(s) Summary
Attention kernel refactor
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py, examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py, examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
Replaced nested/inlined kernel_func implementations with @tilelang.jit/prim_func flashattn that accepts explicit parameters (block_N, block_H, num_stages, threads, num_pages/page sizing) and tensors (Q,K,V, block_indices, cache_seqlens, block_table, glse, Output_partial/Output). SparseFlashAttn now constructs and calls the returned wrapper directly; test utilities renamed (debugassert_close) and benchmark/test shapes/seed handling updated to use paging nomenclature.
Lowering: atomic_add & copy
src/op/atomic_add.cc, src/op/copy.cc, src/op/fill.cc
Removed dependency on loop_parallel_transform_utils.h and eliminated the parallel-loop substitution step; code now uses fused_loop directly for layout inference, partitioning, parallel-op creation and vectorization. Minor include removals and replaced transformed_loop usages with fused_loop.
Parallel transform utilities removed
src/transform/common/loop_parallel_transform_utils.h
Deleted the header and its public tvm::tl::ParallelLoopTransformer and nested BufferAccessCollector classes (Substitute helper and related logic removed).
Layout inference reorder
src/transform/layout_inference.cc
Reordered substitution logic to run LayoutInferencer::Substitute() (with skip_thread_partition flag) after computing thread-binding/skip conditions; collector analysis now runs on original function body before substitution.

Sequence Diagram(s)

sequenceDiagram
    actor Caller
    participant SparseFlashAttn
    participant flashattn_wrapper
    participant prim_main
    participant Memory as KV/O caches

    Caller->>SparseFlashAttn: forward(query,key,value, auxiliaries)
    SparseFlashAttn->>flashattn_wrapper: construct/obtain wrapper(block_N,block_H,num_stages,threads,...)
    SparseFlashAttn->>flashattn_wrapper: invoke(query,key,value, block_indices, cache_seqlens, block_table, glse, output_partial)
    flashattn_wrapper->>prim_main: call prim_func main with tensors & params
    prim_main->>Memory: read/write K/V caches, accumulate Output_partial/Output
    prim_main-->>flashattn_wrapper: return outputs
    flashattn_wrapper-->>SparseFlashAttn: return outputs
    SparseFlashAttn-->>Caller: return final output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • Rachmanino
  • chengyupku

Poem

🐰🌿 I hopped through kernels, neat and spry,
Rewrote the paths where primitives fly.
Loops trimmed down, caches paged with care,
Now flashattn sings in TileLang air. 🥕✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 15.38% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title '[Clean][Refactor] Phaseout Legacy Pass ParallelLoopTransformer' accurately describes the core refactoring objective, but the actual changeset includes substantial flashattn kernel refactoring in example scripts that is equally significant to the ParallelLoopTransformer removal, making the title partially aligned with but not fully representative of the main changes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings


📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fd59247 and 8c7d301.

📒 Files selected for processing (1)
  • src/op/fill.cc
💤 Files with no reviewable changes (1)
  • src/op/fill.cc
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (1)

418-418: Bug: sparse_kernel.kernel does not exist - will raise AttributeError.

The SparseFlashAttn class (lines 159-207) does not define a kernel attribute. This will fail at runtime when run_regression_perf is called.

Compare with example_tilelang_sparse_gqa_decode_varlen_mask.py (lines 434-444) which correctly creates the kernel directly using flashattn(...).

🐛 Suggested fix
-    kernel = sparse_kernel.kernel
+    kernel = flashattn(
+        batch,
+        heads,
+        heads_kv,
+        dim,
+        dim_v,
+        block_N=block_size,
+        block_H=sparse_kernel.block_H,
+        num_stages=2,
+        threads=128,
+    )
♻️ Duplicate comments (2)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (1)

311-319: Same issue as varlen_mask: assert_close doesn't actually assert.

See previous comment on the varlen_mask file. Consider making this function consistent across files by either renaming or adding actual assertion behavior.

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (1)

144-148: Same sentinel value concern as varlen_mask file.

The combine kernel uses lse_local_split != 0 to detect valid splits, with the same potential issues around uninitialized glse memory. See the comment on the varlen_mask file for details.

🧹 Nitpick comments (3)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)

303-311: Function name assert_close is misleading - it doesn't actually assert.

The function only prints a message when values don't match but doesn't raise an exception. Consider either:

  1. Renaming to check_close or print_close_check to reflect actual behavior
  2. Adding an actual assertion to match the name
🔧 Suggested fix to make it actually assert
 def assert_close(name, expect, actual, atol=1e-3, rtol=1e-3):
     all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol)
     print(name + "  all_close={}".format(all_close))
     if not all_close:
         diff = (expect - actual).abs()
         print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item()))
         max_indices = torch.nonzero(diff == diff.max().item())
         first_index = tuple(max_indices[0].tolist())
         print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}")
+        raise AssertionError(f"{name} tensors are not close: max diff = {diff.max().item()}")
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (1)

327-327: Unnecessary int() cast - math.ceil already returns an integer.

As flagged by static analysis (RUF046), math.ceil() returns an integer in Python 3, so the int() wrapper is redundant.

🔧 Suggested fix
-    max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size))
+    max_selected_blocks = math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (1)

166-167: Remove debug print(main) statement.

This debug print statement will output the kernel IR on every JIT compilation, which clutters output in production. Consider removing it or guarding it behind a debug/verbose flag.

🔧 Suggested fix
-    print(main)
     return main
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2d8d367 and fd59247.

📒 Files selected for processing (7)
  • examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
  • examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
  • examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
  • src/op/atomic_add.cc
  • src/op/copy.cc
  • src/transform/common/loop_parallel_transform_utils.h
  • src/transform/layout_inference.cc
💤 Files with no reviewable changes (2)
  • src/transform/common/loop_parallel_transform_utils.h
  • src/transform/layout_inference.cc
🧰 Additional context used
🧬 Code graph analysis (3)
src/op/copy.cc (4)
src/op/parallel.h (1)
  • ParallelOp (170-178)
tilelang/jit/adapter/utils.py (1)
  • is_cpu_target (106-107)
tilelang/transform/__init__.py (1)
  • VectorizeLoop (334-342)
src/transform/loop_vectorize.cc (4)
  • VectorizeLoop (396-406)
  • VectorizeLoop (396-396)
  • VectorizeLoop (408-418)
  • VectorizeLoop (408-409)
src/op/atomic_add.cc (2)
tilelang/language/tir/op.py (1)
  • ret (1878-1891)
src/transform/loop_partition.cc (2)
  • PartitionLoop (64-167)
  • PartitionLoop (64-65)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (2)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (4)
  • flashattn (19-156)
  • main (38-154)
  • main (322-371)
  • SparseFlashAttn (159-207)
tilelang/language/loop.py (2)
  • Pipelined (97-134)
  • Parallel (13-72)
🪛 Ruff (0.14.11)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py

327-327: Value being cast to int is already an integer

Remove unnecessary int call

(RUF046)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (13)
src/op/atomic_add.cc (1)

785-790: LGTM - Clean removal of the intermediate transformation step.

The refactoring correctly replaces transformed_loop with fused_loop at both AtomicAddInferLayout and PartitionLoop call sites, aligning with the goal of phasing out ParallelLoopTransformer. The substitution logic has been relocated to layout_inference.cc where LayoutInferencer::Substitute now performs the transformation after thread-binding analysis, and ParallelLoopTransformer has been fully removed from the codebase.

src/op/copy.cc (1)

719-727: Clean removal of intermediate transformation step is complete and consistent across the codebase.

The refactor correctly eliminates the ParallelLoopTransformer::Substitute step by passing fused_loop directly to both ParallelOp and VectorizeLoop. No references to ParallelLoopTransformer remain in the codebase, and the pattern is consistent with similar lowering paths in fill.cc and layout_inference.cc, where ParallelOp receives a For object directly.

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (4)

13-18: LGTM! Clean decorator configuration.

The @tilelang.jit decorator with out_idx=[-1] indicating the last parameter is the output, and TL_ENABLE_FAST_MATH enabled for performance, is a clean and explicit configuration.


186-199: LGTM! Clean kernel instantiation and invocation pattern.

The new pattern of constructing the kernel with compile-time parameters via flashattn(...) and then invoking with runtime tensors is cleaner and more explicit than the previous approach.


384-386: LGTM! Deterministic seeding for reproducibility.

Setting both torch.manual_seed(42) and torch.cuda.manual_seed_all(42) ensures reproducible test results across runs.


128-132: The != 0 check is a valid sentinel pattern in this code. Since glse is always written at lines 108-110 (even when has_valid_block is false), and invalid splits retain their initialized logsum value of 0 while valid splits compute a non-zero log2 value, this approach is sound. No issue to address.

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (2)

13-18: LGTM! Consistent decorator configuration across files.

The @tilelang.jit configuration matches the varlen_mask variant, maintaining consistency across the codebase.


74-78: LGTM! Block scheduling logic is correct.

The scheduling logic correctly distributes blocks across splits using floordiv and floormod operations.

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (5)

15-21: LGTM! Extended JIT signature for paged KV cache.

The flashattn function signature correctly includes additional parameters for paged attention: page_block_size and num_pages.


39-40: LGTM! Proper constraint validation for paged cache.

The assertion ensures block_N divides evenly into page_block_size, which is required for correct block-to-page mapping via block_ratio.


170-222: LGTM! Well-structured SparseFlashAttn for paged KV cache.

The class correctly stores paged cache parameters (page_block_size, block_N, num_pages) and passes them to flashattn in the forward method.


503-505: LGTM! Reproducibility seeds added.

Manual seeds for both PyTorch and CUDA RNG ensure reproducible benchmark results.


610-625: LGTM! Correct kernel instantiation pattern in regression test.

Unlike example_tilelang_sparse_gqa_decode_varlen_indice.py, this file correctly creates the kernel directly using flashattn(...) rather than accessing a non-existent class attribute.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

* Cleaned up the fill.cc file by removing the unused import of loop_parallel_transform_utils.h, improving code clarity and maintainability.
@LeiWang1999 LeiWang1999 merged commit f035315 into tile-ai:main Jan 15, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants