Skip to content

Commit 5345b05

Browse files
committed
[Fix][TIR] LowerThreadAllreduce warp reduction mask
The warp reduction implemented by "shuffle down" primitive takes a mask denoting the active threads within the warp that participate in this shuffle. Previously we compute the mask, while in practice we find that it results in "CUDA illegal instruction" error on NVIDIA H100 GPU when the mask is set, and the issue is gone if we do not update the mask. Therefore, this PR updates the allreduce lowering to remove the mask update. Confirmed the correctness on the following devices: * NVIDIA H100, * NVIDIA RTX 4090, * AMD Radeon 7900 XTX, * Apple M2 Ultra.
1 parent 99defd2 commit 5345b05

File tree

1 file changed

+0
-7
lines changed

1 file changed

+0
-7
lines changed

src/tir/transforms/lower_thread_allreduce.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
294294
PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
295295

296296
if (reduce_extent <= warp_size_) {
297-
if (group_extent > 1 && reduce_extent < warp_size_) {
298-
mask = mask &
299-
(((1 << reduce_extent) - 1) << (reduce_extent * cast(mask_dtype, group_index)));
300-
}
301297
std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
302298
values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq);
303299

@@ -352,9 +348,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
352348
values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i],
353349
/*indices=*/{group_index * n_warps + reduce_index});
354350
}
355-
if (n_warps < warp_size_) {
356-
mask = mask & (((1 << n_warps) - 1) << (group_index * n_warps));
357-
}
358351
std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
359352
values, types, combiner, reduce_index, n_warps, group_index, mask,
360353
/*predicate=*/reduce_index < make_const(reduce_index->dtype, n_warps), &seq);

0 commit comments

Comments
 (0)