diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 6370eba55b32..521ffec3a739 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -21,6 +21,7 @@ class AllocationAnalysis; SmallVector getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, unsigned &outVec); +SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op); } // namespace triton diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 567501d3dbce..ec3757208016 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -18,6 +18,7 @@ using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTATile; using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::getUniqueContigPerThread; using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; @@ -50,9 +51,7 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) { return {inOrd, outOrd}; } -SmallVector -getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, - unsigned &outVec) { +SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) { auto srcTy = op.getSrc().getType().cast(); auto dstTy = op.getResult().getType().cast(); Attribute srcLayout = srcTy.getEncoding(); @@ -76,15 +75,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, } } - assert(srcLayout && dstLayout && - "Unexpected layout in getScratchConfigForCvtLayout()"); - auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); - unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]]; - unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]]; - // TODO: Fix the legacy issue that ourOrd[0] == 0 always means - // that we cannot do vectorization. - inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread; - outVec = outOrd[0] == 0 ? 1 : dstContigPerThread; + assert(srcLayout && dstLayout && "Unexpected layout in getRepShape()"); auto srcShapePerCTA = getShapePerCTA(srcTy); auto dstShapePerCTA = getShapePerCTA(dstTy); @@ -92,21 +83,44 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape()); unsigned rank = dstTy.getRank(); - SmallVector paddedRepShape(rank); - unsigned pad = std::max(inVec, outVec); + SmallVector repShape(rank); for (unsigned d = 0; d < rank; ++d) { - paddedRepShape[d] = + repShape[d] = std::max(std::min(srcShapePerCTA[d], srcShapePerCTATile[d]), std::min(dstShapePerCTA[d], dstShapePerCTATile[d])); } - if (rank == 1) - return paddedRepShape; + return repShape; +} + +SmallVector +getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, + unsigned &outVec) { + auto repShape = getRepShapeForCvtLayout(op); + + auto srcTy = op.getSrc().getType().cast(); + auto dstTy = op.getResult().getType().cast(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); + unsigned srcContigPerThread = + getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]]; + unsigned dstContigPerThread = + getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]]; + // TODO: Fix the legacy issue that ourOrd[0] == 0 always means + // that we cannot do vectorization. + inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread; + outVec = outOrd[0] == 0 ? 1 : dstContigPerThread; + + if (repShape.size() <= 1) + return repShape; unsigned paddedDim = 1; if (auto dstBlockedLayout = dstLayout.dyn_cast()) { paddedDim = dstBlockedLayout.getOrder()[0]; } - paddedRepShape[paddedDim] += pad; - return paddedRepShape; + unsigned pad = std::max(inVec, outVec); + repShape[paddedDim] += pad; + return repShape; } SmallVector diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 1aadf5093884..70e675c7bc5f 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -237,12 +237,30 @@ struct ConvertLayoutOpConversion llvm_unreachable("unexpected layout in getMultiDimOffset"); } + SmallVector + getWrappedMultiDimOffset(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDimOffset, + ArrayRef shape, + SmallVector shapePerCTATile, + SmallVector shapePerCTA) const { + unsigned rank = shape.size(); + SmallVector multiDimOffsetWrapped(rank); + for (unsigned d = 0; d < rank; ++d) { + if (shapePerCTATile[d] > shapePerCTA[d]) + multiDimOffsetWrapped[d] = urem(multiDimOffset[d], i32_val(shape[d])); + else + multiDimOffsetWrapped[d] = multiDimOffset[d]; + } + return multiDimOffsetWrapped; + } + // shared memory rd/st for blocked or mma layout with data padding void processReplica(Location loc, ConversionPatternRewriter &rewriter, bool stNotRd, RankedTensorType type, ArrayRef numCTAsEachRep, ArrayRef multiDimRepId, unsigned vec, ArrayRef paddedRepShape, + ArrayRef origRepShape, ArrayRef outOrd, SmallVector &vals, Value smemBase) const { auto accumNumCTAsEachRep = product(numCTAsEachRep); @@ -286,8 +304,11 @@ struct ConvertLayoutOpConversion SmallVector multiDimOffset = getMultiDimOffset(layout, loc, rewriter, elemId, type, multiDimCTAInRepId, shapePerCTATile); - Value offset = - linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd); + SmallVector multiDimOffsetWrapped = getWrappedMultiDimOffset( + rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile, + shapePerCTA); + Value offset = linearize(rewriter, loc, multiDimOffsetWrapped, + paddedRepShape, outOrd); auto elemPtrTy = ptr_ty(llvmElemTy, 3); Value ptr = gep(elemPtrTy, smemBase, offset); auto vecTy = vec_ty(llvmElemTy, vec); @@ -575,6 +596,7 @@ struct ConvertLayoutOpConversion rewriter, srcTy); unsigned inVec = 0; unsigned outVec = 0; + auto origRepShape = getRepShapeForCvtLayout(op); auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec); if (getElementTypeOrSelf(op.getType()) .isa()) { @@ -618,7 +640,7 @@ struct ConvertLayoutOpConversion else processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, multiDimRepId, inVec, paddedRepShape, - outOrd, vals, smemBase); + origRepShape, outOrd, vals, smemBase); } else { assert(0 && "ConvertLayout with input layout not implemented"); return failure(); @@ -651,7 +673,8 @@ struct ConvertLayoutOpConversion else processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, multiDimRepId, outVec, - paddedRepShape, outOrd, outVals, smemBase); + paddedRepShape, origRepShape, outOrd, outVals, + smemBase); } else { assert(0 && "ConvertLayout with output layout not implemented"); return failure(); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 2b12e727025f..78c4f92caa53 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -339,6 +339,8 @@ class ConvertTritonGPUOpToLLVMPatternBase { // Order auto inOrder = triton::gpu::getOrder(srcEncoding); auto outOrder = triton::gpu::getOrder(resSharedLayout); + assert(outVec * (maxPhase - 1) <= srcShape[outOrder[0]] && + "Swizzling would generate out of bounds memory accesses"); // Tensor indices held by the current thread, as LLVM values auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy, false); // Swizzling with leading offsets (e.g. Hopper GMMA) @@ -452,10 +454,10 @@ class ConvertTritonGPUOpToLLVMPatternBase { auto dstElemTy = dstTy.getElementType(); auto inOrd = triton::gpu::getOrder(srcSharedLayout); auto outOrd = triton::gpu::getOrder(dstDistributedLayout); - unsigned outVec = - inOrd == outOrd - ? triton::gpu::getContigPerThread(dstDistributedLayout)[outOrd[0]] - : 1; + unsigned outVec = inOrd == outOrd + ? triton::gpu::getUniqueContigPerThread( + dstDistributedLayout, dstShape)[outOrd[0]] + : 1; unsigned inVec = srcSharedLayout.getVec(); unsigned minVec = std::min(outVec, inVec); unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy); @@ -501,10 +503,10 @@ class ConvertTritonGPUOpToLLVMPatternBase { auto dstElemTy = dstTy.getElementType(); auto inOrd = triton::gpu::getOrder(srcDistributedLayout); auto outOrd = dstSharedLayout.getOrder(); - unsigned inVec = - inOrd == outOrd - ? triton::gpu::getContigPerThread(srcDistributedLayout)[inOrd[0]] - : 1; + unsigned inVec = inOrd == outOrd + ? triton::gpu::getUniqueContigPerThread( + srcDistributedLayout, srcShape)[inOrd[0]] + : 1; unsigned outVec = dstSharedLayout.getVec(); unsigned minVec = std::min(outVec, inVec); unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 70f8d3d80900..df798d5ac6c2 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3607,6 +3607,7 @@ def kernel(Out): # MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), # MmaLayout(1, [4, 1], [1, 1], [0, 1]), # MmaLayout((2, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + BlockedLayout([1, 16], [8, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), @@ -3624,15 +3625,16 @@ def kernel(Out): ] -@pytest.mark.parametrize("shape", [(128, 128)]) +@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]]) @pytest.mark.parametrize("dtype", ['float16']) @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("interm_layout", intermediate_layouts) @pytest.mark.parametrize("dst_layout", layouts) -def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device): +def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): if is_hip(): pytest.skip("test_convert2d is not supported in HIP") - + if (M == 1 or N == 1) and interm_layout: + pytest.skip("Out of bound access when maxPhase > 1") if str(src_layout) == str(dst_layout): pytest.skip() if 'mma' in str(src_layout) and 'mma' in str(dst_layout): @@ -3648,43 +3650,43 @@ def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device): """ conversion = f""" - %12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst> - %13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst> + %12 = triton_gpu.convert_layout %9 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %11 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #dst> """ if interm_layout is None else f""" - %15 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #interm> - %16 = triton_gpu.convert_layout %15 : (tensor<128x128xi32, #interm>) -> tensor<128x128xi32, #src> - %17 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #interm> - %18 = triton_gpu.convert_layout %17 : (tensor<128x128xf16, #interm>) -> tensor<128x128xf16, #src> + %15 = triton_gpu.convert_layout %9 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #interm> + %16 = triton_gpu.convert_layout %15 : (tensor<{M}x{N}xi32, #interm>) -> tensor<{M}x{N}xi32, #src> + %17 = triton_gpu.convert_layout %11 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #interm> + %18 = triton_gpu.convert_layout %17 : (tensor<{M}x{N}xf16, #interm>) -> tensor<{M}x{N}xf16, #src> - %12 = triton_gpu.convert_layout %16 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst> - %13 = triton_gpu.convert_layout %18 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst> + %12 = triton_gpu.convert_layout %16 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %18 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #dst> """ - ir = layouts + """ - module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func public @kernel_0d1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<128> : tensor<128x1xi32, #src> - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>> - %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>> - %2 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #src> - %4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src> - %5 = arith.muli %4, %cst : tensor<128x1xi32, #src> - %6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src> - %7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src> - %8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src> - %9 = arith.addi %8, %7 : tensor<128x128xi32, #src> - %10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr, #src>, tensor<128x128xi32, #src> - %11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src> - %3 = tt.splat %arg1 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #dst> - """ + conversion + """ - %14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr, #dst>, tensor<128x128xi32, #dst> - tt.store %14, %13 : tensor<128x128xf16, #dst> + ir = layouts + f""" + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #src> + %3 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}xf16, #dst> tt.return - } -} + }} +}} """ - x = to_triton(numpy_random(shape, dtype_str=dtype), device=device) + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x) # write the IR to a temporary file using mkstemp