Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@ unsigned getNumCTAs(Attribute layout);
// len(shape) == rank.
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);

// Return the order that represents that the dot operand is in kMajor
// Return the order that represents that the dot operand is in kContig
// (contiguous in the inner dimension) or it's contiguous on the outer
// dimension.
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor);
bool kContig);

bool isExpensiveCat(CatOp cat, Attribute targetEncoding);

Expand Down
16 changes: 8 additions & 8 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,15 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
}

SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor) {
// kMajor: if true, the matrix is fastest-running on k,
bool kContig) {
// kContig: if true, the matrix is fastest-running on k,
// otherwise it is on m (resp. n)
// opIdx=0: [batch, m, k] if rank == 3 else [m, k]
// opIdx=1: [batch, k, n] if rank == 3 else [k, n]
// batch (if rank == 3) is always the slowest running dimension
assert(rank == 2 || rank == 3);
assert(opIdx == 0 || opIdx == 1);
auto rowMajor = bool(opIdx) != kMajor;
auto rowMajor = bool(opIdx) != kContig;
return getMatrixOrder(rank, rowMajor);
}

Expand Down Expand Up @@ -268,7 +268,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
}
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto rank = dotLayout.getWarpsPerCTA().size();
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kContig*/ true);
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
Expand Down Expand Up @@ -987,7 +987,7 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
}
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
/*kContig*/ true);
}

LogicalResult DotOperandEncodingAttr::verify(
Expand Down Expand Up @@ -1959,7 +1959,7 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getRepOrder() const {
SmallVector<unsigned>
AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
return getOrderForDotOperand(opIdx, rank, /*kContig*/ true);
}

SmallVector<unsigned>
Expand Down Expand Up @@ -2027,7 +2027,7 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
SmallVector<unsigned>
AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
return getOrderForDotOperand(opIdx, rank, /*kContig*/ true);
}

SmallVector<unsigned>
Expand Down Expand Up @@ -2219,7 +2219,7 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
SmallVector<unsigned>
NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
return getOrderForDotOperand(opIdx, rank, /*kContig*/ true);
}

SmallVector<unsigned>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc,
return base;
}

bool isKMajor(llvm::ArrayRef<unsigned> order, int opIdx) {
bool isKContig(llvm::ArrayRef<unsigned> order, int opIdx) {
auto rank = order.size();
int kdim = opIdx == 0 ? rank - 1 : rank - 2;
return order[0] == kdim;
Expand Down Expand Up @@ -102,9 +102,9 @@ bool isSwizzlePatternFitsIntoBlock(const SharedEncodingAttr sharedLayout,
const auto swizzleSlowDimSize =
sharedLayout.getMaxPhase() * sharedLayout.getPerPhase();
const auto swizzlePatternSizeK =
isKMajor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize;
isKContig(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize;
const auto swizzlePatternSizeNonK =
!isKMajor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize;
!isKContig(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize;

const auto blockSizeK = mfmaInstrK * reps[reps.size() - 1];
const auto blockSizeNonK = mfmaInstrNonK * warpsPerBlockNonK;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc,
const SharedMemoryObject &smemObj,
ArrayRef<Value> strides);

bool isKMajor(llvm::ArrayRef<unsigned> order, int opIdx);
bool isKContig(llvm::ArrayRef<unsigned> order, int opIdx);

using computeTensorElemMappingInBlockT =
std::function<llvm::SmallVector<llvm::SmallVector<Value>>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
Value smemBase;
auto smemStrides = smemObj.getStrides(aTensorTy, loc, rewriter);
bool isFastPath =
!AMD::isKMajor(order, opIdx) && !hasSwizzleEnabled(sharedLayout);
!AMD::isKContig(order, opIdx) && !hasSwizzleEnabled(sharedLayout);
if (isFastPath) {
// fast path handles tensors that are not k-major and have swizzling
// disabled, in which case offsets computation can be simplified
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,13 +498,13 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,
// getValuesFromDotOperandLayoutStruct as both a and b are K-major
assert(dotOpA.getRepOrder() == getOrderForDotOperand(dotOpA.getOpIdx(),
aShapePerCTA.size(),
/*kMajor=*/true));
/*kContig=*/true));
auto ha = getValuesFromDotOperandLayoutStruct(
typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy);

assert(dotOpB.getRepOrder() == getOrderForDotOperand(dotOpB.getOpIdx(),
bShapePerCTA.size(),
/*kMajor=*/true));
/*kContig=*/true));
auto hb = getValuesFromDotOperandLayoutStruct(
typeConverter, loc, rewriter, loadedB, repBatch, repN, repK, bTensorTy);

Expand Down