From e9681789d121240473e9c22d34c8169778946a36 Mon Sep 17 00:00:00 2001 From: Qingyi Liu Date: Tue, 15 Aug 2023 15:14:47 -0700 Subject: [PATCH] [BACKEND] Fix nPerWarp == 8 in MMA16816SmemLoader --- .../SharedToDotOperandMMAv2.cpp | 44 ++++++++++++------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index 189db47608df..603e980940d5 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -19,7 +19,7 @@ using ::mlir::triton::gpu::SharedEncodingAttr; // Data loader for mma.16816 instruction. class MMA16816SmemLoader { public: - MMA16816SmemLoader(int warpsPerTile, ArrayRef order, + MMA16816SmemLoader(int nPerWarp, int warpsPerTile, ArrayRef order, ArrayRef warpsPerCTA, uint32_t kOrder, int kWidth, ArrayRef smemStrides, ArrayRef tileShape, ArrayRef instrShape, @@ -93,6 +93,8 @@ class MMA16816SmemLoader { int inWarpMatOffset; // Offset in number of matrices to increment on non-k dim across warps int warpMatOffset; + + int nPerWarp; }; SmallVector @@ -131,10 +133,18 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane, // address (s0,s1) annotates. Value matOff[2]; - matOff[kOrder ^ 1] = add( - mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1) - mul(nkMatArr, - i32_val(inWarpMatOffset))); // matrix offset inside a warp (kOrder=1) + // When B's shape(k, n) is (16, 8) and ldmatrix.x4 is used, the shared memory + // access will be out of bound. In the future we should change this case to + // ldmatrix.x2 + if (kOrder == 0 && nPerWarp == 8) { + matOff[kOrder ^ 1] = mul(warpId, i32_val(warpMatOffset)); + } else { + matOff[kOrder ^ 1] = add( + mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1) + mul(nkMatArr, + i32_val( + inWarpMatOffset))); // matrix offset inside a warp (kOrder=1) + } matOff[kOrder] = kMatArr; // Physical offset (before swizzling) @@ -390,13 +400,13 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef ptrs, Type matTy, } MMA16816SmemLoader::MMA16816SmemLoader( - int warpsPerTile, ArrayRef order, ArrayRef warpsPerCTA, - uint32_t kOrder, int kWidth, ArrayRef smemStrides, - ArrayRef tileShape, ArrayRef instrShape, - ArrayRef matShape, int perPhase, int maxPhase, int elemBytes, - ConversionPatternRewriter &rewriter, + int nPerWarp, int warpsPerTile, ArrayRef order, + ArrayRef warpsPerCTA, uint32_t kOrder, int kWidth, + ArrayRef smemStrides, ArrayRef tileShape, + ArrayRef instrShape, ArrayRef matShape, int perPhase, + int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter, TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc) - : order(order.begin(), order.end()), + : nPerWarp(nPerWarp), order(order.begin(), order.end()), warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder), kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()), instrShape(instrShape.begin(), instrShape.end()), @@ -490,6 +500,7 @@ std::function getLoadMatrixFn( bool isA, TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) { auto tensorTy = tensor.getType().cast(); + auto shapePerCTA = getShapePerCTA(tensorTy); Type eltTy = tensorTy.getElementType(); // We assumes that the input operand of Dot should be from shared layout. // TODO(Superjomn) Consider other layouts if needed later. @@ -511,13 +522,16 @@ std::function getLoadMatrixFn( if (kWidth != (4 / elemBytes)) assert(vecPhase == 1 || vecPhase == 4 * kWidth); + int nPerWarp = + std::max(shapePerCTA[1] / mmaLayout.getWarpsPerCTA()[1], 8); + // (a, b) is the coordinate. auto load = [=, &rewriter, &vals](int a, int b) { MMA16816SmemLoader loader( - warpsPerTile, sharedLayout.getOrder(), mmaLayout.getWarpsPerCTA(), - kOrder, kWidth, smemObj.strides, tensorTy.getShape() /*tileShape*/, - instrShape, matShape, perPhase, maxPhase, elemBytes, rewriter, - typeConverter, loc); + nPerWarp, warpsPerTile, sharedLayout.getOrder(), + mmaLayout.getWarpsPerCTA(), kOrder, kWidth, smemObj.strides, + tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase, + maxPhase, elemBytes, rewriter, typeConverter, loc); // Offset of a slice within the original tensor in shared memory Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); SmallVector offs =