Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
// Data loader for mma.16816 instruction.
class MMA16816SmemLoader {
public:
MMA16816SmemLoader(int warpsPerTile, ArrayRef<uint32_t> order,
MMA16816SmemLoader(int nPerWarp, int warpsPerTile, ArrayRef<uint32_t> order,
ArrayRef<uint32_t> warpsPerCTA, uint32_t kOrder,
int kWidth, ArrayRef<Value> smemStrides,
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
Expand Down Expand Up @@ -93,6 +93,8 @@ class MMA16816SmemLoader {
int inWarpMatOffset;
// Offset in number of matrices to increment on non-k dim across warps
int warpMatOffset;

int nPerWarp;
};

SmallVector<Value>
Expand Down Expand Up @@ -131,10 +133,18 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
// address (s0,s1) annotates.

Value matOff[2];
matOff[kOrder ^ 1] = add(
mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1)
mul(nkMatArr,
i32_val(inWarpMatOffset))); // matrix offset inside a warp (kOrder=1)
// When B's shape(k, n) is (16, 8) and ldmatrix.x4 is used, the shared memory
// access will be out of bound. In the future we should change this case to
// ldmatrix.x2
if (kOrder == 0 && nPerWarp == 8) {
matOff[kOrder ^ 1] = mul(warpId, i32_val(warpMatOffset));
} else {
matOff[kOrder ^ 1] = add(
mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1)
mul(nkMatArr,
i32_val(
inWarpMatOffset))); // matrix offset inside a warp (kOrder=1)
}
matOff[kOrder] = kMatArr;

// Physical offset (before swizzling)
Expand Down Expand Up @@ -390,13 +400,13 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> ptrs, Type matTy,
}

MMA16816SmemLoader::MMA16816SmemLoader(
int warpsPerTile, ArrayRef<uint32_t> order, ArrayRef<uint32_t> warpsPerCTA,
uint32_t kOrder, int kWidth, ArrayRef<Value> smemStrides,
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
ArrayRef<int> matShape, int perPhase, int maxPhase, int elemBytes,
ConversionPatternRewriter &rewriter,
int nPerWarp, int warpsPerTile, ArrayRef<uint32_t> order,
ArrayRef<uint32_t> warpsPerCTA, uint32_t kOrder, int kWidth,
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
ArrayRef<int> instrShape, ArrayRef<int> matShape, int perPhase,
int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter,
TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc)
: order(order.begin(), order.end()),
: nPerWarp(nPerWarp), order(order.begin(), order.end()),
warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder),
kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()),
instrShape(instrShape.begin(), instrShape.end()),
Expand Down Expand Up @@ -490,6 +500,7 @@ std::function<void(int, int)> getLoadMatrixFn(
bool isA, TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shapePerCTA = getShapePerCTA(tensorTy);
Type eltTy = tensorTy.getElementType();
// We assumes that the input operand of Dot should be from shared layout.
// TODO(Superjomn) Consider other layouts if needed later.
Expand All @@ -511,13 +522,16 @@ std::function<void(int, int)> getLoadMatrixFn(
if (kWidth != (4 / elemBytes))
assert(vecPhase == 1 || vecPhase == 4 * kWidth);

int nPerWarp =
std::max<int>(shapePerCTA[1] / mmaLayout.getWarpsPerCTA()[1], 8);

// (a, b) is the coordinate.
auto load = [=, &rewriter, &vals](int a, int b) {
MMA16816SmemLoader loader(
warpsPerTile, sharedLayout.getOrder(), mmaLayout.getWarpsPerCTA(),
kOrder, kWidth, smemObj.strides, tensorTy.getShape() /*tileShape*/,
instrShape, matShape, perPhase, maxPhase, elemBytes, rewriter,
typeConverter, loc);
nPerWarp, warpsPerTile, sharedLayout.getOrder(),
mmaLayout.getWarpsPerCTA(), kOrder, kWidth, smemObj.strides,
tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
maxPhase, elemBytes, rewriter, typeConverter, loc);
// Offset of a slice within the original tensor in shared memory
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
SmallVector<Value> offs =
Expand Down