diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 8682706db899..966f6d31c725 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -162,13 +162,6 @@ struct ReduceOpConversion auto mod = op->getParentOfType(); unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - if (iWarpSize > numLaneToReduce) { - Value threadId = getThreadId(rewriter, loc); - Value warpSize = i32_val(iWarpSize); - Value laneId = urem(threadId, warpSize); - Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce)); - pred = pred ? and_(pred, lanePred) : lanePred; - } for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { SmallVector shfl(acc.size()); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d2dbe2044236..c0d38843ab24 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5940,6 +5940,30 @@ def sanitize_sum_kernel(Z, X, BLOCK: tl.constexpr): torch.testing.assert_close(Z, X.sum().to(torch.int32)) +@pytest.mark.parametrize("reduce_dim", [0, 1]) +def test_side_effectful_reduction_2d(device, reduce_dim): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, reduce_dim: tl.constexpr, + NON_REDUCE_DIM: tl.constexpr): + offsets = tl.arange(0, BLOCK_0)[:, None] * BLOCK_1 + tl.arange(0, BLOCK_1)[None, :] + vals = tl.load(X + offsets) + z = tl.reduce(vals, reduce_dim, sanitize_add) + tl.store(Z + tl.arange(0, NON_REDUCE_DIM), z) + + BLOCK_0 = 16 + BLOCK_1 = 32 + NON_REDUCE_DIM = BLOCK_1 if reduce_dim == 0 else BLOCK_0 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK_0, BLOCK_1], device="cuda", dtype=torch.int32) + Z = torch.zeros([NON_REDUCE_DIM], device="cuda", dtype=torch.int32) + sanitize_sum_2d_kernel[(1, )](Z, X, BLOCK_0=BLOCK_0, BLOCK_1=BLOCK_1, reduce_dim=reduce_dim, + NON_REDUCE_DIM=NON_REDUCE_DIM) + torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32)) + + def test_side_effectful_scan(device): if device != "cuda": pytest.skip()