Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions test/Conversion/amd/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -229,24 +229,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: reduce_xor_max
tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) {
// CHECK: rocdl.ds_swizzle
// stride 16: CDNA fallback to bpermute
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now this is only checking gfx942. Can you also add check lines for gfx1250 given your changes?

// CHECK: rocdl.ds_bpermute

// stride 8: ROW_ROR:8 (0x128 = 296)
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 280, 15, 12, false : i32
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 264, 15, 3, false : i32
// CHECK-SAME: with 296, 15, 15, false : i32
// CHECK: llvm.intr.maxnum

// stride 4: ROW_HALF_MIRROR (0x141 = 321) + quad_perm xor 3 (27)
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 276, 15, 10, false : i32
// CHECK-SAME: with 321, 15, 15, false : i32
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 260, 15, 5, false : i32
// CHECK-SAME: with 27, 15, 15, false : i32
// CHECK: llvm.intr.maxnum

// stride 2: quad_perm xor 2 (78)
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 78, 15, 15, false : i32
// CHECK: llvm.intr.maxnum

// stride 1: quad_perm xor 1 (177)
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 177, 15, 15, false : i32
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
Expand All @@ -260,6 +263,37 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr

// -----

#linear = #ttg.linear<{register = [[0, 1]], lane = [[1, 0], [0, 0], [2, 0], [4, 0], [8, 0]], warp = [], block = []}>
#slice = #ttg.slice<{dim = 0, parent = #linear}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
// GFX1250-LABEL: reduce_xor_row_xmask
tt.func @reduce_xor_row_xmask(%arg0: tensor<16x2xf32, #linear>) {
// stride 16
// GFX1250-NOT: rocdl.ds_bpermute
// GFX1250: rocdl.permlanex16

// stride 8: ROW_XMASK:8
// GFX1250: rocdl.update.dpp
// GFX1250-SAME: with 360, 15, 15, false

// stride 4: ROW_XMASK:4
// GFX1250: rocdl.update.dpp
// GFX1250-SAME: with 356, 15, 15, false

// stride 1: ROW_XMASK:1
// GFX1250: rocdl.update.dpp
// GFX1250-SAME: with 353, 15, 15, false
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.maxnumf %arg1, %arg2 : f32
tt.reduce.return %1 : f32
}) : (tensor<16x2xf32, #linear>) -> tensor<2xf32, #slice>
tt.return
}
}

// -----

#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: atomicrmw_scope_memsemantics
Expand Down
6 changes: 5 additions & 1 deletion third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ enum class DppCtrl : uint32_t {
QUAD_PERM_FIRST = 0,
ROW_SHL0 = 0x100,
ROW_SHR0 = 0x110,
ROW_ROR0 = 0x120,
ROW_MIRROR = 0x140,
ROW_HALF_MIRROR = 0x141,
BCAST15 = 0x142,
BCAST31 = 0x143
BCAST31 = 0x143,
ROW_XMASK0 = 0x160,
};

} // namespace mlir::triton::AMD
Expand Down
220 changes: 110 additions & 110 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,47 @@ using mlir::triton::AMD::DppCtrl;
using mlir::triton::AMD::ISAFamily;
using mlir::triton::gpu::appendOrGetExternFuncOp;

namespace mlir::LLVM::AMD {
namespace {

enum class ShflKind : uint32_t {
bfly = 0,
up = 1,
down = 2,
idx = 3,
idx = 2,
};
} // namespace

namespace mlir::LLVM::AMD {
static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter,
ISAFamily isaFamily, Value val, Value i,
int strideInt, ShflKind mode, Value clamp) {
Value emitDpp(Location loc, RewriterBase &rewriter, Value old, Value src,
DppCtrl dppCtrl, uint32_t rowMask = 0xf, uint32_t bankMask = 0xf,
bool boundCtrl = false) {
return ROCDL::DPPUpdateOp::create(
rewriter, loc, src.getType(), old, src,
rewriter.getI32IntegerAttr(static_cast<uint32_t>(dppCtrl)),
rewriter.getI32IntegerAttr(rowMask), rewriter.getI32IntegerAttr(bankMask),
rewriter.getBoolAttr(boundCtrl));
}

Value emitPermlaneX16Xor(Location loc, RewriterBase &rewriter, Value val,
uint32_t rowMask) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
// PermlaneX16 reads the opposite 16-lane half of a 32-lane wave. It uses a
// 16-nibble selector to choose the source lane within that half.
// We use it to perform a shuffleXor with mask `rowMask ^ 16`.
assert(rowMask < 16 && "Expected a cross-row shuffleXor");
auto buildSelectorMask = [&](unsigned startLane) {
uint32_t sel = 0;
for (unsigned lane = 0; lane < 8; ++lane)
sel |= ((startLane + lane) ^ rowMask) << (lane * 4);
return sel;
};
Value loSel = b.i32_val(buildSelectorMask(0));
Value hiSel = b.i32_val(buildSelectorMask(8));
return ROCDL::PermlaneX16Op::create(rewriter, loc, val.getType(), val, val,
loSel, hiSel, true, false);
}

Value shuffleCommonImpl(Location loc, RewriterBase &rewriter,
ISAFamily isaFamily, Value val, Value i, int strideInt,
ShflKind mode, Value clamp) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
unsigned bits = val.getType().getIntOrFloatBitWidth();

Expand Down Expand Up @@ -82,110 +110,80 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter,
};

switch (mode) {
case ShflKind::bfly:
if (strideInt > 16 || (strideInt & (strideInt - 1)) != 0) {
// Non-power-of-2 masks or strides > 16 cannot use DPP or ds_swizzle.
// Fall back to ds_bpermute with an XOR lane index.
Value lineId = b.xor_(laneId, b.i32_val(strideInt));
return bpermute(lineId);
} else if (strideInt == 16) {
if (isRDNA(isaFamily)) {
// Lane i in the upper 16 lanes reads the value from lane i in the lower
// 16 lanes and vice versa.
Value select_lo = b.i32_val(0x76543210);
Value select_hi = b.i32_val(0xfedcba98);
return ROCDL::PermlaneX16Op::create(rewriter, loc, valType, val, val,
select_lo, select_hi, true, false);
} else {
Value offset = b.i32_val(0x401F);
return ROCDL::DsSwizzleOp::create(rewriter, loc, valType, val, offset);
}
} else {
if (!llvm::is_contained({ISAFamily::CDNA2, ISAFamily::CDNA3,
ISAFamily::CDNA4, ISAFamily::RDNA3,
ISAFamily::RDNA4, ISAFamily::GFX1250},
isaFamily)) {
// DPP is only supported for CDNA2/CDNA3/CDNA4/RDNA3/RDNA4/GFX1250 right
// now, so we fallback to ds_swizzle for other architectures.
//
// This map facilates the butterfly shuffle pattern for a stride less
// than 16. The pattern stride is the key of the map.
DenseMap<short, unsigned int> masks{
{16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}};
Value offset = b.i32_val(masks[strideInt]);
return ROCDL::DsSwizzleOp::create(rewriter, loc, valType, val, offset);
}

auto createDppOpWithoutBoundCtrl = [&](Value &old, Value &src,
uint32_t dppCtrl, uint32_t rowMask,
uint32_t bankMask) {
return ROCDL::DPPUpdateOp::create(rewriter, loc, valType, old, src,
rewriter.getI32IntegerAttr(dppCtrl),
rewriter.getI32IntegerAttr(rowMask),
rewriter.getI32IntegerAttr(bankMask),
rewriter.getBoolAttr(false));
};

const int allRows = 0xf;
const int allBanks = 0xf;

switch (strideInt) {
case 1: {
// quad_perm: 1, 0, 3, 2
uint32_t dppCtrl = static_cast<uint32_t>(DppCtrl::QUAD_PERM_FIRST);
std::array<uint32_t, 4> mask = {1, 0, 3, 2};
for (int i = 0; i < mask.size(); i++) {
dppCtrl |= mask[i] << (i * 2);
}
return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows,
allBanks);
}
case 2: {
// quad_perm: 2, 3, 0, 1
uint32_t dppCtrl = static_cast<uint32_t>(DppCtrl::QUAD_PERM_FIRST);
std::array<uint32_t, 4> mask = {2, 3, 0, 1};
for (int i = 0; i < mask.size(); i++) {
dppCtrl |= mask[i] << (i * 2);
case ShflKind::bfly: {
// We have the following tools to decompose shuffleXor into DPP primitives.
// On CDNA:
// - xor by 15 : reflect the lanes within each group of 16 lanes,
// - xor by 8 : right rotate by 8 within each group of 16 lanes,
// - xor by 7 : reflect the lanes within each group of 8 lanes,
// - xor by 1-3 : perform a permutation within each group of 4 lanes.
// On RDNA:
// - xor by 1-15 : row_xmask does exactly this.
//
// On RDNA, we also have permlanex16 which can perform xor by 16-31.
// On CDNA, one can see that any shuffleXor with `mask` in [1, 15] can
// be implemented using at most 2 DPP instructions by mapping the upper bits
// of `mask` to the `xor by 7` and `xor by 8` instructions.
uint32_t mask = strideInt;
if (mask == 0)
return val;
assert(mask < iWarpSize &&
"shuffle_xor expects a mask within the wave size");

auto makeDppCtrl = [](DppCtrl base, uint32_t arg) {
return static_cast<DppCtrl>(llvm::to_underlying(base) + arg);
};
auto makeQuadPermCtrl = [](uint32_t quadMask) {
DppCtrl ctrl = DppCtrl::QUAD_PERM_FIRST;
uint32_t ctrlBits = llvm::to_underlying(ctrl);
for (unsigned lane = 0; lane < 4; ++lane)
ctrlBits |= (lane ^ quadMask) << (lane * 2);
return static_cast<DppCtrl>(ctrlBits);
};

if (isRDNA(isaFamily) || isaFamily == ISAFamily::GFX1250) {
if (mask < 16)
return emitDpp(loc, rewriter, val, val,
makeDppCtrl(DppCtrl::ROW_XMASK0, mask));
else if (mask < 32)
return emitPermlaneX16Xor(loc, rewriter, val, mask & 0xf);
} else if ((isCDNA(isaFamily) || isaFamily == ISAFamily::GCN5_1) &&
mask < 16) {
Value result = val;
uint32_t highBitsDppBasis = 0;

if (mask & 4)
highBitsDppBasis ^= 7;
if (mask & 8)
highBitsDppBasis ^= 8;

uint32_t quadMask = mask ^ highBitsDppBasis;

if (highBitsDppBasis) {
DppCtrl highBitsDppCtrl;
switch (highBitsDppBasis) {
case 0x7:
highBitsDppCtrl = DppCtrl::ROW_HALF_MIRROR;
break;
case 0x8:
highBitsDppCtrl = makeDppCtrl(DppCtrl::ROW_ROR0, 8);
break;
case 0xf:
highBitsDppCtrl = DppCtrl::ROW_MIRROR;
break;
}
return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows,
allBanks);
}
case 4: {
// row_shr:4 bank_mask: 0xa
auto ret = createDppOpWithoutBoundCtrl(
val, val, 4 + static_cast<uint32_t>(DppCtrl::ROW_SHR0),
allRows, 0xa)
.getRes();

// row_shl:4 bank_mask: 0x5
return createDppOpWithoutBoundCtrl(
ret, val, 4 + static_cast<uint32_t>(DppCtrl::ROW_SHL0), allRows,
0x5);
}
case 8: {
// row_shr:8 bank_mask: 0xc
auto ret = createDppOpWithoutBoundCtrl(
val, val, 8 + static_cast<uint32_t>(DppCtrl::ROW_SHR0),
allRows, 0xc)
.getRes();

// row_shl:8 bank_mask: 0x3
return createDppOpWithoutBoundCtrl(
ret, val, 8 + static_cast<uint32_t>(DppCtrl::ROW_SHL0), allRows,
0x3);
}
default: {
// Non-power-of-2 strides (e.g. 3, 5, 6) arise from layout conversions
// that combine multiple lane bits into a single XOR mask. DPP can only
// express single-bit strides (1, 2, 4, 8); fall back to bpermute for
// the rest.
Value stride = b.i32_val(strideInt);
Value lineId = b.xor_(laneId, stride);
return bpermute(lineId);
result = emitDpp(loc, rewriter, result, result, highBitsDppCtrl);
}

if (quadMask) {
result =
emitDpp(loc, rewriter, result, result, makeQuadPermCtrl(quadMask));
}
return result;
} else {
return bpermute(b.xor_(laneId, b.i32_val(mask)));
}
break;
}
case ShflKind::up: {
Value mask = b.icmp_slt(laneId, i);
Value delta = b.sub(laneId, i);
Expand All @@ -201,9 +199,9 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter,
return Value();
}

static Value shuffleCommon(Location loc, RewriterBase &rewriter,
ISAFamily isaFamily, Value val, Value i,
int strideInt, ShflKind mode, Value clamp) {
Value shuffleCommon(Location loc, RewriterBase &rewriter, ISAFamily isaFamily,
Value val, Value i, int strideInt, ShflKind mode,
Value clamp) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
// To shuffle pointers, convert them to i64.
Type valTy = val.getType();
Expand All @@ -216,6 +214,8 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter,
return result;
}

} // namespace

Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i,
ISAFamily isaFamily) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
Expand Down
Loading