diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 37d8f67580fe..dde33fa2678d 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -294,10 +294,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); if (reduce_extent <= warp_size_) { - if (group_extent > 1 && reduce_extent < warp_size_) { - mask = mask & - (((1 << reduce_extent) - 1) << (reduce_extent * cast(mask_dtype, group_index))); - } std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce( values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq); @@ -352,9 +348,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i], /*indices=*/{group_index * n_warps + reduce_index}); } - if (n_warps < warp_size_) { - mask = mask & (((1 << n_warps) - 1) << (group_index * n_warps)); - } std::tie(reduce_results, local_bufs) = MakeWarpAllreduce( values, types, combiner, reduce_index, n_warps, group_index, mask, /*predicate=*/reduce_index < make_const(reduce_index->dtype, n_warps), &seq); diff --git a/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py b/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py index d8c9568da90e..18d6339349ff 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py @@ -342,10 +342,7 @@ def expected(A: T.Buffer((32, 8), "float32"), B: T.Buffer((32,), "float32")): t0 = T.decl_buffer([1], "float32", scope="local") A_1 = T.Buffer((256,), data=A.data) red_buf0_1[0] = A_1[threadIdx_y * 8 + threadIdx_x] - mask[0] = T.bitwise_and( - T.tvm_warp_activemask(), - T.shift_left(T.uint32(255), T.uint32(8) * T.Cast("uint32", threadIdx_y)), - ) + mask[0] = T.tvm_warp_activemask() t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 4, 32, 32) red_buf0_1[0] = red_buf0_1[0] + t0[0] t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 2, 32, 32) @@ -421,7 +418,7 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")) T.tvm_storage_sync("shared") if threadIdx_x < 4: red_buf0[0] = red_buf_staging[threadIdx_x] - mask[0] = T.bitwise_and(T.tvm_warp_activemask(), T.uint32(15)) + mask[0] = T.tvm_warp_activemask() t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 2, 32, 32) red_buf0[0] = red_buf0[0] + t0[0] t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 1, 32, 32) @@ -573,9 +570,7 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): T.tvm_storage_sync("shared") if threadIdx_x < 4: red_buf0[0] = red_buf_staging[threadIdx_y * 4 + threadIdx_x] - mask[0] = T.bitwise_and( - T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y * 4)) - ) + mask[0] = T.tvm_warp_activemask() t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 2, 32, 32) red_buf0[0] = red_buf0[0] + t0[0] t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 1, 32, 32) @@ -657,9 +652,7 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): T.tvm_storage_sync("shared") if threadIdx_x < 16: red_buf0[0] = red_buf_staging[threadIdx_y * 16 + threadIdx_x] - mask[0] = T.bitwise_and( - T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y * 16)) - ) + mask[0] = T.tvm_warp_activemask() t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 8, 32, 32) red_buf0[0] = red_buf0[0] + t0[0] t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 4, 32, 32)