Skip to content

[Enhancement] Improve equality checks in layout nodes and fragment validation#1573

Merged
LeiWang1999 merged 2 commits intomainfrom
fix/layout-equality-check
Dec 30, 2025
Merged

[Enhancement] Improve equality checks in layout nodes and fragment validation#1573
LeiWang1999 merged 2 commits intomainfrom
fix/layout-equality-check

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Dec 30, 2025

Summary

This PR improves the correctness of layout and fragment equality checks, and adds validation for physical index compatibility in ProveFragmentContains.

Problem

  1. LayoutNode::IsEqual and FragmentNode::IsEqual had incorrect comparison logic: The original implementation used StructuralEqual directly on forward_index_, which could incorrectly compare InputPlaceholder variables. For example, two layouts with different mappings like i * 16 + j and j * 16 + i might be incorrectly considered equal if their InputPlaceholder indices happened to match structurally.

  2. ProveFragmentContains only checked thread containment, not physical index compatibility: When validating a loop layout candidate against a buffer fragment, the function only verified that threads accessing small fragment elements are a subset of threads accessing large fragment elements. However, it didn't verify that the physical indices match, which could lead to incorrect code generation.

    Example of incorrect code generation:

    Loop layout:   (i, j) -> physical = [0],  thread = i*4+j
    Buffer layout: (i, j) -> physical = [j],  thread = i*4+rep
    

    Without physical index validation, code generation would produce:

    // Incorrect: col_sum[(threadIdx.x) & 3] 
    // Should be: col_sum[j] where j is correctly derived
    sums[(threadIdx.x) + 128] = col_sum[(threadIdx.x) & 3];

Solution

  1. Fix IsEqual by using Forward with common variables:

    // Create common variables for comparison
    Array<PrimExpr> common_vars;
    for (size_t i = 0; i < this->InputDim(); i++) {
      common_vars.push_back(Var("_cmp_v" + std::to_string(i)));
    }
    
    // Compare Forward results with StructuralEqual
    auto this_forward = this->Forward(common_vars);
    auto other_forward = other->Forward(common_vars);
    if (!StructuralEqual()(this_forward, other_forward)) {
      return false;
    }
  2. Add check_forward_index parameter to ProveFragmentContains:

    bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
                               Array<PrimExpr> small_frag_indices,
                               Array<PrimExpr> large_frag_indices,
                               arith::Analyzer &analyzer_,
                               bool check_forward_index = false);

    When check_forward_index=true, validates physical index equality:

    auto small_physical = small_frag->Forward(small_frag_indices);
    auto large_physical = large_frag->Forward(large_frag_indices);
    // Dimension mismatch or value mismatch -> return false

Changes

  • src/op/parallel.h: Add check_forward_index parameter to ProveFragmentContains
  • src/op/parallel.cc: Implement physical index validation; update call sites in ValidateCandidateAgainstFragments and InferLayout
  • src/layout/layout.cc: Fix LayoutNode::IsEqual and FragmentNode::IsEqual to use Forward with common variables
  • testing/python/layout/test_tilelang_layout_equal.py: Add comprehensive tests for layout/fragment equality
  • Move test_tilelang_issue_layout.py to testing/python/layout/test_tilelang_layout_inference.py

Test plan

  • All 24 new layout equality tests pass
  • Existing layout inference test passes
  • Manual verification that incorrect code generation case is now caught

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Bug Fixes

    • Improved layout equality validation to be more robust against structural differences in internal representations.
    • Enhanced fragment containment verification with forward index validation to ensure correctness during layout operations.
  • Tests

    • Added comprehensive test suite for layout and fragment equality semantics, covering edge cases and complex scenarios.

✏️ Tip: You can customize this high-level summary in your review settings.

…lidation

* Enhanced the IsEqual method in LayoutNode and FragmentNode to include detailed comparisons of forward mappings, ensuring accurate equality checks.
* Introduced a new parameter in ProveFragmentContains to validate physical indices when checking fragment containment, improving correctness in layout validation.
* Removed obsolete test file related to layout inference.

This update aims to strengthen the integrity of layout comparisons and fragment validations in the system.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 30, 2025

📝 Walkthrough

Walkthrough

This change enhances Layout and Fragment equality comparison by comparing forward-mapped synthetic variables instead of raw indices, and introduces forward-index validation in fragment containment checks via a new parameter in ProveFragmentContains.

Changes

Cohort / File(s) Change Summary
Layout Equality Enhancement
src/layout/layout.cc
Modified LayoutNode::IsEqual and FragmentNode::IsEqual to construct synthetic common_vars and compare Forward(common_vars) mappings instead of direct forward_index_ comparisons; added ForwardThread mapping comparison for fragments to improve robustness against AST structural differences.
Fragment Containment Validation
src/op/parallel.h, src/op/parallel.cc
Added optional check_forward_index parameter (default false) to ProveFragmentContains signature; implemented guarded validation block in implementation to compare forward and physical indices of small/large fragments when enabled; updated call site in ValidateCandidateAgainstFragments to pass check_forward_index=true.
Equality Test Coverage
testing/python/layout/test_tilelang_layout_equal.py
New comprehensive test module covering Layout/Fragment equality semantics: identical mappings, differing index orders, coefficients, shapes, variable equivalence, thread mappings, forward definitions, replicate factors, and edge cases across 1D, 2D, and 3D layouts.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Poem

🐰 Ears perk up with glee,
Forward maps now dance so free,
No more index confusion here,
Synthetic vars make truth clear,
Thread and layout now agree!

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main changes: improving equality checks in layout nodes (LayoutNode::IsEqual, FragmentNode::IsEqual) and fragment validation (ProveFragmentContains), which are the core objectives of this PR.
Docstring Coverage ✅ Passed Docstring coverage is 87.10% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@LeiWang1999
Copy link
Member Author

@codex review

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
testing/python/layout/test_tilelang_layout_equal.py (1)

8-8: Consider removing set_random_seed() for this test module.

This test file contains only deterministic equality comparisons with no random number generation. The set_random_seed() call is harmless but unnecessary, adding a slight inconsistency since the tests don't rely on randomness.

🔎 Proposed fix
-tilelang.testing.set_random_seed()
-
-
+
 class TestLayoutEqual:
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b6a2513 and 2bc0016.

📒 Files selected for processing (5)
  • src/layout/layout.cc
  • src/op/parallel.cc
  • src/op/parallel.h
  • testing/python/layout/test_tilelang_layout_equal.py
  • testing/python/layout/test_tilelang_layout_inference.py
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/layout/test_tilelang_layout_equal.py (3)
src/layout/layout.cc (4)
  • Layout (60-72)
  • Layout (74-77)
  • Fragment (518-540)
  • Fragment (542-552)
tilelang/testing/__init__.py (1)
  • set_random_seed (32-37)
tilelang/layout/fragment.py (1)
  • replicate (138-152)
🪛 Ruff (0.14.10)
testing/python/layout/test_tilelang_layout_equal.py

86-86: Unused lambda argument: j

(ARG005)


87-87: Unused lambda argument: i

(ARG005)


98-98: Unused lambda argument: j

(ARG005)


99-99: Unused lambda argument: j

(ARG005)


104-104: Unused lambda argument: j

(ARG005)


105-105: Unused lambda argument: j

(ARG005)


164-164: Unused lambda argument: i

(ARG005)


164-164: Unused lambda argument: j

(ARG005)


165-165: Unused lambda argument: i

(ARG005)


165-165: Unused lambda argument: j

(ARG005)


172-172: Unused lambda argument: i

(ARG005)


172-172: Unused lambda argument: j

(ARG005)


173-173: Unused lambda argument: j

(ARG005)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (8)
src/op/parallel.h (1)

27-31: LGTM!

The API extension is well-designed with a default parameter value of false, ensuring backward compatibility with existing call sites while enabling the new forward-index validation when explicitly requested.

src/layout/layout.cc (2)

695-718: LGTM!

The fix correctly addresses the equality comparison issue by constructing fresh common variables and comparing the actual forward mapping results rather than the raw AST structure. The early return pattern improves readability.


720-764: LGTM!

The implementation correctly handles the broadcast case and uses the fresh common variable approach for both forward index and forward thread comparisons. The addition of ForwardThread comparison ensures complete fragment equality validation.

src/op/parallel.cc (2)

44-70: LGTM!

The forward index validation logic is well-implemented. The dimension check followed by component-wise equality verification using analyzer_.Simplify is the correct approach for comparing symbolic expressions.


621-637: LGTM!

The call site correctly enables forward index validation during candidate validation. The comment clearly explains the rationale: ensuring physical indices match for correct code generation. Other call sites appropriately use the default false since they're checking containment relationships rather than strict validation.

testing/python/layout/test_tilelang_layout_equal.py (3)

11-67: LGTM!

Comprehensive test coverage for layout equality including shape mismatches, coefficient differences, variable name independence, 2D outputs, and the notable edge case of commutative expressions where the test correctly validates behavior without assuming AST-level equivalence.


69-125: LGTM!

Good coverage of fragment equality scenarios including thread mapping, forward index, replicate factors, and the forward_fn API. The unused lambda arguments flagged by static analysis are intentional—the lambdas must match the expected arity even when only a subset of parameters is used in the expression.


127-175: LGTM!

Excellent edge case coverage including single/multi-dimensional layouts, empty forward index, and the important distinction between constant and variable mappings using tvm.tir.const.

@LeiWang1999
Copy link
Member Author

@regression-perf

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +54 to +58
auto small_physical = small_frag->Forward(small_frag_indices);
auto large_physical = large_frag->Forward(large_frag_indices);

// Dimension mismatch means they are not equal.
if (small_physical.size() != large_physical.size()) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Skip size-mismatch rejects when forward_index is absent

When check_forward_index is enabled, any size mismatch between small_frag->Forward(...) and large_frag->Forward(...) is treated as incompatible. This causes plan candidates (which often have a non-empty forward_index, e.g., from PlanLoopPartition) to be rejected against buffer fragments that intentionally use an empty forward_index (see CompleteBufferFragment/ComputeLoopLayoutFromBuffer constructing Fragment(..., {})). The result is that ValidateCandidateAgainstFragments can mark the plan candidate invalid and ChooseBestCandidate falls back to the more replicated buffer-based candidate, introducing a performance regression even though no physical index compatibility can be checked. Consider only enforcing the forward-index equality when both fragments define a non-empty forward index (or treating empty as “don’t care”).

Useful? React with 👍 / 👎.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! However, this concern doesn't apply here.

When a Fragment is constructed with an empty forward_index (e.g., Fragment(..., {}, ...)), the constructor automatically infers and fills it via infer_fragment_index():

  // layout.cc:510-513
  if (forward_index.empty()) {
    forward_index = {
        infer_fragment_index(getVarMap(), forward_thread_, &analyzer)};
  }

So after construction, all Fragments have non-empty forward_index_. The "empty forward_index" scenario you described doesn't exist at runtime - it's only a construction-time shorthand that gets filled automatically.

@github-actions
Copy link

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/20599596999

Results

File Original Latency Current Latency Speedup
example_warp_specialize_gemm_barrierpipe_stage2 0.038177 0.040289 0.947579
example_gqa_decode 0.047489 0.048513 0.978892
example_tilelang_gemm_fp8 0.32055 0.323319 0.991434
example_mha_bwd_bshd 0.0407307 0.0409872 0.993743
example_gemm_autotune 0.022272 0.022369 0.995664
example_dynamic 0.65451 0.656335 0.997219
example_mha_fwd_varlen 0.0451077 0.0451983 0.997995
example_mha_bwd_bhsd 0.0399362 0.0400027 0.998336
example_gemm 0.022592 0.022624 0.998586
example_tilelang_gemm_fp8_intrinsic 0.466554 0.46678 0.999516
example_vertical_slash_sparse_attn 0.237613 0.237722 0.999542
example_gqa_bwd 0.049832 0.0498269 1.0001
example_tilelang_gemm_splitk 1.40852 1.40829 1.00016
example_tilelang_gemm_splitk_vectorize_atomicadd 1.40737 1.40703 1.00024
example_mha_inference 0.072449 0.072418 1.00043
example_warp_specialize_gemm_copy_0_gemm_1 0.039745 0.039713 1.00081
example_gqa_bwd_tma_reduce_varlen 0.0636369 0.063576 1.00096
example_linear_attn_fwd 0.0365442 0.036491 1.00146
example_linear_attn_bwd 0.152156 0.151873 1.00187
example_gqa_bwd_wgmma_pipelined 0.0742521 0.0741092 1.00193
example_gemm_schedule 0.0326117 0.0325423 1.00213
tilelang_example_sparse_tensorcore 0.0150657 0.0150324 1.00222
example_warp_specialize_gemm_softpipe_stage2 0.039073 0.038977 1.00246
example_gemm_intrinsics 0.035137 0.035009 1.00366
block_sparse_attn_tilelang 0.0102827 0.0102443 1.00375
example_fusedmoe_tilelang 0.130872 0.130365 1.00389
example_elementwise_add 0.297524 0.295667 1.00628
example_tilelang_gemm_fp8_2xAcc 0.189429 0.188196 1.00655
example_dequant_gemv_fp16xint4 0.028699 0.0285056 1.00679
example_per_token_cast_to_fp8 0.00751 0.00745457 1.00744
example_convolution_autotune 1.00454 0.995354 1.00922
example_mha_bwd_bshd_wgmma_pipelined 0.0256704 0.0254349 1.00926
example_tilelang_nsa_fwd 0.00714488 0.00706112 1.01186
example_tilelang_nsa_decode 0.00682458 0.00674449 1.01187
example_gemv 0.279586 0.276177 1.01234
example_warp_specialize_gemm_copy_1_gemm_0 0.03904 0.038497 1.0141
example_dequant_groupedgemm_bf16_mxfp4_hopper 3.54828 3.49757 1.0145
example_dequant_gemm_w4a8 5.5077 5.40044 1.01986
example_topk 0.011136 0.010912 1.02053
example_tilelang_block_sparse_attn 0.010394 0.0101717 1.02186
fp8_lighting_indexer 0.0367359 0.0359078 1.02306
example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.015817 0.0154468 1.02397
example_tilelang_sparse_gqa_decode_varlen_indice 0.0176224 0.017208 1.02408
example_mha_sink_fwd_bhsd 0.0162244 0.0158356 1.02456
sparse_mla_fwd_pipelined 0.096605 0.094288 1.02457
example_mha_sink_fwd_bhsd_sliding_window 0.0160311 0.0156437 1.02476
example_mha_sink_bwd_bhsd_sliding_window 0.0455647 0.0444465 1.02516
example_mha_sink_fwd_bhsd_wgmma_pipelined 0.0157734 0.015379 1.02564
example_dequant_gemm_fp4_hopper 1.09429 1.06642 1.02614
example_gqa_sink_fwd_bhsd_wgmma_pipelined 0.0147915 0.0144077 1.02664
example_group_per_split_token_cast_to_fp8 0.0105102 0.0102358 1.02682
example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0149692 0.014559 1.02818
example_blocksparse_gemm 0.0232711 0.0226315 1.02826
sparse_mla_fwd 0.143527 0.13955 1.02849
example_tilelang_sparse_gqa_decode_varlen_mask 0.0240485 0.0233746 1.02883
example_dequant_gemm_bf16_mxfp4_hopper 0.528558 0.5123 1.03174
sparse_mla_bwd 0.394367 0.381911 1.03262
example_gqa_sink_bwd_bhsd_sliding_window 0.0264273 0.0255607 1.0339
topk_selector 0.0554232 0.0535515 1.03495
example_mha_sink_bwd_bhsd 0.0635433 0.06133 1.03609
example_gqa_sink_bwd_bhsd 0.0432016 0.041658 1.03705
example_convolution 1.38575 1.33395 1.03883
example_mla_decode 0.480748 0.461002 1.04283
example_dequant_gemm_bf16_fp4_hopper 0.598253 0.570029 1.04951

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LeiWang1999 LeiWang1999 merged commit 0fa16b4 into main Dec 30, 2025
7 checks passed
@LeiWang1999 LeiWang1999 deleted the fix/layout-equality-check branch December 31, 2025 03:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant