From c4f1e2a49357d180127a1fc85af467b3280327c4 Mon Sep 17 00:00:00 2001 From: Frederick Vu <100011202+FrederickVu@users.noreply.github.com> Date: Fri, 17 Apr 2026 18:09:38 +0000 Subject: [PATCH 1/3] Clean up shufflexor implementation --- test/Conversion/amd/tritongpu_to_llvm.mlir | 15 +- .../include/TritonAMDGPUToLLVM/TargetUtils.h | 6 +- .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 223 +++++++++--------- 3 files changed, 127 insertions(+), 117 deletions(-) diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index 5a9d4e2dc4cc..787cd5f1f062 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -229,24 +229,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: reduce_xor_max tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) { - // CHECK: rocdl.ds_swizzle + // stride 16: CDNA fallback to bpermute + // CHECK: rocdl.ds_bpermute + // stride 8: ROW_ROR:8 (0x128 = 296) // CHECK: rocdl.update.dpp - // CHECK-SAME: with 280, 15, 12, false : i32 - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 264, 15, 3, false : i32 + // CHECK-SAME: with 296, 15, 15, false : i32 // CHECK: llvm.intr.maxnum + // stride 4: ROW_HALF_MIRROR (0x141 = 321) + quad_perm xor 3 (27) // CHECK: rocdl.update.dpp - // CHECK-SAME: with 276, 15, 10, false : i32 + // CHECK-SAME: with 321, 15, 15, false : i32 // CHECK: rocdl.update.dpp - // CHECK-SAME: with 260, 15, 5, false : i32 + // CHECK-SAME: with 27, 15, 15, false : i32 // CHECK: llvm.intr.maxnum + // stride 2: quad_perm xor 2 (78) // CHECK: rocdl.update.dpp // CHECK-SAME: with 78, 15, 15, false : i32 // CHECK: llvm.intr.maxnum + // stride 1: quad_perm xor 1 (177) // CHECK: rocdl.update.dpp // CHECK-SAME: with 177, 15, 15, false : i32 %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h index 5d84110dc263..6a6b466ead87 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h @@ -37,8 +37,12 @@ enum class DppCtrl : uint32_t { QUAD_PERM_FIRST = 0, ROW_SHL0 = 0x100, ROW_SHR0 = 0x110, + ROW_ROR0 = 0x120, + ROW_MIRROR = 0x140, + ROW_HALF_MIRROR = 0x141, BCAST15 = 0x142, - BCAST31 = 0x143 + BCAST31 = 0x143, + ROW_XMASK0 = 0x160, }; } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 1ce9740295a3..9880d2ae71a9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -16,19 +16,50 @@ using mlir::triton::AMD::DppCtrl; using mlir::triton::AMD::ISAFamily; using mlir::triton::gpu::appendOrGetExternFuncOp; +namespace mlir::LLVM::AMD { namespace { + enum class ShflKind : uint32_t { bfly = 0, up = 1, - down = 2, - idx = 3, + idx = 2, }; -} // namespace -namespace mlir::LLVM::AMD { -static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, - ISAFamily isaFamily, Value val, Value i, - int strideInt, ShflKind mode, Value clamp) { +Value emitDpp(Location loc, RewriterBase &rewriter, Value old, Value src, + DppCtrl dppCtrl, uint32_t rowMask = 0xf, uint32_t bankMask = 0xf, + bool boundCtrl = false) { + return ROCDL::DPPUpdateOp::create( + rewriter, loc, src.getType(), old, src, + rewriter.getI32IntegerAttr(static_cast(dppCtrl)), + rewriter.getI32IntegerAttr(rowMask), + rewriter.getI32IntegerAttr(bankMask), + rewriter.getBoolAttr(boundCtrl)) + .getRes(); +} + +Value emitPermlaneX16Xor(Location loc, RewriterBase &rewriter, Value val, + uint32_t rowMask) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // PermlaneX16 reads the opposite 16-lane half of a 32-lane wave. It uses a + // 16-nibble selector to choose the source lane within that half. + // We use it to perform a shuffleXor with mask `rowMask ^ 16`. + assert(rowMask < 16 && "Expected a cross-row shuffleXor"); + auto buildSelectorMask = [&](unsigned startLane) { + uint32_t sel = 0; + for (unsigned lane = 0; lane < 8; ++lane) + sel |= ((startLane + lane) ^ rowMask) << (lane * 4); + return sel; + }; + Value loSel = b.i32_val(buildSelectorMask(0)); + Value hiSel = b.i32_val(buildSelectorMask(8)); + return ROCDL::PermlaneX16Op::create(rewriter, loc, val.getType(), val, val, + loSel, hiSel, true, false) + .getRes(); +} + +Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, + ISAFamily isaFamily, Value val, Value i, int strideInt, + ShflKind mode, Value clamp) { auto b = TritonLLVMOpBuilder(loc, rewriter); unsigned bits = val.getType().getIntOrFloatBitWidth(); @@ -82,110 +113,80 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, }; switch (mode) { - case ShflKind::bfly: - if (strideInt > 16 || (strideInt & (strideInt - 1)) != 0) { - // Non-power-of-2 masks or strides > 16 cannot use DPP or ds_swizzle. - // Fall back to ds_bpermute with an XOR lane index. - Value lineId = b.xor_(laneId, b.i32_val(strideInt)); - return bpermute(lineId); - } else if (strideInt == 16) { - if (isRDNA(isaFamily)) { - // Lane i in the upper 16 lanes reads the value from lane i in the lower - // 16 lanes and vice versa. - Value select_lo = b.i32_val(0x76543210); - Value select_hi = b.i32_val(0xfedcba98); - return ROCDL::PermlaneX16Op::create(rewriter, loc, valType, val, val, - select_lo, select_hi, true, false); - } else { - Value offset = b.i32_val(0x401F); - return ROCDL::DsSwizzleOp::create(rewriter, loc, valType, val, offset); - } - } else { - if (!llvm::is_contained({ISAFamily::CDNA2, ISAFamily::CDNA3, - ISAFamily::CDNA4, ISAFamily::RDNA3, - ISAFamily::RDNA4, ISAFamily::GFX1250}, - isaFamily)) { - // DPP is only supported for CDNA2/CDNA3/CDNA4/RDNA3/RDNA4/GFX1250 right - // now, so we fallback to ds_swizzle for other architectures. - // - // This map facilates the butterfly shuffle pattern for a stride less - // than 16. The pattern stride is the key of the map. - DenseMap masks{ - {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; - Value offset = b.i32_val(masks[strideInt]); - return ROCDL::DsSwizzleOp::create(rewriter, loc, valType, val, offset); - } - - auto createDppOpWithoutBoundCtrl = [&](Value &old, Value &src, - uint32_t dppCtrl, uint32_t rowMask, - uint32_t bankMask) { - return ROCDL::DPPUpdateOp::create(rewriter, loc, valType, old, src, - rewriter.getI32IntegerAttr(dppCtrl), - rewriter.getI32IntegerAttr(rowMask), - rewriter.getI32IntegerAttr(bankMask), - rewriter.getBoolAttr(false)); - }; - - const int allRows = 0xf; - const int allBanks = 0xf; - - switch (strideInt) { - case 1: { - // quad_perm: 1, 0, 3, 2 - uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); - std::array mask = {1, 0, 3, 2}; - for (int i = 0; i < mask.size(); i++) { - dppCtrl |= mask[i] << (i * 2); - } - return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, - allBanks); - } - case 2: { - // quad_perm: 2, 3, 0, 1 - uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); - std::array mask = {2, 3, 0, 1}; - for (int i = 0; i < mask.size(); i++) { - dppCtrl |= mask[i] << (i * 2); + case ShflKind::bfly: { + // We have the following tools to decompose shuffleXor into DPP primitives. + // On CDNA: + // - xor by 15 : reflect the lanes within each group of 16 lanes, + // - xor by 8 : right rotate by 8 within each group of 16 lanes, + // - xor by 7 : reflect the lanes within each group of 8 lanes, + // - xor by 1-3 : perform a permutation within each group of 4 lanes. + // On RDNA: + // - xor by 1-15 : row_xmask does exactly this. + // + // On RDNA, we also have permlanex16 which can perform xor by 16-31. + // On CDNA, one can see that any shuffleXor with `mask` in [1, 15] can + // be implemented using at most 2 DPP instructions by mapping the upper bits + // of `mask` to the `xor by 7` and `xor by 8` instructions. + uint32_t mask = strideInt; + if (mask == 0) + return val; + assert(mask < iWarpSize && + "shuffle_xor expects a mask within the wave size"); + + auto makeDppCtrl = [](DppCtrl base, uint32_t arg) { + return static_cast(llvm::to_underlying(base) + arg); + }; + auto makeQuadPermCtrl = [](uint32_t quadMask) { + DppCtrl ctrl = DppCtrl::QUAD_PERM_FIRST; + uint32_t ctrlBits = llvm::to_underlying(ctrl); + for (unsigned lane = 0; lane < 4; ++lane) + ctrlBits |= (lane ^ quadMask) << (lane * 2); + return static_cast(ctrlBits); + }; + + if (isRDNA(isaFamily) || isaFamily == ISAFamily::GFX1250) { + if (mask < 16) + return emitDpp(loc, rewriter, val, val, + makeDppCtrl(DppCtrl::ROW_XMASK0, mask)); + else if (mask < 32) + return emitPermlaneX16Xor(loc, rewriter, val, mask & 0xf); + } else if ((isCDNA(isaFamily) || isaFamily == ISAFamily::GCN5_1) && + mask < 16) { + Value result = val; + uint32_t highBitsDppBasis = 0; + + if (mask & 4) + highBitsDppBasis ^= 7; + if (mask & 8) + highBitsDppBasis ^= 8; + + uint32_t quadMask = mask ^ highBitsDppBasis; + + if (highBitsDppBasis) { + DppCtrl highBitsDppCtrl; + switch (highBitsDppBasis) { + case 0x7: + highBitsDppCtrl = DppCtrl::ROW_HALF_MIRROR; + break; + case 0x8: + highBitsDppCtrl = makeDppCtrl(DppCtrl::ROW_ROR0, 8); + break; + case 0xf: + highBitsDppCtrl = DppCtrl::ROW_MIRROR; + break; } - return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, - allBanks); - } - case 4: { - // row_shr:4 bank_mask: 0xa - auto ret = createDppOpWithoutBoundCtrl( - val, val, 4 + static_cast(DppCtrl::ROW_SHR0), - allRows, 0xa) - .getRes(); - - // row_shl:4 bank_mask: 0x5 - return createDppOpWithoutBoundCtrl( - ret, val, 4 + static_cast(DppCtrl::ROW_SHL0), allRows, - 0x5); - } - case 8: { - // row_shr:8 bank_mask: 0xc - auto ret = createDppOpWithoutBoundCtrl( - val, val, 8 + static_cast(DppCtrl::ROW_SHR0), - allRows, 0xc) - .getRes(); - - // row_shl:8 bank_mask: 0x3 - return createDppOpWithoutBoundCtrl( - ret, val, 8 + static_cast(DppCtrl::ROW_SHL0), allRows, - 0x3); - } - default: { - // Non-power-of-2 strides (e.g. 3, 5, 6) arise from layout conversions - // that combine multiple lane bits into a single XOR mask. DPP can only - // express single-bit strides (1, 2, 4, 8); fall back to bpermute for - // the rest. - Value stride = b.i32_val(strideInt); - Value lineId = b.xor_(laneId, stride); - return bpermute(lineId); + result = emitDpp(loc, rewriter, result, result, highBitsDppCtrl); } + + if (quadMask) { + result = + emitDpp(loc, rewriter, result, result, makeQuadPermCtrl(quadMask)); } + return result; + } else { + return bpermute(b.xor_(laneId, b.i32_val(mask))); } - break; + } case ShflKind::up: { Value mask = b.icmp_slt(laneId, i); Value delta = b.sub(laneId, i); @@ -201,9 +202,9 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, return Value(); } -static Value shuffleCommon(Location loc, RewriterBase &rewriter, - ISAFamily isaFamily, Value val, Value i, - int strideInt, ShflKind mode, Value clamp) { +Value shuffleCommon(Location loc, RewriterBase &rewriter, ISAFamily isaFamily, + Value val, Value i, int strideInt, ShflKind mode, + Value clamp) { auto b = TritonLLVMOpBuilder(loc, rewriter); // To shuffle pointers, convert them to i64. Type valTy = val.getType(); @@ -216,6 +217,8 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, return result; } +} // namespace + Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, ISAFamily isaFamily) { auto b = TritonLLVMOpBuilder(loc, rewriter); From 67209a61540c6904dfab358bcf2343c467a2969b Mon Sep 17 00:00:00 2001 From: Frederick Vu <100011202+FrederickVu@users.noreply.github.com> Date: Fri, 17 Apr 2026 20:21:07 +0000 Subject: [PATCH 2/3] retrigger CI From 1a9639a5f59b20530fca589b03678b2d8de977dd Mon Sep 17 00:00:00 2001 From: Frederick Vu <100011202+FrederickVu@users.noreply.github.com> Date: Sat, 18 Apr 2026 02:12:08 +0000 Subject: [PATCH 3/3] Add gfx1250 test and drop getRes() calls --- test/Conversion/amd/tritongpu_to_llvm.mlir | 31 +++++++++++++++++++ .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 13 +++----- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index 787cd5f1f062..bfa139b4789c 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -263,6 +263,37 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr // ----- +#linear = #ttg.linear<{register = [[0, 1]], lane = [[1, 0], [0, 0], [2, 0], [4, 0], [8, 0]], warp = [], block = []}> +#slice = #ttg.slice<{dim = 0, parent = #linear}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + // GFX1250-LABEL: reduce_xor_row_xmask + tt.func @reduce_xor_row_xmask(%arg0: tensor<16x2xf32, #linear>) { + // stride 16 + // GFX1250-NOT: rocdl.ds_bpermute + // GFX1250: rocdl.permlanex16 + + // stride 8: ROW_XMASK:8 + // GFX1250: rocdl.update.dpp + // GFX1250-SAME: with 360, 15, 15, false + + // stride 4: ROW_XMASK:4 + // GFX1250: rocdl.update.dpp + // GFX1250-SAME: with 356, 15, 15, false + + // stride 1: ROW_XMASK:1 + // GFX1250: rocdl.update.dpp + // GFX1250-SAME: with 353, 15, 15, false + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<16x2xf32, #linear>) -> tensor<2xf32, #slice> + tt.return + } +} + +// ----- + #blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: atomicrmw_scope_memsemantics diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 9880d2ae71a9..3b643f379833 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -29,12 +29,10 @@ Value emitDpp(Location loc, RewriterBase &rewriter, Value old, Value src, DppCtrl dppCtrl, uint32_t rowMask = 0xf, uint32_t bankMask = 0xf, bool boundCtrl = false) { return ROCDL::DPPUpdateOp::create( - rewriter, loc, src.getType(), old, src, - rewriter.getI32IntegerAttr(static_cast(dppCtrl)), - rewriter.getI32IntegerAttr(rowMask), - rewriter.getI32IntegerAttr(bankMask), - rewriter.getBoolAttr(boundCtrl)) - .getRes(); + rewriter, loc, src.getType(), old, src, + rewriter.getI32IntegerAttr(static_cast(dppCtrl)), + rewriter.getI32IntegerAttr(rowMask), rewriter.getI32IntegerAttr(bankMask), + rewriter.getBoolAttr(boundCtrl)); } Value emitPermlaneX16Xor(Location loc, RewriterBase &rewriter, Value val, @@ -53,8 +51,7 @@ Value emitPermlaneX16Xor(Location loc, RewriterBase &rewriter, Value val, Value loSel = b.i32_val(buildSelectorMask(0)); Value hiSel = b.i32_val(buildSelectorMask(8)); return ROCDL::PermlaneX16Op::create(rewriter, loc, val.getType(), val, val, - loSel, hiSel, true, false) - .getRes(); + loSel, hiSel, true, false); } Value shuffleCommonImpl(Location loc, RewriterBase &rewriter,