diff --git a/src/transform/loop_unswitching.cc b/src/transform/loop_unswitching.cc index aab8f7fdc..7767866bb 100644 --- a/src/transform/loop_unswitching.cc +++ b/src/transform/loop_unswitching.cc @@ -201,8 +201,8 @@ bool UsesLoopVarThroughLetBindings( auto it = let_bindings->find(var_node); if (it != let_bindings->end()) { // Check if the bound expression uses the loop variable - if (UsesVar(it->second, - [&](const VarNode *v) { return v == loop_var.get(); })) { + if (UsesLoopVarThroughLetBindings(it->second, loop_var, + let_bindings)) { uses_loop_var = true; } } diff --git a/testing/python/transform/test_tilelang_transform_loop_unswitching.py b/testing/python/transform/test_tilelang_transform_loop_unswitching.py index 91ff355b9..212b57732 100644 --- a/testing/python/transform/test_tilelang_transform_loop_unswitching.py +++ b/testing/python/transform/test_tilelang_transform_loop_unswitching.py @@ -386,5 +386,19 @@ def expected( _check(before, expected) +def test_no_hoist_multiple_let(): + @tilelang.jit() + def get_fused_mapping_kernel(topk_idx: T.Tensor[(1,), T.int32]): + with T.Kernel(): + _tmp1 = T.alloc_shared((1,), "int") + for i in T.serial(0, 4, 2): + _tmp2 = topk_idx[i] + T.assume(0 <= _tmp2 < 1) + if _tmp2 != -1: + T.atomic_add(_tmp1[_tmp2], 1) + + get_fused_mapping_kernel.compile() + + if __name__ == "__main__": tilelang.testing.main()