diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 8b0d37cb8..4ac0a11db 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -271,8 +271,16 @@ std::pair LayoutNode::InverseWithLevel() const { "NoCheck; symbolic dims: " << symbolic_dims; } - arith::IterMapResult res = - arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer); + + arith::IterMapResult res; + res = arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer); + // We do not find the exact level of current layout, because non-bijective + // will leave loop guard to handle the boundary. So just set the level flag + // and use bijective res to generate correct inverse layour + if (!res->errors.empty()) { + level = arith::IterMapLevel::NoCheck; + } + if (!res->errors.empty()) { std::ostringstream msg; msg << "Layout " << DebugOutput() << " has errors: " << res->errors; @@ -607,7 +615,16 @@ arith::IterMapResult FragmentNode::DetectInjective() const { << "NoCheck; symbolic dims: " << symbolic_dims; } - return arith::DetectIterMap(indices, getVarMap(), 1, level, &analyzer); + arith::IterMapResult res; + while (true) { + res = arith::DetectIterMap(indices, getVarMap(), 1, level, &analyzer); + if (res->errors.empty()) + break; + if (level == arith::IterMapLevel::NoCheck) + break; + level = static_cast((static_cast(level) * 2) + 1); + } + return res; } PrimExpr FragmentNode::ThreadExtent() const { diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 822f52622..7e6591ba2 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -546,6 +546,29 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, } } + // To avoid fragment layout exceed loop extent's boundary + if (loop_layout_.defined()) { + auto inv_layout = loop_layout_->Inverse(); + Array thread_indices; + for (size_t i = 0; i < loop_layout_->OutputDim(); i++) { + thread_indices.push_back(0); + } + thread_indices.push_back(InputPlaceholder(0) - T.thread_bounds->min); + auto logical_indices = inv_layout->Forward(thread_indices); + + // Check if the logical indices exceed the original loop_vars_ extent + for (size_t i = 0; i < loop_vars_.size(); ++i) { + PrimExpr logical_i = logical_indices[i]; + PrimExpr original_extent = loop_vars_[i]->dom->extent; + + // If the logical index cannot be proved to be less than the original + // extent, add a predicate + if (!analyzer_.CanProve(LT(logical_i, original_extent))) { + AddPredicate(LT(logical_i, original_extent)); + } + } + } + if (!analyzer_.CanProveEqual(loop_thread_extent, block_size)) { AddPredicate( LT(InputPlaceholder(0), loop_thread_extent + T.thread_bounds->min)); diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index 35166e4b4..a233f9b82 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -83,9 +83,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, Array loop_extents; auto inverse_info = loop_layout->InverseWithLevel(); auto inv_loop = inverse_info.first; - // Must check the guard if the layout can not be proved as bijective - bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective; auto indices = inv_loop->Forward(Array(vars.begin(), vars.end())); + // Normalize thread var once so we can reuse the same substitution later. Map thread_offset_map; bool has_thread_offset = false; @@ -94,15 +93,24 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, thread_offset_map.Set(thread_var, thread_var - range->min); has_thread_offset = true; } + bool bounds_match = true; for (int i = 0; i < old_loop_depth; i++) { const ForNode *loop = body.as(); ICHECK(loop != nullptr) << "No extra statements are allowed between nested parallel loops."; + if (!analyzer->CanProve(loop->extent == loop_layout->InputShape()[i]) || + !analyzer->CanProve(loop->min == 0)) { + bounds_match = false; + } vmap.Set(loop->loop_var, indices[i]); loop_mins.push_back(loop->min); loop_extents.push_back(loop->extent); body = loop->body; } + // Must check the guard if the layout can not be proved as bijective + bool need_guard = + inverse_info.second != arith::IterMapLevel::Bijective || !bounds_match; + // substitute and re-construct the serial loop body = Substitute(body, vmap); // Guard executes the recovered loop body only if each inverse-mapped iterator diff --git a/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py b/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py new file mode 100644 index 000000000..824ebeedf --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_loop_partition_boundary.py @@ -0,0 +1,76 @@ +import tilelang +import tilelang.language as T +import tilelang.testing +from tilelang import tvm as tvm +from tilelang.utils.target import determine_target + + +def _tilelang_transform_loop_partition_boundary(): + def before(): + @T.prim_func + def main( + S: T.Tensor((8), T.bfloat16), + D: T.Tensor((4, 64), T.bfloat16), + ): + with T.Kernel(1, threads=128): + S_shared = T.alloc_shared((8), T.bfloat16) + S_fragment = T.alloc_fragment((8), T.float32) + D_shared = T.alloc_shared((4, 64), T.bfloat16) + + T.copy(S, S_shared) + T.copy(S_shared, S_fragment) + for k in T.serial(64): + for i in T.Parallel(4): + D_shared[i, k] = S_fragment[i] + T.copy(D_shared, D) + + return main + + def after(): + @T.prim_func + def main( + S: T.Tensor((8), T.bfloat16), + D: T.Tensor((4, 64), T.bfloat16), + ): + with T.Kernel(1, threads=128): + tx = T.get_thread_binding(0) + S_shared = T.alloc_shared((8), T.bfloat16) + S_fragment = T.alloc_fragment((8), T.float32) + D_shared = T.alloc_shared((4, 64), T.bfloat16) + + T.copy(S, S_shared) + T.copy(S_shared, S_fragment) + for k in T.serial(64): + if tx < 4: + for i in T.Parallel(4): + D_shared[i, k] = S_fragment[i] + T.copy(D_shared, D) + + return main + + return tvm.IRModule({"main": before()}), tvm.IRModule({"main": after()}) + + +def boundary_check(): + before, after = _tilelang_transform_loop_partition_boundary() + target = tvm.target.Target(determine_target("auto")) + with target: + with tvm.transform.PassContext(): + mod = tvm.tir.transform.BindTarget(target)(before) + mod = tilelang.transform.LayoutInference()(mod) + mod = tilelang.transform.LowerTileOp()(mod) + mod = tvm.tir.transform.Simplify()(mod) + with tvm.transform.PassContext(): + ref_mod = tvm.tir.transform.BindTarget(target)(after) + ref_mod = tilelang.transform.LayoutInference()(ref_mod) + ref_mod = tilelang.transform.LowerTileOp()(ref_mod) + ref_mod = tvm.tir.transform.Simplify()(ref_mod) + assert mod["main"].script() == ref_mod["main"].script(), "mod and ref_mod are not structural equal" + + +def test_tilelang_transform_loop_partition_boundary(): + boundary_check() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_lower_tile_op.py b/testing/python/transform/test_tilelang_transform_lower_tile_op.py index 16c7cb802..7245d2b6f 100644 --- a/testing/python/transform/test_tilelang_transform_lower_tile_op.py +++ b/testing/python/transform/test_tilelang_transform_lower_tile_op.py @@ -72,12 +72,14 @@ def main( return tvm.IRModule({"main": main}) + before_mod, after_mod = before(), after() with tvm.transform.PassContext(): - mod = tvm.tir.transform.BindTarget(auto_target)(before()) + mod = tvm.tir.transform.BindTarget(auto_target)(before_mod) mod = tl.transform.LowerTileOp()(mod) mod = tvm.tir.transform.Simplify()(mod) - ref_mod = tvm.tir.transform.BindTarget(auto_target)(after()) - ref_mod = tvm.tir.transform.Simplify()(ref_mod) + with tvm.transform.PassContext(): + ref_mod = tvm.tir.transform.BindTarget(auto_target)(after_mod) + ref_mod = tvm.tir.transform.Simplify()(ref_mod) # Note(tzj): The structures are equal except the argument in "T.reads" function. # The difference is just between the first index and the indices range, which is totally equivalent tvm.ir.structural_equal(mod, ref_mod)