From 7d5069983615b62ae875e550d73b06b69b27469c Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Mon, 26 Jan 2026 20:23:32 +0800 Subject: [PATCH 01/16] [BugFix] Fix boundary check in Parallel Op --- src/transform/loop_partition.cc | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index 35166e4b4..0e0b0d855 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -83,9 +83,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, Array loop_extents; auto inverse_info = loop_layout->InverseWithLevel(); auto inv_loop = inverse_info.first; - // Must check the guard if the layout can not be proved as bijective - bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective; auto indices = inv_loop->Forward(Array(vars.begin(), vars.end())); + // Normalize thread var once so we can reuse the same substitution later. Map thread_offset_map; bool has_thread_offset = false; @@ -94,15 +93,25 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, thread_offset_map.Set(thread_var, thread_var - range->min); has_thread_offset = true; } + + bool bounds_match = true; for (int i = 0; i < old_loop_depth; i++) { const ForNode *loop = body.as(); ICHECK(loop != nullptr) << "No extra statements are allowed between nested parallel loops."; + if (!analyzer->CanProve(loop->extent == loop_layout->InputShape()[i]) || + !analyzer->CanProve(loop->min == 0)) { + bounds_match = false; + } vmap.Set(loop->loop_var, indices[i]); loop_mins.push_back(loop->min); loop_extents.push_back(loop->extent); body = loop->body; } + + // Must check the guard if the layout can not be proved as bijective or bounds don't match + bool need_guard = (inverse_info.second != arith::IterMapLevel::Bijective) || !bounds_match; + // substitute and re-construct the serial loop body = Substitute(body, vmap); // Guard executes the recovered loop body only if each inverse-mapped iterator From cb3f86bc34df0c21389c4eccf81cbe35cd901761 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Mon, 26 Jan 2026 20:24:37 +0800 Subject: [PATCH 02/16] [Lint] --- src/transform/loop_partition.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index 0e0b0d855..1c2c21c83 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -109,8 +109,10 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, body = loop->body; } - // Must check the guard if the layout can not be proved as bijective or bounds don't match - bool need_guard = (inverse_info.second != arith::IterMapLevel::Bijective) || !bounds_match; + // Must check the guard if the layout can not be proved as bijective or bounds + // don't match + bool need_guard = + (inverse_info.second != arith::IterMapLevel::Bijective) || !bounds_match; // substitute and re-construct the serial loop body = Substitute(body, vmap); From 0c90536b59e103081a36c271b1b4b968f67228c4 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 27 Jan 2026 11:15:57 +0800 Subject: [PATCH 03/16] [Testing] Add testing script for loop layout boundary check --- ...elang_transform_loop_partition_boundary.py | 73 +++++++++++++++++++ .../test_tilelang_transform_lower_tile_op.py | 8 +- 2 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 testing/python/transform/test_tilelang_transform_loop_partition_boundary.py diff --git a/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py b/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py new file mode 100644 index 000000000..d43444dfc --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py @@ -0,0 +1,73 @@ +import tilelang +import tilelang.language as T +import tilelang.testing +from tilelang import tvm as tvm +from tilelang.utils.target import determine_target + + +def _tilelang_transform_loop_partition_boundary(): + def before(): + @T.prim_func + def main( + S: T.Tensor((8), T.bfloat16), + D: T.Tensor((4, 64), T.bfloat16), + ): + with T.Kernel(1, threads=128): + S_shared = T.alloc_shared((8), T.bfloat16) + S_fragment = T.alloc_fragment((8), T.float32) + D_shared = T.alloc_shared((4, 64), T.bfloat16) + + T.copy(S, S_shared) + T.copy(S_shared, S_fragment) + for k in T.serial(64): + for i in T.Parallel(4): + D_shared[i, k] = S_fragment[i] + T.copy(D_shared, D) + return main + + def after(): + @T.prim_func + def main( + S: T.Tensor((8), T.bfloat16), + D: T.Tensor((4, 64), T.bfloat16), + ): + with T.Kernel(1, threads=128): + S_shared = T.alloc_shared((8), T.bfloat16) + S_fragment = T.alloc_fragment((8), T.float32) + D_shared = T.alloc_shared((4, 64), T.bfloat16) + + T.copy(S, S_shared) + T.copy(S_shared, S_fragment) + for k in T.serial(64): + for i in T.Parallel(8): + if i < 4: + D_shared[i, k] = S_fragment[i] + T.copy(D_shared, D) + return main + + return tvm.IRModule({"main": before()}), tvm.IRModule({"main": after()}) + + +def boundary_check(): + before, after = _tilelang_transform_loop_partition_boundary() + target = tvm.target.Target(determine_target("auto")) + with target: + with tvm.transform.PassContext(): + mod = tvm.tir.transform.BindTarget(target)(before) + mod = tilelang.transform.LayoutInference()(mod) + mod = tilelang.transform.LowerTileOp()(mod) + mod = tvm.tir.transform.Simplify()(mod) + with tvm.transform.PassContext(): + ref_mod = tvm.tir.transform.BindTarget(target)(after) + ref_mod = tilelang.transform.LayoutInference()(ref_mod) + ref_mod = tilelang.transform.LowerTileOp()(ref_mod) + ref_mod = tvm.tir.transform.Simplify()(ref_mod) + assert mod["main"].script() == ref_mod["main"].script(), "mod and ref_mod are not structural equal" + + +def test_tilelang_transform_loop_partition_boundary(): + boundary_check() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_lower_tile_op.py b/testing/python/transform/test_tilelang_transform_lower_tile_op.py index 16c7cb802..7245d2b6f 100644 --- a/testing/python/transform/test_tilelang_transform_lower_tile_op.py +++ b/testing/python/transform/test_tilelang_transform_lower_tile_op.py @@ -72,12 +72,14 @@ def main( return tvm.IRModule({"main": main}) + before_mod, after_mod = before(), after() with tvm.transform.PassContext(): - mod = tvm.tir.transform.BindTarget(auto_target)(before()) + mod = tvm.tir.transform.BindTarget(auto_target)(before_mod) mod = tl.transform.LowerTileOp()(mod) mod = tvm.tir.transform.Simplify()(mod) - ref_mod = tvm.tir.transform.BindTarget(auto_target)(after()) - ref_mod = tvm.tir.transform.Simplify()(ref_mod) + with tvm.transform.PassContext(): + ref_mod = tvm.tir.transform.BindTarget(auto_target)(after_mod) + ref_mod = tvm.tir.transform.Simplify()(ref_mod) # Note(tzj): The structures are equal except the argument in "T.reads" function. # The difference is just between the first index and the indices range, which is totally equivalent tvm.ir.structural_equal(mod, ref_mod) From f58db537c2cf5ad3c454fd0072d74797c1166f35 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 27 Jan 2026 11:16:10 +0800 Subject: [PATCH 04/16] [Lint] --- .../test_tilelang_transform_loop_partition_boundary.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py b/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py index d43444dfc..a40b52d04 100644 --- a/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py +++ b/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py @@ -23,6 +23,7 @@ def main( for i in T.Parallel(4): D_shared[i, k] = S_fragment[i] T.copy(D_shared, D) + return main def after(): @@ -43,8 +44,9 @@ def main( if i < 4: D_shared[i, k] = S_fragment[i] T.copy(D_shared, D) + return main - + return tvm.IRModule({"main": before()}), tvm.IRModule({"main": after()}) From 8d3226c9b1b62a5aafffd5dcf45baab28ad328c9 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 27 Jan 2026 14:34:59 +0800 Subject: [PATCH 05/16] [Layout] Check inverse layout and loop extent in Parallel Op --- src/op/parallel.cc | 22 +++++++++++++++++++ src/transform/loop_partition.cc | 21 ++---------------- ...elang_transform_loop_partition_boundary.py | 5 +++-- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 822f52622..39205ad21 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -546,6 +546,28 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, } } + // To avoid fragment layout exceed loop extent's boundary + if (loop_layout_.defined()) { + auto inv_layout = loop_layout_->Inverse(); + Array thread_indices; + for (size_t i = 0; i < loop_layout_->OutputDim(); i++) { + thread_indices.push_back(0); + } + thread_indices.push_back(InputPlaceholder(0) - T.thread_bounds->min); + auto logical_indices = inv_layout->Forward(thread_indices); + + // Check if the logical indices exceed the original loop_vars_ extent + for (size_t i = 0; i < loop_vars_.size(); ++i) { + PrimExpr logical_i = logical_indices[i]; + PrimExpr original_extent = loop_vars_[i]->dom->extent; + + // If the logical index cannot be proved to be less than the original extent, add a predicate + if (!analyzer_.CanProve(LT(logical_i, original_extent))) { + AddPredicate(LT(logical_i, original_extent)); + } + } + } + if (!analyzer_.CanProveEqual(loop_thread_extent, block_size)) { AddPredicate( LT(InputPlaceholder(0), loop_thread_extent + T.thread_bounds->min)); diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index 1c2c21c83..2b46c2bcb 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -94,25 +94,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, has_thread_offset = true; } - bool bounds_match = true; - for (int i = 0; i < old_loop_depth; i++) { - const ForNode *loop = body.as(); - ICHECK(loop != nullptr) - << "No extra statements are allowed between nested parallel loops."; - if (!analyzer->CanProve(loop->extent == loop_layout->InputShape()[i]) || - !analyzer->CanProve(loop->min == 0)) { - bounds_match = false; - } - vmap.Set(loop->loop_var, indices[i]); - loop_mins.push_back(loop->min); - loop_extents.push_back(loop->extent); - body = loop->body; - } - - // Must check the guard if the layout can not be proved as bijective or bounds - // don't match - bool need_guard = - (inverse_info.second != arith::IterMapLevel::Bijective) || !bounds_match; + // Must check the guard if the layout can not be proved as bijective + bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective; // substitute and re-construct the serial loop body = Substitute(body, vmap); diff --git a/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py b/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py index a40b52d04..85d47c87b 100644 --- a/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py +++ b/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py @@ -33,6 +33,7 @@ def main( D: T.Tensor((4, 64), T.bfloat16), ): with T.Kernel(1, threads=128): + tx = T.get_thread_binding(0) S_shared = T.alloc_shared((8), T.bfloat16) S_fragment = T.alloc_fragment((8), T.float32) D_shared = T.alloc_shared((4, 64), T.bfloat16) @@ -40,8 +41,8 @@ def main( T.copy(S, S_shared) T.copy(S_shared, S_fragment) for k in T.serial(64): - for i in T.Parallel(8): - if i < 4: + if tx < 4: + for i in T.Parallel(8): D_shared[i, k] = S_fragment[i] T.copy(D_shared, D) From 8598d89e1e2506481862dff31bac2ccac5478f2a Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 27 Jan 2026 14:36:58 +0800 Subject: [PATCH 06/16] [Lint] --- src/op/parallel.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 39205ad21..7e6591ba2 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -561,7 +561,8 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, PrimExpr logical_i = logical_indices[i]; PrimExpr original_extent = loop_vars_[i]->dom->extent; - // If the logical index cannot be proved to be less than the original extent, add a predicate + // If the logical index cannot be proved to be less than the original + // extent, add a predicate if (!analyzer_.CanProve(LT(logical_i, original_extent))) { AddPredicate(LT(logical_i, original_extent)); } From e289b252183de74852c7bdb58dc06ad30b12a380 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 27 Jan 2026 14:43:09 +0800 Subject: [PATCH 07/16] Test Fix --- .../test_tilelang_transform_loop_partition_boundary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py b/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py index 85d47c87b..824ebeedf 100644 --- a/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py +++ b/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py @@ -42,7 +42,7 @@ def main( T.copy(S_shared, S_fragment) for k in T.serial(64): if tx < 4: - for i in T.Parallel(8): + for i in T.Parallel(4): D_shared[i, k] = S_fragment[i] T.copy(D_shared, D) From 87caaaa5820d4eeb550a374939ae938d66944646 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 27 Jan 2026 14:45:57 +0800 Subject: [PATCH 08/16] Revert --- src/transform/loop_partition.cc | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index 2b46c2bcb..3365104d1 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -93,7 +93,15 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, thread_offset_map.Set(thread_var, thread_var - range->min); has_thread_offset = true; } - + for (int i = 0; i < old_loop_depth; i++) { + const ForNode *loop = body.as(); + ICHECK(loop != nullptr) + << "No extra statements are allowed between nested parallel loops."; + vmap.Set(loop->loop_var, indices[i]); + loop_mins.push_back(loop->min); + loop_extents.push_back(loop->extent); + body = loop->body; + } // Must check the guard if the layout can not be proved as bijective bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective; From e2e9b07eceef0f7085b0011e9655ac3537aaf1b2 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 27 Jan 2026 18:55:56 +0800 Subject: [PATCH 09/16] [BugFix] Use loop variables in predicate generation --- src/op/parallel.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 7e6591ba2..7d4a97e3c 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -551,7 +551,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, auto inv_layout = loop_layout_->Inverse(); Array thread_indices; for (size_t i = 0; i < loop_layout_->OutputDim(); i++) { - thread_indices.push_back(0); + thread_indices.push_back(loop_vars_[i]->var); } thread_indices.push_back(InputPlaceholder(0) - T.thread_bounds->min); auto logical_indices = inv_layout->Forward(thread_indices); From 46b3458898c07969a84afd9eb4cb88f4bcc70fdd Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 27 Jan 2026 19:13:58 +0800 Subject: [PATCH 10/16] Fix --- src/op/parallel.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 7d4a97e3c..5038e08e7 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -551,7 +551,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, auto inv_layout = loop_layout_->Inverse(); Array thread_indices; for (size_t i = 0; i < loop_layout_->OutputDim(); i++) { - thread_indices.push_back(loop_vars_[i]->var); + thread_indices.push_back(InputPlaceholder(i + 1)); } thread_indices.push_back(InputPlaceholder(0) - T.thread_bounds->min); auto logical_indices = inv_layout->Forward(thread_indices); From 5c96c104bd92a32f51861d949d3af2dc7058624f Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 27 Jan 2026 20:51:00 +0800 Subject: [PATCH 11/16] Revert to consider both first address of each thread and fine-grained guard in partition loop --- src/op/parallel.cc | 2 +- src/transform/loop_partition.cc | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 5038e08e7..7e6591ba2 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -551,7 +551,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, auto inv_layout = loop_layout_->Inverse(); Array thread_indices; for (size_t i = 0; i < loop_layout_->OutputDim(); i++) { - thread_indices.push_back(InputPlaceholder(i + 1)); + thread_indices.push_back(0); } thread_indices.push_back(InputPlaceholder(0) - T.thread_bounds->min); auto logical_indices = inv_layout->Forward(thread_indices); diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index 3365104d1..9dd050e50 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -93,17 +93,22 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, thread_offset_map.Set(thread_var, thread_var - range->min); has_thread_offset = true; } + bool bounds_match = true; for (int i = 0; i < old_loop_depth; i++) { const ForNode *loop = body.as(); ICHECK(loop != nullptr) << "No extra statements are allowed between nested parallel loops."; + if (!analyzer->CanProve(loop->extent == loop_layout->InputShape()[i]) || + !analyzer->CanProve(loop->min == 0)) { + bounds_match = false; + } vmap.Set(loop->loop_var, indices[i]); loop_mins.push_back(loop->min); loop_extents.push_back(loop->extent); body = loop->body; } // Must check the guard if the layout can not be proved as bijective - bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective; + bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective || !bounds_match; // substitute and re-construct the serial loop body = Substitute(body, vmap); From 9d3286452f576a828629471725af05db252fadb0 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 27 Jan 2026 20:51:13 +0800 Subject: [PATCH 12/16] Lint --- src/transform/loop_partition.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index 9dd050e50..a233f9b82 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -108,7 +108,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, body = loop->body; } // Must check the guard if the layout can not be proved as bijective - bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective || !bounds_match; + bool need_guard = + inverse_info.second != arith::IterMapLevel::Bijective || !bounds_match; // substitute and re-construct the serial loop body = Substitute(body, vmap); From 13d920b8bc2a8a0f8aea67ab9559b3300499d6c9 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 27 Jan 2026 21:56:05 +0800 Subject: [PATCH 13/16] Add robust inverse layout itermap check --- src/layout/layout.cc | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 8b0d37cb8..3bd44f947 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -271,8 +271,16 @@ std::pair LayoutNode::InverseWithLevel() const { "NoCheck; symbolic dims: " << symbolic_dims; } - arith::IterMapResult res = - arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer); + + arith::IterMapResult res; + while (true) { + res = arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer); + if (res->errors.empty()) break; + if (level == arith::IterMapLevel::NoCheck) break; + level = static_cast((static_cast(level) * 2) + 1); + } + LOG(INFO) << "Layout::InverseWithLevel level: " << static_cast(level); + if (!res->errors.empty()) { std::ostringstream msg; msg << "Layout " << DebugOutput() << " has errors: " << res->errors; @@ -607,7 +615,14 @@ arith::IterMapResult FragmentNode::DetectInjective() const { << "NoCheck; symbolic dims: " << symbolic_dims; } - return arith::DetectIterMap(indices, getVarMap(), 1, level, &analyzer); + arith::IterMapResult res; + while (true) { + res = arith::DetectIterMap(indices, getVarMap(), 1, level, &analyzer); + if (res->errors.empty()) break; + if (level == arith::IterMapLevel::NoCheck) break; + level = static_cast((static_cast(level) * 2) + 1); + } + return res; } PrimExpr FragmentNode::ThreadExtent() const { From b13e9fb227649b6fb9851f3dcf4918229140e841 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 27 Jan 2026 22:06:37 +0800 Subject: [PATCH 14/16] Lint --- src/layout/layout.cc | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 3bd44f947..2ee15b1f0 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -274,9 +274,12 @@ std::pair LayoutNode::InverseWithLevel() const { arith::IterMapResult res; while (true) { - res = arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer); - if (res->errors.empty()) break; - if (level == arith::IterMapLevel::NoCheck) break; + res = + arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer); + if (res->errors.empty()) + break; + if (level == arith::IterMapLevel::NoCheck) + break; level = static_cast((static_cast(level) * 2) + 1); } LOG(INFO) << "Layout::InverseWithLevel level: " << static_cast(level); @@ -618,8 +621,10 @@ arith::IterMapResult FragmentNode::DetectInjective() const { arith::IterMapResult res; while (true) { res = arith::DetectIterMap(indices, getVarMap(), 1, level, &analyzer); - if (res->errors.empty()) break; - if (level == arith::IterMapLevel::NoCheck) break; + if (res->errors.empty()) + break; + if (level == arith::IterMapLevel::NoCheck) + break; level = static_cast((static_cast(level) * 2) + 1); } return res; From b7bf609e244db7781cbf4af4e50b8994089b6aed Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Wed, 28 Jan 2026 00:19:01 +0800 Subject: [PATCH 15/16] Fix layout level result bug --- src/layout/layout.cc | 14 +++++--------- .../test_tilelang_fragment_loop_checker.py | 1 - 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 2ee15b1f0..4fd2a4d21 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -273,16 +273,12 @@ std::pair LayoutNode::InverseWithLevel() const { } arith::IterMapResult res; - while (true) { - res = - arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer); - if (res->errors.empty()) - break; - if (level == arith::IterMapLevel::NoCheck) - break; - level = static_cast((static_cast(level) * 2) + 1); + res = arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer); + // We do not find the exact level of current layout, because non-bijective will leave loop guard to handle the boundary. + // So just set the level flag and use bijective res to generate correct inverse layour + if (!res->errors.empty()) { + level = arith::IterMapLevel::NoCheck; } - LOG(INFO) << "Layout::InverseWithLevel level: " << static_cast(level); if (!res->errors.empty()) { std::ostringstream msg; diff --git a/testing/python/analysis/test_tilelang_fragment_loop_checker.py b/testing/python/analysis/test_tilelang_fragment_loop_checker.py index 99458f1c8..d13a372df 100644 --- a/testing/python/analysis/test_tilelang_fragment_loop_checker.py +++ b/testing/python/analysis/test_tilelang_fragment_loop_checker.py @@ -1,7 +1,6 @@ import tilelang import tilelang.testing import tilelang.language as T -import pytest @tilelang.jit From a0789806a9b0a7edd6955726d8db941e8d508c0c Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Wed, 28 Jan 2026 00:19:34 +0800 Subject: [PATCH 16/16] Lint --- src/layout/layout.cc | 5 +++-- .../python/analysis/test_tilelang_fragment_loop_checker.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 4fd2a4d21..4ac0a11db 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -274,8 +274,9 @@ std::pair LayoutNode::InverseWithLevel() const { arith::IterMapResult res; res = arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer); - // We do not find the exact level of current layout, because non-bijective will leave loop guard to handle the boundary. - // So just set the level flag and use bijective res to generate correct inverse layour + // We do not find the exact level of current layout, because non-bijective + // will leave loop guard to handle the boundary. So just set the level flag + // and use bijective res to generate correct inverse layour if (!res->errors.empty()) { level = arith::IterMapLevel::NoCheck; } diff --git a/testing/python/analysis/test_tilelang_fragment_loop_checker.py b/testing/python/analysis/test_tilelang_fragment_loop_checker.py index d13a372df..99458f1c8 100644 --- a/testing/python/analysis/test_tilelang_fragment_loop_checker.py +++ b/testing/python/analysis/test_tilelang_fragment_loop_checker.py @@ -1,6 +1,7 @@ import tilelang import tilelang.testing import tilelang.language as T +import pytest @tilelang.jit