From 6a0444e9a8804456487982c4c6970077d7b8a034 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 24 Aug 2025 22:54:12 +0800 Subject: [PATCH 1/2] Update test parameters and remove debug print statement - Adjusted test cases in `test_tilelang_dynamic_symbolic_bench.py` to use smaller matrix sizes (1024x1024) for improved performance and quicker execution. - Removed a debug print statement from `phase.py` to clean up the code and enhance clarity. --- .../dynamic/test_tilelang_dynamic_symbolic_bench.py | 8 ++++---- tilelang/engine/phase.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py b/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py index 2f534de4b..d67f055d9 100644 --- a/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py +++ b/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py @@ -550,10 +550,10 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): def test_all(): - run_assert_tl_matmul_block_static(16384, 16384, 16384, 128, 128, 32) - run_assert_tl_matmul_block_dynamic_m(16384, 16384, 16384, 128, 128, 32) - run_assert_tl_matmul_block_dynamic_mn(16384, 16384, 16384, 128, 128, 32) - run_assert_tl_matmul_block_dynamic_mnk(16384, 16384, 16384, 128, 128, 32) + run_assert_tl_matmul_block_static(1024, 1024, 1024, 128, 128, 32) + run_assert_tl_matmul_block_dynamic_m(1024, 1024, 1024, 128, 128, 32) + run_assert_tl_matmul_block_dynamic_mn(1024, 1024, 1024, 128, 128, 32) + run_assert_tl_matmul_block_dynamic_mnk(1024, 1024, 1024, 128, 128, 32) if __name__ == "__main__": diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 91712a664..74874ae11 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -165,7 +165,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.MergeSharedMemoryAllocations( enable_aggressive_merge=enable_aggressive_merge)( mod) - print("mod \n", mod) mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod) # Inject PTX async copy must behind the thread sync pass From 12d0094c2aafd74c43f7a87117295d22f3052711 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 24 Aug 2025 23:03:20 +0800 Subject: [PATCH 2/2] Refactor loop stack management in warp_specialized_rewriter - Introduced a new `LoopInfo` struct to encapsulate loop variable details, including `loop_var`, `extent`, and `min`, enhancing clarity and maintainability. - Updated the `loop_stack_` to utilize `LoopInfo` instead of a pair, improving type safety and readability. - Adjusted linear index calculations to account for the new structure, ensuring correct behavior in loop transformations. --- src/transform/warp_specialized_rewriter.cc | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index 39cc17ea8..f440e946b 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -24,6 +24,12 @@ using namespace tir; using namespace runtime; using arith::IRVisitorWithAnalyzer; +struct LoopInfo { + Var loop_var; + PrimExpr extent; + PrimExpr min; +}; + enum class Role { kConsumer, kProducer, kBoth }; class ProducerBufferDetector : public StmtExprVisitor { @@ -838,7 +844,7 @@ class WSCodeEmitter : public StmtMutator { num_stages = static_cast(num_stages_anno->as()->value); ICHECK(num_stages_ == 1) << "Nested pipeline not supported."; } - loop_stack_.emplace_back(op->loop_var, op->extent); + loop_stack_.emplace_back(LoopInfo{op->loop_var, op->extent, op->min}); Array> group_info_array; Array order_info_array; @@ -871,15 +877,14 @@ class WSCodeEmitter : public StmtMutator { num_stages_ = num_stages; pipeline_info_ = pipeline_info; - PrimExpr linear_index = loop_stack_[0].first; + PrimExpr linear_index = loop_stack_[0].loop_var - loop_stack_[0].min; for (size_t i = 1; i < loop_stack_.size(); ++i) { - linear_index = - linear_index * loop_stack_[i].second + loop_stack_[i].first; + linear_index = linear_index * loop_stack_[i].extent + + (loop_stack_[i].loop_var - loop_stack_[i].min); } stage_ = FloorMod(linear_index, num_stages); parity_ = FloorMod( parity_before * op->extent + FloorDiv(linear_index, num_stages), 2); - auto result = FilterByRole(op); Stmt grouped_for_node; @@ -1137,7 +1142,7 @@ class WSCodeEmitter : public StmtMutator { PrimExpr parity_ = 0; PrimExpr stage_ = 0; int num_stages_ = 1; - std::vector> loop_stack_; + std::vector loop_stack_; Var thread_var_; bool mbarrier_only_ = false; PipelineInfo pipeline_info_;