Skip to content

[Refactor] Reorganize ParallelOp code structure and move ProveFragmentContains to layout utils#1779

Merged
LeiWang1999 merged 5 commits intotile-ai:mainfrom
LeiWang1999:layout_0203
Feb 6, 2026
Merged

[Refactor] Reorganize ParallelOp code structure and move ProveFragmentContains to layout utils#1779
LeiWang1999 merged 5 commits intotile-ai:mainfrom
LeiWang1999:layout_0203

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Feb 3, 2026

Summary

  • Move ProveFragmentContains function from parallel.cc to layout/utils.cc for better code organization
  • Move IfBufferRemapLoopGenerator helper class into anonymous namespace
  • Extract buffer_is_completed_replicated lambda to private method IsBufferCompletelyReplicated
  • Move buffer access info collection (cross-thread access, store buffers) to constructor as member variables
  • Add DeReplicate optimization attempt before layout validation to reduce replication when possible
  • Refactor ValidateCandidateAgainstFragments to support both bool return and exception throwing modes

Changes

  • src/layout/utils.h: Add ProveFragmentContains function declaration
  • src/layout/utils.cc: Add ProveFragmentContains function implementation
  • src/op/parallel.h:
    • Remove ProveFragmentContains declaration (moved to utils)
    • Add IsBufferCompletelyReplicated private method
    • Add member variables for buffer access info
    • Update ValidateCandidateAgainstFragments signature with optional parameters
  • src/op/parallel.cc:
    • Remove ProveFragmentContains implementation
    • Move IfBufferRemapLoopGenerator to anonymous namespace
    • Add IsBufferCompletelyReplicated method implementation
    • Collect buffer access info in constructor instead of InferLayout
    • Add DeReplicate attempt before validation
    • Simplify validation code by reusing ValidateCandidateAgainstFragments

Summary by CodeRabbit

  • Refactor

    • Overhauled layout inference and fragment validation to better track thread bindings, replication, and remapping steps.
  • Improvements

    • Enhanced detection and handling of fully-replicated buffers and cross-thread buffer interactions.
    • Stronger replication guards and expanded debug logging for layout decisions.
  • New Features

    • Added a validation utility to prove fragment containment and integrated it into layout validation paths and APIs.
  • Tests

    • Tests updated to disable caching for kernel generation, print generated kernels for inspection, and gate a CUDA-specific test.

…d Validation

- Added the ProveFragmentContains function to check if the threads accessing elements of a smaller fragment are a subset of those accessing a larger fragment.
- This function ensures valid access when transitioning from a smaller to a larger fragment layout.
- Updated layout.cc and utils.cc to incorporate this new functionality, enhancing the layout validation process.
- Removed the previous implementation of ProveFragmentContains from parallel.cc to streamline the codebase.
@github-actions
Copy link

github-actions bot commented Feb 3, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 3, 2026

📝 Walkthrough

Walkthrough

Moves fragment-containment verification into layout utilities (adds ProveFragmentContains), binds thread ranges onto de-replicated and fully-replicated Fragments, removes local ProveFragmentContains from ParallelOp, and refactors ParallelOpNode to use IsBufferCompletelyReplicated, add cross-thread/store tracking, and update validation APIs and inference flow.

Changes

Cohort / File(s) Summary
Layout thread binding
src/layout/layout.cc
Chain BindThreadRange(Range(0, ThreadExtent())) onto fragments returned by FragmentNode::DeReplicate and Fragment::FullyReplicated.
Layout utilities (new API)
src/layout/utils.h, src/layout/utils.cc
Add ProveFragmentContains(...) declaration and implementation in tvm::tl (includes layout.h); function verifies whether a smaller fragment's thread mapping is contained in a larger fragment, with optional forward-index check.
Parallel op refactor & validation
src/op/parallel.h, src/op/parallel.cc
Remove local ProveFragmentContains use; add IsBufferCompletelyReplicated(...); track cross-thread/store state (has_cross_thread_access_, store_shared_global_buffers_, store_fragment_buffers_); change ValidateCandidateAgainstFragments and ChooseBestCandidate signatures; replace containment checks with replication-based validation and adjust layout inference flow and replication guard construction.
Tests
testing/python/layout/test_tilelang_annotate_loop_layout.py
Add CUDA compute-capability gate to a test, remove minor TODO/whitespace, and tighten test formatting.

Sequence Diagram(s)

sequenceDiagram
  participant Parallel as ParallelOpNode
  participant Layout as LayoutUtils::ProveFragmentContains
  participant Analyzer as Analyzer
  participant Buffer as Buffer/Fragment

  Parallel->>Buffer: collect fragments, accesses, indices
  Parallel->>Parallel: IsBufferCompletelyReplicated(buffer, layout_map)
  alt buffer fully replicated
    Parallel-->>Parallel: skip containment proof, bind thread range
  else not fully replicated
    Parallel->>Layout: ProveFragmentContains(small_frag, large_frag, small_idx, large_idx, Analyzer, check_forward)
    Layout-->>Parallel: true / false
    Parallel->>Parallel: ValidateCandidateAgainstFragments(candidate, T, throw_on_error, check_forward_index, source_buffer)
    Parallel->>Parallel: build replication guards, bind thread ranges
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Poem

🐇 I hopped through fragments, checked each thread,
Bound ranges where replication led,
Proved small inside large with a twitch,
Guarded loops without a glitch,
A rabbit cheer — layout well-fed!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main refactoring: reorganizing ParallelOp code structure and moving ProveFragmentContains to layout utils, which aligns with the PR objectives.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@src/op/parallel.cc`:
- Around line 476-487: There is a stray semicolon after the call to
DeReplicate() and several info logs should be debug-only: remove the extraneous
semicolon following loop_layout_->DeReplicate() and ensure the result is
assigned to dereplicated_layout as intended, and replace LOG(INFO) calls inside
ComputeLoopLayoutFromBuffer with DLOG(INFO) for consistent debug logging; check
the code paths around DeReplicate(), loop_layout_, and
ValidateCandidateAgainstFragments to confirm behavior is unchanged after
removing the semicolon and update the LOG calls to DLOG to match the file's
debug logging convention.
🧹 Nitpick comments (3)
src/layout/utils.cc (1)

401-434: Analyzer binding has potential side effects on caller.

The analyzer.Bind(rep_small, ..., true) call modifies the analyzer's state. Since the analyzer is passed by reference, this binding persists after the function returns, which could affect subsequent analysis in the caller.

Consider whether this is intentional. If not, you may want to either:

  1. Document this side effect clearly
  2. Use a local analyzer copy if isolation is needed
src/op/parallel.cc (2)

418-425: Consider using DLOG(INFO) for debug logging.

These LOG(INFO) statements will emit output in production builds. Other debug logging in this file uses DLOG(INFO) (e.g., lines 434, 437-438, 621-622). For consistency and to avoid noisy logs in production, consider changing these to DLOG(INFO).

♻️ Proposed fix
     if (read_source_buffer.defined() && allow_layout_propgate) {
       candidate_from_buffer =
           ComputeLoopLayoutFromBuffer(read_source_buffer, T);
-      LOG(INFO) << "read_source_buffer: " << read_source_buffer;
-      LOG(INFO) << "candidate_from_buffer: " << candidate_from_buffer;
+      DLOG(INFO) << "read_source_buffer: " << read_source_buffer;
+      DLOG(INFO) << "candidate_from_buffer: " << candidate_from_buffer;
     }

     // try to infer loop layout with two mechanisms and choose the best one
     {
       candidate_from_plan = ComputePlanCandidate(T);
-      LOG(INFO) << "candidate_from_plan: " << candidate_from_plan;
+      DLOG(INFO) << "candidate_from_plan: " << candidate_from_plan;
     }

629-656: Use DLOG(INFO) instead of LOG(INFO) for debug logging.

These logging statements should use DLOG(INFO) for consistency with the rest of the file and to avoid verbose output in production.

♻️ Proposed fix
   Var rep("_rep");
   auto rep_iter =
       IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar);
-  LOG(INFO) << "rep: " << rep;
-  LOG(INFO) << "src_layout: " << src_layout->DebugOutput();
-  LOG(INFO) << "src_layout->DeReplicate(): "
-            << src_layout->DeReplicate()->DebugOutput();
-  LOG(INFO) << "Create rep_iter: " << rep_iter
-            << " from rep_extent: " << src_layout->ReplicateExtent();
+  DLOG(INFO) << "rep: " << rep;
+  DLOG(INFO) << "src_layout: " << src_layout->DebugOutput();
+  DLOG(INFO) << "src_layout->DeReplicate(): "
+             << src_layout->DeReplicate()->DebugOutput();
+  DLOG(INFO) << "Create rep_iter: " << rep_iter
+             << " from rep_extent: " << src_layout->ReplicateExtent();
   PrimExpr loop_var_to_thread =
       src_layout->ForwardThread(GetAccessInfo(buffer).indices, rep);
   loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
-  LOG(INFO) << "loop_var_to_thread after simplify: " << loop_var_to_thread;
+  DLOG(INFO) << "loop_var_to_thread after simplify: " << loop_var_to_thread;
   ...
   try {
     result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
                  ->BindThreadRange(T.thread_bounds);
-    LOG(INFO) << "result: " << result;
-    LOG(INFO) << "result->DeReplicate(): " << result->DeReplicate();
+    DLOG(INFO) << "result: " << result;
+    DLOG(INFO) << "result->DeReplicate(): " << result->DeReplicate();
   } catch (const tvm::runtime::Error &err) {

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@src/op/parallel.cc`:
- Around line 208-222: The ParallelOpNode::IsBufferCompletelyReplicated function
unsafely calls layout_map[buffer].as<Fragment>().value() and inconsistently
treats index errors (returns false for non-IntImm but LOG(FATAL) for non-zero
IntImm); fix by first checking the as<Fragment>() cast result (e.g., auto
frag_opt = layout_map[buffer].as<Fragment>(); if (!frag_opt) return false)
before accessing value, and make index handling consistent by returning false
(not LOG(FATAL)) when any index is non-IntImm or an IntImm with value != 0; use
GetAccessInfo(buffer).indices to perform these checks and then call
frag_opt->IsCompletedReplicated() if all checks pass.
🧹 Nitpick comments (3)
src/op/parallel.cc (3)

120-135: Potential duplicate buffer entries in tracking vectors.

The PostOrderVisit may encounter multiple stores to the same buffer, adding duplicates to store_shared_global_buffers_ and store_fragment_buffers_. If downstream code iterates these vectors assuming unique entries, this could cause redundant processing or incorrect behavior.

Consider using a set or checking for existence before insertion:

♻️ Proposed fix to avoid duplicates
   PostOrderVisit(root_, [&](const ObjectRef &obj) {
     if (const auto *store = obj.as<BufferStoreNode>()) {
       auto buffer = store->buffer;
       if (IsSharedBuffer(buffer) || IsGlobalBuffer(buffer)) {
         has_cross_thread_access_ = true;
-        store_shared_global_buffers_.emplace_back(buffer);
+        if (std::find(store_shared_global_buffers_.begin(),
+                      store_shared_global_buffers_.end(), buffer) ==
+            store_shared_global_buffers_.end()) {
+          store_shared_global_buffers_.emplace_back(buffer);
+        }
       } else if (IsFragmentBuffer(buffer)) {
-        store_fragment_buffers_.emplace_back(buffer);
+        if (std::find(store_fragment_buffers_.begin(),
+                      store_fragment_buffers_.end(), buffer) ==
+            store_fragment_buffers_.end()) {
+          store_fragment_buffers_.emplace_back(buffer);
+        }
       }
     } else if (const auto *load = obj.as<BufferLoadNode>()) {

477-484: Consider skipping redundant re-validation after successful DeReplicate.

When ValidateCandidateAgainstFragments succeeds for dereplicated_layout (line 477-480), loop_layout_ is updated to dereplicated_layout. The subsequent call on line 482-484 then validates the same layout again with throw_on_error=true, which is redundant since it was just validated.

♻️ Proposed optimization
   Fragment dereplicated_layout = loop_layout_->DeReplicate();
   if (ValidateCandidateAgainstFragments(
           dereplicated_layout, T, /*throw_on_error=*/false,
           /*check_forward_index=*/false, source_buffer)) {
     loop_layout_ = dereplicated_layout;
+  } else {
+    // DeReplicate failed validation; validate original with throwing enabled
+    ValidateCandidateAgainstFragments(loop_layout_, T, /*throw_on_error=*/true,
+                                      /*check_forward_index=*/false,
+                                      source_buffer);
   }
-  ValidateCandidateAgainstFragments(loop_layout_, T, /*throw_on_error=*/true,
-                                    /*check_forward_index=*/false,
-                                    source_buffer);

626-630: Reconsider the leading underscore in variable name.

The variable was renamed from rep to _rep, but it's actively used on line 630. Leading underscores typically indicate unused variables or have reserved meanings in C++ (identifiers starting with underscore followed by uppercase, or double underscores, are reserved). While _rep is technically allowed, it may be misleading since the variable is not unused.

♻️ Suggested alternatives
-  Var rep("_rep");
+  Var rep("rep");

Or if distinguishing from other rep variables is needed:

-  Var rep("_rep");
+  Var rep("loop_rep");

Comment on lines +208 to +222
bool ParallelOpNode::IsBufferCompletelyReplicated(
const Buffer &buffer, const LayoutMap &layout_map) const {
if (!IsFragmentBuffer(buffer))
return false;
auto frag = layout_map[buffer].as<Fragment>().value();
// buffer indices should be IntImm
for (const auto &index : GetAccessInfo(buffer).indices) {
if (!index.as<IntImmNode>()) {
return false;
} else if (index.as<IntImmNode>()->value != 0) {
LOG(FATAL) << "buffer " << buffer << " is not completed replicated";
}
}
return frag->IsCompletedReplicated();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unsafe .value() access and inconsistent error handling.

Two concerns:

  1. Line 212: T.layout_map[buffer].as<Fragment>().value() calls .value() without checking if the cast succeeded. If the layout is not a Fragment, this will crash.

  2. Lines 217-218: The method returns false for non-IntImm indices but calls LOG(FATAL) for non-zero IntImm values. This inconsistency is confusing—consider either returning false for both or documenting why non-zero indices are a fatal error.

🛡️ Proposed fix for safe access
 bool ParallelOpNode::IsBufferCompletelyReplicated(
     const Buffer &buffer, const LayoutMap &layout_map) const {
   if (!IsFragmentBuffer(buffer))
     return false;
-  auto frag = layout_map[buffer].as<Fragment>().value();
+  auto frag_opt = layout_map[buffer].as<Fragment>();
+  if (!frag_opt.has_value())
+    return false;
+  auto frag = frag_opt.value();
   // buffer indices should be IntImm
   for (const auto &index : GetAccessInfo(buffer).indices) {
     if (!index.as<IntImmNode>()) {
       return false;
     } else if (index.as<IntImmNode>()->value != 0) {
-      LOG(FATAL) << "buffer " << buffer << " is not completed replicated";
+      return false;  // Non-zero index means not completely replicated
     }
   }
   return frag->IsCompletedReplicated();
 }
🤖 Prompt for AI Agents
In `@src/op/parallel.cc` around lines 208 - 222, The
ParallelOpNode::IsBufferCompletelyReplicated function unsafely calls
layout_map[buffer].as<Fragment>().value() and inconsistently treats index errors
(returns false for non-IntImm but LOG(FATAL) for non-zero IntImm); fix by first
checking the as<Fragment>() cast result (e.g., auto frag_opt =
layout_map[buffer].as<Fragment>(); if (!frag_opt) return false) before accessing
value, and make index handling consistent by returning false (not LOG(FATAL))
when any index is non-IntImm or an IntImm with value != 0; use
GetAccessInfo(buffer).indices to perform these checks and then call
frag_opt->IsCompletedReplicated() if all checks pass.

- Removed the initial DeReplicate attempt from InferLayout to streamline layout inference.
- Added DeReplicate logic to ComputeLoopLayoutFromBuffer to reduce replication when validating layout candidates.
- Updated test cases to disable caching and ensure proper functionality of loop layout kernels.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@testing/python/layout/test_tilelang_annotate_loop_layout.py`:
- Around line 110-112: The test runner was replaced with a direct call to
test_loop_layout_identity(), skipping other tests; restore the standard runner
by re-enabling tilelang.testing.main() (or calling the test framework's main
entry) in the if __name__ == "__main__": block and remove the explicit
test_loop_layout_identity() invocation so that test_loop_layout_fragment_vec4,
test_copy_loop_layout_annotated_replicate_vec4, and
test_annotate_replicate_loop_layout_vec4 are executed via the test runner.
- Around line 43-47: Remove the debug print and either remove or document the
cache disable: delete the print(code) statement that prints kernel source (from
kernel.get_kernel_source()) and if tilelang.disable_cache() was only used for
temporary debugging, remove that call; if disabling cache is required for this
test, add a brief inline comment next to tilelang.disable_cache() explaining why
caching must be disabled for this test (e.g., to ensure fresh kernel
compilation). Ensure the assertion comparing the generated code remains
unchanged.

@LeiWang1999
Copy link
Member Author

@regression-perf

@github-actions
Copy link

github-actions bot commented Feb 4, 2026

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/21658459663

Results

File Original Latency Current Latency Speedup
example_per_token_cast_to_fp8 0.00741028 0.00981551 0.754956
example_topk 0.010849 0.013473 0.80524
example_warp_specialize_gemm_softpipe_stage2 0.038209 0.039136 0.976313
example_gqa_decode 0.047585 0.048256 0.986095
example_tilelang_gemm_fp8_2xAcc 0.185923 0.188274 0.98751
example_convolution_autotune 0.991756 0.997459 0.994282
sparse_mla_fwd 0.130814 0.131546 0.994429
example_tilelang_sparse_gqa_decode_varlen_indice 0.0170542 0.0171371 0.995158
example_gemm 0.022656 0.022752 0.995781
example_mha_sink_bwd_bhsd_sliding_window 0.0447301 0.0448682 0.996921
example_convolution 1.33469 1.33731 0.998042
example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0145713 0.0145972 0.998227
example_gemv 0.28471 0.285126 0.998543
example_gqa_bwd_wgmma_pipelined 0.0695477 0.0696481 0.998559
example_gemm_autotune 0.022304 0.022336 0.998567
example_dequant_gemm_bf16_mxfp4_hopper 0.508586 0.50907 0.999049
tilelang_example_sparse_tensorcore 0.0150216 0.0150305 0.999409
example_group_per_split_token_cast_to_fp8 0.0103507 0.010356 0.999489
example_linear_attn_fwd 0.0368838 0.036899 0.999588
example_tilelang_gemm_fp8 0.322293 0.32239 0.999698
example_dequant_gemv_fp16xint4 0.028414 0.0284205 0.999774
example_tilelang_sparse_gqa_decode_varlen_mask 0.0234106 0.0234141 0.999851
example_mha_bwd_bhsd 0.040665 0.0406691 0.999899
example_vertical_slash_sparse_attn 0.23724 0.237261 0.99991
example_gqa_bwd_tma_reduce_varlen 0.0521791 0.0521837 0.99991
example_gqa_sink_fwd_bhsd_wgmma_pipelined 0.0144449 0.0144459 0.999929
example_tilelang_gemm_fp8_intrinsic 0.933421 0.933478 0.999939
block_sparse_attn_tilelang 0.0102525 0.0102531 0.999939
example_tilelang_gemm_splitk 1.4228 1.42284 0.999972
example_gqa_sink_bwd_bhsd_sliding_window 0.0255522 0.0255505 1.00007
example_linear_attn_bwd 0.153165 0.153153 1.00008
example_mha_bwd_bshd 0.0412572 0.041254 1.00008
example_tilelang_nsa_fwd 0.00685982 0.00685926 1.00008
example_gemm_schedule 0.0325899 0.0325864 1.00011
example_gqa_sink_bwd_bhsd 0.0416456 0.0416397 1.00014
topk_selector 0.0535432 0.0535342 1.00017
example_elementwise_add 0.295893 0.29584 1.00018
example_mha_bwd_bshd_wgmma_pipelined 0.0256546 0.0256494 1.0002
example_mha_sink_bwd_bhsd 0.0623852 0.0623722 1.00021
example_blocksparse_gemm 0.022672 0.022667 1.00022
sparse_mla_bwd 0.383223 0.383121 1.00027
example_tilelang_gemm_splitk_vectorize_atomicadd 1.42263 1.42198 1.00046
example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0154901 0.015483 1.00046
example_gqa_bwd 0.049795 0.0497697 1.00051
example_mla_decode 0.4611 0.460777 1.0007
example_mha_fwd_varlen 0.0454591 0.045415 1.00097
example_tilelang_nsa_decode 0.00734988 0.00734142 1.00115
example_dequant_groupedgemm_bf16_mxfp4_hopper 3.45833 3.45336 1.00144
example_dynamic 0.657485 0.656524 1.00146
example_mha_sink_fwd_bhsd_sliding_window 0.0156881 0.0156636 1.00156
example_dequant_gemm_w4a8 5.40415 5.39481 1.00173
example_gemm_intrinsics 0.035073 0.035009 1.00183
example_mha_sink_fwd_bhsd_wgmma_pipelined 0.0154353 0.015405 1.00197
example_tilelang_block_sparse_attn 0.0101789 0.0101556 1.0023
fp8_lighting_indexer 0.0357559 0.0355217 1.00659
example_dequant_gemm_bf16_fp4_hopper 0.579304 0.575145 1.00723
example_mha_sink_fwd_bhsd 0.0159033 0.0157592 1.00915
sparse_mla_fwd_pipelined 0.0956494 0.0946471 1.01059
example_dequant_gemm_fp4_hopper 1.07228 1.06068 1.01094
example_warp_specialize_gemm_copy_1_gemm_0 0.03824 0.037664 1.01529
example_warp_specialize_gemm_barrierpipe_stage2 0.038784 0.038049 1.01932
example_warp_specialize_gemm_copy_0_gemm_1 0.038496 0.037697 1.0212
example_mha_inference 0.081058 0.07904 1.02553

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LeiWang1999
Copy link
Member Author

@regression-perf

@github-actions
Copy link

github-actions bot commented Feb 5, 2026

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/21699570055

Results

File Original Latency Current Latency Speedup
example_per_token_cast_to_fp8 0.00739141 0.00977391 0.756239
example_topk 0.010753 0.013216 0.813635
example_warp_specialize_gemm_softpipe_stage2 0.037249 0.039232 0.949455
example_warp_specialize_gemm_copy_0_gemm_1 0.037953 0.038945 0.974528
example_warp_specialize_gemm_copy_1_gemm_0 0.037568 0.038273 0.98158
example_mha_inference 0.077026 0.077506 0.993807
example_tilelang_sparse_gqa_decode_varlen_indice 0.0168858 0.0169779 0.994576
example_dequant_gemm_bf16_mxfp4_hopper 0.563685 0.565005 0.997664
sparse_mla_fwd 0.130252 0.130453 0.998454
example_gqa_decode 0.048224 0.048289 0.998654
example_gqa_sink_bwd_bhsd_sliding_window 0.0251311 0.025156 0.999008
example_tilelang_sparse_gqa_decode_varlen_mask 0.0231264 0.0231432 0.999274
example_blocksparse_gemm 0.0224166 0.0224329 0.999275
example_mha_bwd_bshd_wgmma_pipelined 0.0255022 0.0255168 0.999425
example_mla_decode 0.449034 0.449289 0.999432
example_gemm_schedule 0.0322449 0.0322632 0.999433
example_gqa_bwd_tma_reduce_varlen 0.051314 0.0513394 0.999506
example_tilelang_block_sparse_attn 0.0100625 0.0100668 0.999577
tilelang_example_sparse_tensorcore 0.0149016 0.0149076 0.9996
example_gqa_bwd 0.0490208 0.0490397 0.999613
example_tilelang_gemm_fp8_intrinsic 0.910589 0.910934 0.999621
example_tilelang_gemm_splitk 1.40212 1.40265 0.999624
example_tilelang_gemm_splitk_vectorize_atomicadd 1.40142 1.40177 0.999748
example_group_per_split_token_cast_to_fp8 0.0103227 0.0103251 0.999768
example_convolution_autotune 0.991602 0.991792 0.999808
example_linear_attn_bwd 0.15244 0.152465 0.99983
example_dequant_gemm_w4a8 5.30119 5.30207 0.999834
example_linear_attn_fwd 0.0365489 0.0365535 0.999874
example_elementwise_add 0.294005 0.294026 0.99993
example_dequant_gemv_fp16xint4 0.0283649 0.0283662 0.999955
example_gemm_intrinsics 0.034657 0.034657 1
example_gqa_sink_fwd_bhsd_wgmma_pipelined 0.0142967 0.0142961 1.00004
example_gqa_sink_bwd_bhsd 0.0408335 0.0408259 1.00018
example_dynamic 0.651626 0.651501 1.00019
example_vertical_slash_sparse_attn 0.23166 0.231595 1.00028
example_gqa_bwd_wgmma_pipelined 0.0688768 0.0688565 1.0003
example_convolution 1.30866 1.30825 1.00032
example_gemv 0.281493 0.281355 1.00049
block_sparse_attn_tilelang 0.0101569 0.0101519 1.0005
example_dequant_gemm_bf16_fp4_hopper 0.626412 0.625999 1.00066
example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0144518 0.014435 1.00117
example_tilelang_nsa_fwd 0.00681987 0.00681142 1.00124
example_tilelang_nsa_decode 0.00730589 0.00729623 1.00132
example_mha_bwd_bhsd 0.0400585 0.0400044 1.00135
example_mha_sink_bwd_bhsd 0.0616336 0.0615437 1.00146
topk_selector 0.0531949 0.0530925 1.00193
example_tilelang_gemm_fp8 0.318134 0.317412 1.00227
example_mha_bwd_bshd 0.0407504 0.0406231 1.00313
example_gemm_autotune 0.02224 0.022145 1.00429
example_tilelang_gemm_fp8_2xAcc 0.182484 0.181683 1.0044
sparse_mla_bwd 0.379074 0.377197 1.00498
example_mha_sink_fwd_bhsd_sliding_window 0.0155927 0.015511 1.00527
example_mha_fwd_varlen 0.0452419 0.0449749 1.00594
fp8_lighting_indexer 0.0353867 0.0351756 1.006
example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0154363 0.0153418 1.00616
example_mha_sink_bwd_bhsd_sliding_window 0.0445644 0.0442906 1.00618
example_mha_sink_fwd_bhsd_wgmma_pipelined 0.0153523 0.0152319 1.00791
example_mha_sink_fwd_bhsd 0.0157814 0.0156284 1.00979
sparse_mla_fwd_pipelined 0.0950355 0.0940212 1.01079
example_gemm 0.022592 0.022336 1.01146
example_dequant_gemm_fp4_hopper 1.04914 1.02952 1.01906
example_warp_specialize_gemm_barrierpipe_stage2 0.04032 0.039072 1.03194
example_dequant_groupedgemm_bf16_mxfp4_hopper 4.18087 3.8724 1.07966

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

- Removed caching disablement and print statements from the loop layout identity test for cleaner output.
- Updated the main execution block to directly call the testing framework, enhancing test execution flow.
@LeiWang1999 LeiWang1999 merged commit b01cb93 into tile-ai:main Feb 6, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant