Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,36 +41,60 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
if (inBitWidth == ouBitWidth)
return values;
if (inBitWidth == 16 && ouBitWidth == 32) {
// Register layout conversion:
//
// [0, 1], [4, 5] ⟶ [0], [1], [4], [5]
// [2, 3], [6, 7] [2], [3], [6], [7]
//
// Original access order:
//
// [0, 1], [2, 3], [4, 5], [6, 7]
//
// Transformed access order:
//
// [0], [2], [1], [3], [4], [6], [5], [7]
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 8) {
ret.push_back(values[i]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 7]);
}
return ret;
}
if (inBitWidth == 8 && ouBitWidth == 16) {
// Register layout conversion:
//
// [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11]
// [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15]
//
// Original access order:
//
// [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]
//
// Transformed access order:
//
// [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15]
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 16) {
ret.push_back(values[i + 0]);
ret.push_back(values[i]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 8]);
ret.push_back(values[i + 9]);
ret.push_back(values[i + 10]);
ret.push_back(values[i + 11]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 7]);
ret.push_back(values[i + 8]);
ret.push_back(values[i + 9]);
ret.push_back(values[i + 12]);
ret.push_back(values[i + 13]);
ret.push_back(values[i + 10]);
ret.push_back(values[i + 11]);
ret.push_back(values[i + 14]);
ret.push_back(values[i + 15]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,6 @@ struct ConvertLayoutOpConversion
// for the destination type, we need to pack values together
// so they can be consumed by tensor core operations
SmallVector<Value> vecVals;
SmallVector<Type> types;
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
// instructions to pack & unpack sub-word integers. A workaround is to
// store the results of ldmatrix in i32
Expand All @@ -655,37 +654,20 @@ struct ConvertLayoutOpConversion
shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j));
val = or_(i32_ty, val, ext);
}
vecVals.push_back(val);
vecVals.push_back(bitcast(val, i32_ty));
}
elems = elems / (32 / elemSize);
types = SmallVector<Type>(elems, i32_ty);
} else {
unsigned vecSize = std::max<unsigned>(32 / elemSize, 1);
Type vecTy = vec_ty(elemTy, vecSize);
types = SmallVector<Type>(elems / vecSize, vecTy);
for (unsigned i = 0; i < elems; i += vecSize) {
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
vecVals.push_back(packed);
vecVals.push_back(bitcast(packed, i32_ty));
}
}

// This needs to be ordered the same way that
// ldmatrix.x4 would order it
// TODO: this needs to be refactor so we don't
// implicitly depends on how emitOffsetsForMMAV2
// is implemented
SmallVector<Value> reorderedVals;
for (unsigned i = 0; i < vecVals.size(); i += 4) {
reorderedVals.push_back(bitcast(vecVals[i], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty));
}

Value view = packLLElements(loc, getTypeConverter(), reorderedVals,
rewriter, dstTy);
Value view =
packLLElements(loc, getTypeConverter(), vecVals, rewriter, dstTy);
rewriter.replaceOp(op, view);
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,8 @@ Value composeValuesToDotOperandLayoutStruct(
for (int m = 0; m < n0; ++m)
for (int k = 0; k < n1; ++k) {
elems.push_back(vals.at({b, 2 * m, 2 * k}));
elems.push_back(vals.at({b, 2 * m, 2 * k + 1}));
elems.push_back(vals.at({b, 2 * m + 1, 2 * k}));
elems.push_back(vals.at({b, 2 * m, 2 * k + 1}));
elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1}));
}
assert(!elems.empty());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,39 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(

// For kWidth = 8, split the mma into 4 mmas with "stride 4" along K
if (dot.getOpIdx() == 0) {
si = llvm::SmallVector<unsigned>{0, 8, 4, 12, 1, 9, 5, 13,
2, 10, 6, 14, 3, 11, 7, 15};
// Original register layout:
//
// [0, 1, 2, 3], [8, 9, 10, 11]
// [4, 5, 6, 7], [12, 13, 14, 15]
//
// Each element in the layout consists of two bf16 values.
// For example, the row [0, 1, 2, 3] expands to:
//
// [[0/0, 0/1], [1/0, 1/1], [2/0, 2/1], [3/0, 3/1]]
//
// Here, 0/0 refers to the first half of element 0, and 0/1 refers to the
// second half, matching kWidth = 8.
//
// To derive four independent MMA operations, a stride of 4 is applied to
// the original register layout:
//
// 1st MMA: [0, 4, 8, 12]
// 2nd MMA: [1, 5, 9, 13]
// 3rd MMA: [2, 6, 10, 14]
// 4th MMA: [3, 7, 11, 15]
si = llvm::SmallVector<unsigned>{0, 4, 8, 12, 1, 5, 9, 13,
2, 6, 10, 14, 3, 7, 11, 15};
} else {
// Original register layout:
//
// [0, 1, 2, 3]^T, [4, 5, 6, 7]^T
//
// A stride of 4 is applied to derive four independent MMA operations:
//
// 1st MMA: [0, 4]
// 2nd MMA: [1, 5]
// 3rd MMA: [2, 6]
// 4th MMA: [3, 7]
si = llvm::SmallVector<unsigned>{0, 4, 1, 5, 2, 6, 3, 7};
}

Expand Down Expand Up @@ -112,8 +142,8 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(
for (auto i = 0; i < n0; ++i) {
for (auto j = 0; j < n1; j++) {
vals[{b, 2 * i, 2 * j}] = elems[offset++];
vals[{b, 2 * i, 2 * j + 1}] = elems[offset++];
vals[{b, 2 * i + 1, 2 * j}] = elems[offset++];
vals[{b, 2 * i, 2 * j + 1}] = elems[offset++];
vals[{b, 2 * i + 1, 2 * j + 1}] = elems[offset++];
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,6 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
ret.push_back(v);
}
}
// FIXME [Dot LL]
// The DotOperandEncodingAttr without LLs encodes the
// layout as
// e0 e1
// e2 e3
// rather than transposed that, as the PTX docs say
// We transpose every block of 4 elements (kWidth = 8 -> 4 bf16x2)
assert(ret.size() % 16 == 0);
for (int i = 0; i < ret.size() / 16; ++i) {
for (int j = 0; j < 4; ++j) {
std::swap(ret[16 * i + j + 4], ret[16 * i + j + 8]);
}
}

return ret;
}
Expand Down