Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,6 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// Return true if the src and dst layout match.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,6 @@ namespace gpu {
SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
Type ouType);

SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter);

SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter);

Type getElementType(Value value);

class MultipleOperandsRange
Expand Down Expand Up @@ -187,8 +179,8 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
for (auto operand : adaptor.getOperands()) {
auto argTy = op->getOperand(0).getType();
auto subOperands = unpackLLElements(loc, operand, rewriter);
subOperands = unpackI32(subOperands, argTy, rewriter, loc,
this->getTypeConverter());
subOperands = unpackI32s(subOperands, argTy, rewriter, loc,
this->getTypeConverter());
allOperands.resize(subOperands.size());
for (auto v : llvm::enumerate(subOperands))
allOperands[v.index()].push_back(v.value());
Expand All @@ -215,7 +207,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
}
resultVals = maybeDeduplicate(op, resultVals);
resultVals =
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
Value view = packLLElements(loc, this->getTypeConverter(), resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, view);
Expand Down
61 changes: 61 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,67 @@ inline Value getStructFromSharedMemoryObject(Location loc,
return llvmStruct;
}

// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
// instructions to pack & unpack sub-word integers. A workaround is to
// store the results of tensors with dot operand encodings in i32 to
// facilitate instructions such as `ldmatrix`.
//
// TODO: Confirm if the problem is still there.
inline bool requiresI32Conversion(Type type) {
auto tensorTy = dyn_cast<RankedTensorType>(type);
if (!tensorTy)
return false;
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!dotOpEnc)
return false;
auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOpEnc.getParent());
if (!(parent && parent.getVersionMajor() < 3))
return false;
return true;
}

inline SmallVector<Value> packI32s(const SmallVector<Value> &inValues,
Type type, RewriterBase &rewriter,
Location loc,
const LLVMTypeConverter *typeConverter) {
if (!requiresI32Conversion(type))
return inValues;
Type eltTy =
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());

SmallVector<Value> outValues;
int vecWidth = 32 / eltTy.getIntOrFloatBitWidth();
auto vecTy = vec_ty(eltTy, vecWidth);
for (int i = 0; i < inValues.size(); i += vecWidth) {
Value vec = undef(vecTy);
for (int j = 0; j < vecWidth; j++) {
vec = insert_element(vec, inValues[i + j], i32_val(j));
}
outValues.push_back(bitcast(vec, i32_ty));
}
return outValues;
}

inline SmallVector<Value> unpackI32s(const SmallVector<Value> &inValues,
Type type, RewriterBase &rewriter,
Location loc,
const LLVMTypeConverter *typeConverter) {
if (!requiresI32Conversion(type))
return inValues;
Type eltTy =
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());

SmallVector<Value> outValues;
for (auto v : inValues) {
auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth());
auto vec = bitcast(v, vecTy);
for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) {
outValues.push_back(extract_element(vec, i32_val(i)));
}
}
return outValues;
}

inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
RewriterBase &rewriter) {
assert(bool(llvmStruct) && "can not unpack null values");
Expand Down
22 changes: 4 additions & 18 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,14 +731,14 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
}

bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`,
// `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully
// subsumed by the linear-layout checks.
// TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
// checks.
// TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
// supported yet in Triton's backend.
return !cvtReordersRegisters(srcTy, dstTy) &&
!isBlockedToDotShortcut(srcTy, dstTy) &&
!isMmaToDotShortcut(srcTy, dstTy) &&
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
!isMfmaToDotShortcut(srcTy, dstTy);
}

Expand All @@ -749,20 +749,6 @@ bool atomicNeedsSharedMemory(Value value) {
return true;
}

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
return true;
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mmaLayout &&
!srcTy.getElementType().isF32();
}

namespace {

/// A data structure similar to SetVector but maintains
Expand Down
43 changes: 22 additions & 21 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
} else {
// Cast 5. The two layouts are equivalent. We should probably remove
// these in RemoveLayoutConversion.
rewriter.replaceOp(op, adaptor.getSrc());
auto dstCvt = requiresI32Conversion(dstTy);
auto srcCvt = requiresI32Conversion(srcTy);
if (dstCvt || srcCvt) {
auto inVals = unpackLLElements(op.getLoc(), adaptor.getSrc(), rewriter);
inVals = unpackI32s(inVals, srcTy, rewriter, op.getLoc(),
getTypeConverter());
inVals =
packI32s(inVals, dstTy, rewriter, op.getLoc(), getTypeConverter());
auto res = packLLElements(op.getLoc(), getTypeConverter(), inVals,
rewriter, op.getType());
rewriter.replaceOp(op, res);
} else {
rewriter.replaceOp(op, adaptor.getSrc());
}
return success();
}
}
Expand All @@ -342,9 +355,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
StringAttr kRegister = str_attr("register");
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));

auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());
SmallVector<Value> outVals(numRegs);
for (int i = 0; i < outVals.size(); i++) {
for (int i = 0; i < numRegs; i++) {
// Remove free masks from the register index
// For example, if idx = 0b00111, and masks = 0b00100, then we get
// 0b00011. It means that register 7 (0b111) has the same value as
Expand All @@ -355,6 +371,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
: idx;
outVals[i] = inVals[srcIdx];
}
outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
rewriter.replaceOp(op, result);
Expand Down Expand Up @@ -386,9 +403,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
if (auto nvidiaMma =
dyn_cast<NvidiaMmaEncodingAttr>(dotOperand.getParent())) {
if (product(getCTAsPerCGA(nvidiaMma)) > 1) {
return false;
}
Comment thread
lezcano marked this conversation as resolved.
if (useLegacyMMAConversion) {
return false;
}
Expand All @@ -398,6 +412,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64;
return largeKWidth && nvidiaMma.isAmpere();
}
return false;
}
if (isa<BlockedEncodingAttr>(layout)) {
return true;
Expand Down Expand Up @@ -439,6 +454,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
inVals[it.index()] = ptrtoint(llvmElemTy, it.value());
}
}
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());

// Pretty sure this is the identity function ATM
// It'd be better to simply call `quotient({kBlock})` and
Expand All @@ -458,22 +474,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}
}

// FIXME [Dot LL]
// We know it's just for largeKWidth case in Ampere
// In this case, we need to pack the outputs into i32
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
auto concat = [&](Value a, Value b) {
return or_(zext(i32_ty, bitcast(a, i16_ty)),
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
};

SmallVector<Value> outVals32(outVals.size() / 2);
for (int i = 0; i < outVals32.size(); ++i) {
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
}
outVals = outVals32;
}

outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
rewriter.replaceOp(op, result);
Expand Down
56 changes: 6 additions & 50 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,51 +103,6 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
llvm_unreachable("unimplemented code path");
}

SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter) {
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
if (!tensorTy)
return inValues;
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
return inValues;
SmallVector<Value> outValues;
for (auto v : inValues) {
// cast i32 to appropriate eltType vector and extract elements
auto eltType = typeConverter->convertType(tensorTy.getElementType());
auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth());
auto vec = bitcast(v, vecType);
for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) {
outValues.push_back(extract_element(vec, i32_val(i)));
}
}
return outValues;
}

SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter) {
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
if (!tensorTy)
return inValues;
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
return inValues;
SmallVector<Value> outValues;
auto eltType = typeConverter->convertType(tensorTy.getElementType());
int vecWidth = 32 / eltType.getIntOrFloatBitWidth();
auto vecType = vec_ty(eltType, vecWidth);
for (int i = 0; i < inValues.size(); i += vecWidth) {
Value vec = undef(vecType);
for (int j = 0; j < vecWidth; j++) {
vec = insert_element(vec, inValues[i + j], i32_val(j));
}
outValues.push_back(bitcast(vec, i32_ty));
}
return outValues;
}

int getNumElementsPerThreads(Type type,
const LLVMTypeConverter *typeConverter) {
int numElemsPerThread = 1;
Expand Down Expand Up @@ -500,7 +455,7 @@ struct ElementwiseInlineAsmOpConversion
auto argTy = op->getOperand(0).getType();
auto subOperands = unpackLLElements(loc, operand, rewriter);
unpackedOperands.push_back(
unpackI32(subOperands, argTy, rewriter, loc, getTypeConverter()));
unpackI32s(subOperands, argTy, rewriter, loc, getTypeConverter()));
}

int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(),
Expand Down Expand Up @@ -560,10 +515,11 @@ struct ElementwiseInlineAsmOpConversion
unpackedResults[i], /*inType=*/op->getOperand(0).getType(),
/*ouType=*/op->getResult(i).getType());
}
auto packed = packI32(unpackedResults[i], op->getResult(i).getType(),
rewriter, loc, getTypeConverter());
outs.push_back(packLLElements(loc, getTypeConverter(), packed, rewriter,
op->getResult(i).getType()));
auto dstTy = op->getResult(i).getType();
unpackedResults[i] = packI32s(unpackedResults[i], dstTy, rewriter, loc,
getTypeConverter());
outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i],
rewriter, op->getResult(i).getType()));
}

rewriter.replaceOp(op, outs);
Expand Down
37 changes: 1 addition & 36 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,42 +184,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
SmallVector<Value> outVals = loadSharedToDistributed(
dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo);

// FIXME [Dot LL]
// Ampere case
// In this case, we need to pack the outputs into i32
if (auto dotOp = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding())) {
if (auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOp.getParent())) {
if (parent.isAmpere()) {
if (elemLlvmTy.isInteger(8)) {
auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
return or_(
or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))),
or_(shl(zext(i32_ty, a3), i32_val(16)),
shl(zext(i32_ty, a4), i32_val(24))));
};
SmallVector<Value> outVals32(outVals.size() / 4);
for (int i = 0; i < outVals32.size(); ++i) {
outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1],
outVals[4 * i + 2], outVals[4 * i + 3]);
}
outVals = outVals32;
} else {
assert(elemLlvmTy.isBF16() && "Unexpected element type");
auto concat = [&](Value a, Value b) {
return or_(zext(i32_ty, bitcast(a, i16_ty)),
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
};

SmallVector<Value> outVals32(outVals.size() / 2);
for (int i = 0; i < outVals32.size(); ++i) {
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
}
outVals = outVals32;
}
}
}
}

outVals = packI32s(outVals, dstTy, rewriter, loc, typeConverter);
Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);

Expand Down
Loading