diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index 4e3c093e8cdf..7973f63a0115 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -298,14 +298,7 @@ struct StoreOpConversion vec = std::min(vec, maskAlign); } - // numElements = 1 for scalar - auto tensorTy = valueTy.dyn_cast(); - auto numElems = tensorTy ? tensorTy.getNumElements() : 1; - Value mask = int_val(1, 1); - auto tid = tid_val(); - mask = and_(mask, - icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems))); - + Value mask = getMask(valueTy, rewriter, loc); const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); const size_t valueElemNBits = dtsize * 8; diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 6ea2e0ff0e1a..cdfa731a7fb3 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -421,6 +421,46 @@ class ConvertTritonGPUOpToLLVMPatternBase { // ----------------------------------------------------------------------- // Utilities // ----------------------------------------------------------------------- + Value getMask(Type valueTy, ConversionPatternRewriter &rewriter, + Location loc) const { + auto tensorTy = valueTy.dyn_cast(); + Value mask = int_val(1, 1); + auto tid = tid_val(); + if (tensorTy) { + auto layout = tensorTy.getEncoding(); + auto shape = tensorTy.getShape(); + unsigned rank = shape.size(); + auto sizePerThread = triton::gpu::getSizePerThread(layout); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout); + auto order = triton::gpu::getOrder(layout); + auto shapePerCTA = triton::gpu::getShapePerCTA(layout, shape); + Value warpSize = i32_val(32); + Value laneId = urem(tid, warpSize); + Value warpId = udiv(tid, warpSize); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, warpsPerCTA, order); + SmallVector multiDimThreadId = + delinearize(rewriter, loc, laneId, threadsPerWarp, order); + for (unsigned dim = 0; dim < rank; ++dim) { + // if there is no data replication across threads on this dimension + if (shape[dim] >= shapePerCTA[dim]) + continue; + // Otherwise, we need to mask threads that will replicate data on this + // dimension. Calculate the thread index on this dimension for the CTA + Value threadDim = + add(mul(multiDimWarpId[dim], i32_val(threadsPerWarp[dim])), + multiDimThreadId[dim]); + mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])), + i32_val(shape[dim]))); + } + } else { + // If the tensor is not ranked, then it is a scalar and only thread 0 can + // write + mask = and_(mask, icmp_slt(tid, i32_val(1))); + } + return mask; + } // Convert an \param index to a multi-dim coordinate given \param shape and // \param order. diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 0fd00dae525c..8cf76c713b29 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1488,6 +1488,53 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'): np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) +layouts = [ + BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]), + BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1]) +] + + +@pytest.mark.parametrize("M", [32, 64, 128, 256]) +@pytest.mark.parametrize("src_layout", layouts) +def test_store_op(M, src_layout, device='cuda'): + ir = f""" + #src = {src_layout} + module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %1 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> + %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #src> + %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> + tt.store %8, %4 : tensor<{M}x1xf32, #src> + tt.return + }} + }} + """ + + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + store_kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, 1)).astype('float32') + y = np.zeros((M, 1), dtype='float32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + + pgm = store_kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + @triton.jit def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): delta = mean_2 - mean_1 diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index d05ee816fa97..342310eceed5 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1038,7 +1038,6 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: store_f32 tt.func @store_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) { - // CHECK: llvm.icmp "slt" // CHECK: llvm.inline_asm // CHECK-SAME: @$2 st.global.b32 // CHECK: llvm.inline_asm