diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index b000a3129912..50d024794663 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -b5cc222d7429fe6f18c787f633d5262fac2e676f +86b69c31642e98f8357df62c09d118ad1da4e16a diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 283dd9165918..c39c408d9330 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -727,6 +727,10 @@ def TT_ReduceOp: TT_Op<"reduce", llvm::SmallVector getInputTypes(); llvm::SmallVector getElementTypes(); unsigned getNumOperands(); + + // Returns the CombineOp iff this ReduceOp's region contains only + // one CombineOp other than the return, or nullptr if not applicable. + ::mlir::Operation *getSingleCombiner(); }]; } diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp index 34fb8995430f..06e75ee18d59 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -56,20 +56,19 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // This will create newArg, and map(origArg, newArg) addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, - Location loc) -> std::optional { + Location loc) -> Value { llvm_unreachable("Argument rematerialization should not happen in Triton " "-> TritonGPU conversion"); - return std::nullopt; + return {}; }); // If the origValue still has live user(s), use this to // convert origValue to newValue addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, - ValueRange inputs, - Location loc) -> std::optional { + ValueRange inputs, Location loc) -> Value { llvm_unreachable("Source rematerialization should not happen in Triton -> " "TritonGPU Conversion"); - return std::nullopt; + return {}; }); // This will be called when (desiredType != newOperandType) @@ -79,7 +78,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, ValueRange inputs, Location loc) { auto cast = builder.create(loc, tensorType, inputs); - return std::optional(cast.getResult()); + return cast.getResult(); }); } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index ffea5f3c67a6..e77e2d5c8691 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -503,6 +503,22 @@ llvm::SmallVector ReduceOp::getElementTypes() { return getElementTypesImpl(this->getOperands()); } +::mlir::Operation *ReduceOp::getSingleCombiner() { + if (getNumOperands() != 1 || getNumResults() != 1) + return nullptr; + Block *block = &(*getCombineOp().begin()); + Operation *yield = block->getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + if (!reduceOp || reduceOp->getNumOperands() != 2 || + reduceOp->getNumResults() != 1) + return nullptr; + if (reduceOp->getOperand(0) != block->getArgument(0) || + reduceOp->getOperand(1) != block->getArgument(1)) + return nullptr; + + return reduceOp; +} + unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } //-- ScanOp -- diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index ef6733845721..72e02a4ef46e 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -62,3 +62,111 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_f16 + tt.func @atomic_add_f16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> + // CHECK: llvm.cond_br + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> + tt.return + } +} + +// ----- + +#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_bf16 + tt.func @atomic_add_bf16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> + // CHECK: llvm.cond_br + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> + tt.return + } +} + +// ----- + +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: reduce_dpp_max + tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) { + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 280, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 276, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 274, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 273, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 322, 10, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 323, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK: llvm.amdgcn.readlane + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<64xf32, #blocked3>) -> f32 + tt.return + } +} + +// ----- + +#blocked4 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: reduce_xor_max + tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) { + // CHECK: rocdl.ds_swizzle + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 280, 15, 12, false : i32 + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 264, 15, 3, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 276, 15, 10, false : i32 + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 260, 15, 5, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 78, 15, 15, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 177, 15, 15, false : i32 + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32xf32, #blocked4>) -> f32 + tt.return + } +} diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index 4fb418e3811b..25897f2a9378 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -42,7 +42,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %1 = arith.muli %0, %c1024_i32 : i32 %sub = arith.subi %1, %c128_i32 : i32 %cmp = arith.cmpi sgt, %sub, %c0_i32 : i32 - "llvm.intr.assume"(%cmp) : (i1) -> () + llvm.intr.assume %cmp : i1 %2 = tt.splat %sub : i32 -> tensor<1024xi32, #blocked> %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> // CHECK: %[[offset:.*]] = arith.addi diff --git a/test/lib/Instrumentation/GPUHello.cpp b/test/lib/Instrumentation/GPUHello.cpp index 3bee8ce90ced..5c71857c8f36 100644 --- a/test/lib/Instrumentation/GPUHello.cpp +++ b/test/lib/Instrumentation/GPUHello.cpp @@ -61,7 +61,7 @@ bool GpuHello::runOnModule(Module &module) { PassPluginLibraryInfo getPassPluginInfo() { const auto callback = [](PassBuilder &pb) { - pb.registerOptimizerLastEPCallback([&](ModulePassManager &mpm, auto) { + pb.registerOptimizerLastEPCallback([&](ModulePassManager &mpm, auto, auto) { mpm.addPass(GpuHello()); return true; }); diff --git a/third_party/amd/backend/include/hsa/amd_hsa_elf.h b/third_party/amd/backend/include/hsa/amd_hsa_elf.h index 51aa389a0681..0656c9d99419 100644 --- a/third_party/amd/backend/include/hsa/amd_hsa_elf.h +++ b/third_party/amd/backend/include/hsa/amd_hsa_elf.h @@ -130,6 +130,7 @@ enum : unsigned { EF_AMDGPU_MACH_AMDGCN_GFX1151 = 0x04a, EF_AMDGPU_MACH_AMDGCN_GFX941 = 0x04b, EF_AMDGPU_MACH_AMDGCN_GFX942 = 0x04c, + EF_AMDGPU_MACH_AMDGCN_GFX950 = 0x04f, // First/last AMDGCN-based processors. EF_AMDGPU_MACH_AMDGCN_FIRST = EF_AMDGPU_MACH_AMDGCN_GFX600, diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index a7395f86dc50..6dbb0435e20c 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -30,6 +30,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" #include "triton/Dialect/Triton/IR/Traits.h" + // clang-format off #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" // clang-format on diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h index a49e442d3984..9e174d545dd9 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h @@ -19,6 +19,17 @@ enum class ISAFamily { // Deduces the corresponding ISA family for the given target gfx |arch|. ISAFamily deduceISAFamily(llvm::StringRef arch); +// Here is a partial definition of DppCtrl enums. For the complete definition, +// please check: +// https://github.com/llvm/llvm-project/blob/8c75290/llvm/lib/Target/AMDGPU/SIDefines.h#L939 +enum class DppCtrl : uint32_t { + QUAD_PERM_FIRST = 0, + ROW_SHL0 = 0x100, + ROW_SHR0 = 0x110, + BCAST15 = 0x142, + BCAST31 = 0x143 +}; + } // namespace mlir::triton::AMD #endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETUTILS_H diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index a45efd4a7971..5265f631ad9e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -768,7 +768,11 @@ struct AtomicRMWOpConversion // tensor if (tensorTy) { auto valTy = cast(val.getType()); - vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); + Type elTy = valTy.getElementType(); + vec = std::min(vec, llvm::isa(elTy) && + elTy.getIntOrFloatBitWidth() == 16 + ? 2 + : 1); // mask numElems = tensorTy.getNumElements(); } @@ -783,13 +787,22 @@ struct AtomicRMWOpConversion auto vecTy = vec_ty(valueElemTy, vec); auto retType = vec == 1 ? valueElemTy : vecTy; SmallVector resultVals(elemsPerThread); - const bool f16v2 = vec == 2 && valueElemTy.isF16(); for (size_t i = 0; i < elemsPerThread; i += vec) { Value rmwPtr = ptrElements[i]; // TODO: in case llMask is zero we can create only one branch for all // elemsPerThread. Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; + Value operand; + if (vec == 1) { + operand = valElements[i]; + } else { + operand = undef(vecTy); + for (size_t ii = 0; ii < vec; ++ii) + operand = + insert_element(vecTy, operand, valElements[i + ii], i32_val(ii)); + } + Value undefVal = undef(retType); // Build blocks to bypass the atomic instruction for ~rmwMask. auto *curBlock = rewriter.getInsertionBlock(); @@ -806,25 +819,11 @@ struct AtomicRMWOpConversion auto maybeKind = matchAtomicOp(atomicRmwAttr); // TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient // atomics for MI-* series of AMD GPU. - Value atom = rewriter - .create( - loc, *maybeKind, rmwPtr, valElements[i], - atomicMemOrdering, StringRef("agent")) - .getResult(); - - // NV for the f16v2 case generates one packed instruction. We have to - // create two separate instructions since LLVM::AtomicRMWOp doesn't - // support this. Can be optimized out with rocdl.raw.buffer.atomic. - if (f16v2) { - Value atom2 = - rewriter - .create( - loc, *maybeKind, ptrElements[i + 1], valElements[i + 1], - atomicMemOrdering, StringRef("agent")) - .getResult(); - auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0)); - atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult(); - } + Value atom = + rewriter + .create(loc, *maybeKind, rmwPtr, operand, + atomicMemOrdering, StringRef("agent")) + .getResult(); if (!tensorTy) { if (atomicNeedsSharedMemory(op.getResult())) { Value atomPtr = diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 3a40d73c2a7c..525361fee603 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -5,6 +5,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +using mlir::triton::AMD::DppCtrl; namespace mlir::triton::AMD { namespace { @@ -103,22 +104,22 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleXor(loc, rewriter, val, i); + return LLVM::AMD::shuffleXor(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleUp(loc, rewriter, val, i); + return LLVM::AMD::shuffleUp(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); + return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, Value i) const { - return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); + return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::programId(RewriterBase &rewriter, Location loc, @@ -126,11 +127,184 @@ Value TargetInfo::programId(RewriterBase &rewriter, Location loc, return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis); } +// Cast and sext values into specific-length int to meet the requirements of +// instructions like UpdateDpp or readlane if necessary. +static inline Type castToAndSExtInt(RewriterBase &rewriter, Location loc, + Value &val, Type fromType, + unsigned toBits) { + unsigned originalBits = fromType.getIntOrFloatBitWidth(); + Type toType = fromType; + + if (!fromType.isIntOrIndex()) { + val = bitcast(val, int_ty(originalBits)); + toType = int_ty(originalBits); + } + + if (originalBits < toBits) { + val = sext(int_ty(toBits), val); + toType = int_ty(toBits); + } + + return toType; +} + +// Trunc the value to specific length and then cast it to given type if +// necessary. This function is typically used in conjunction with +// castToAndSExtInt. +static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc, + Value val, Type valType, + unsigned fromBits) { + unsigned originalBits = valType.getIntOrFloatBitWidth(); + Value toVal = val; + + if (originalBits < fromBits) { + toVal = trunc(int_ty(originalBits), toVal); + } + + if (!valType.isIntOrIndex()) { + toVal = bitcast(toVal, valType); + } + + return toVal; +} + bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { - return false; + if (numLaneToReduce != 64) + return false; + + if (auto family = getISAFamily(); + family != ISAFamily::CDNA3 && family != ISAFamily::CDNA2) { + return false; + } + + Operation *reduxOp = op.getSingleCombiner(); + if (!reduxOp) + return false; + + auto createDppReduxOpWithBoundCtrl = [&](Type valType, Value &src, + uint32_t dppCtrl, int rowMask, + int bankMask) -> Value { + // DPP has limited support for data types, so here we need to + // cast non-integer types or integer types shorter than 32 bits + // to int32, except for fp32. + Type actualType = valType; + if (!valType.isF32()) { + actualType = castToAndSExtInt(rewriter, loc, src, valType, 32); + } + + Value dppResult = + rewriter + .create(loc, actualType, src, src, + rewriter.getI32IntegerAttr(dppCtrl), + rewriter.getI32IntegerAttr(rowMask), + rewriter.getI32IntegerAttr(bankMask), + rewriter.getBoolAttr(true)) + .getRes(); + + if (!valType.isF32()) { + src = truncAndCastFromInt(rewriter, loc, src, valType, 32); + dppResult = truncAndCastFromInt(rewriter, loc, dppResult, valType, 32); + } + + IRMapping mapping; + mapping.map(reduxOp->getOperand(0), src); + mapping.map(reduxOp->getOperand(1), dppResult); + return rewriter.clone(*reduxOp, mapping)->getResult(0); + }; + + for (int i = 0; i < acc.size(); i++) { + Value buf; + auto valType = acc[i].getType(); + + /* + Here's the implementation of full-wavefront reduction using dpp. + https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ + + Each step has a v_mov_dpp instruction following the redux op. In + some cases, the lower-level compiler could merge them into single + instruction. For example, v_mov_dpp + max => v_max_dpp. + + For gfx9, we have 64 threads per warp. These 64 threads are arranged + into 4 rows, with each row being 16 threads. Each 16 threads are arranged + further into 4 banks, with each bank being 4 threads. Overall it's in a + (row, bank, thread) structure. When shuffling, we use row/bank mask to + indicate which row/bank to participate. Then modifier like row_shr and + row_bcast means exact data movement schemes. In the following + instructions, taking row 0 as an example: + + Step 1: Right shift for 8 lanes. + lane 8-15 = redux(lane 0-7, lane 8-15) + + Step 2: Right shift for 4 lanes. + lane 12-15 = redux(lane 8-11, lane 12-15) + + Step 3: Right shift for 2 lanes. + lane 14-15 = redux(lane 12-13, lane 14-15) + + Step 4: Right shift for 1 lane. + lane 15 = redux(lane 14, lane 15) + + Step 5: Broadcast lane 15 of each row to all the lanes of its next row. + lane 16-31 = redux(lane 15, lane 16-31) + + Step 6: Broadcast lane 31 to lane 32-63. + lane 32-63 = redux(lane 31, lane 32-63) + + Now the reduction result is stored in lane 63. + + Step 7: Read the reduction result from lane 63 and broadcast with + readlane. + */ + + const int allRows = 0xf; + const int allBanks = 0xf; + + const uint32_t dppCtrlRowShr = static_cast(DppCtrl::ROW_SHR0); + + // row_shr:8 + buf = createDppReduxOpWithBoundCtrl(valType, acc[i], 8 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:4 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 4 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:2 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 2 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:1 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr, + allRows, allBanks); + + // row_bcast:15 row_mask:0xa + buf = createDppReduxOpWithBoundCtrl( + valType, buf, static_cast(DppCtrl::BCAST15), 0xa, allBanks); + + // row_bcast:31 + buf = createDppReduxOpWithBoundCtrl(valType, buf, + static_cast(DppCtrl::BCAST31), + allRows, allBanks); + + // Similarly, we need to cast data types for readlane instruction. + Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 16); + + // Get reduction result from lane 63 + std::string intrinsic = "llvm.amdgcn.readlane"; + Value result = + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, actualType, + ValueRange{buf, i32_val(63)}) + ->getResult(0); + + result = truncAndCastFromInt(rewriter, loc, result, valType, 16); + + acc[i] = result; + } + + return true; } void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp index 63fb972f7903..7ab6fd68a5d5 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp @@ -11,6 +11,7 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) { // CDNA ISA cases switch (kind) { + case llvm::AMDGPU::GK_GFX950: case llvm::AMDGPU::GK_GFX942: case llvm::AMDGPU::GK_GFX941: case llvm::AMDGPU::GK_GFX940: diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 542b1ecbb7fb..0bd401f1993a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -8,6 +8,8 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +using mlir::triton::AMD::DppCtrl; +using mlir::triton::AMD::ISAFamily; using mlir::triton::gpu::appendOrGetExternFuncOp; using mlir::triton::gpu::getFunctionType; @@ -71,8 +73,9 @@ Type castToVectorType(Type ty) { } // namespace namespace mlir::LLVM::AMD { -static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, - Value i, int strideInt, ShflKind mode, Value clamp) { +static Value shuffleCommon(Location loc, RewriterBase &rewriter, + ISAFamily isaFamily, Value val, Value i, + int strideInt, ShflKind mode, Value clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); // On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on @@ -84,7 +87,8 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, if (bits < 32) val = sext(i32_ty, val); - val = shuffleCommon(loc, rewriter, val, i, strideInt, mode, clamp); + val = + shuffleCommon(loc, rewriter, isaFamily, val, i, strideInt, mode, clamp); if (bits < 32) val = trunc(int_ty(bits), val); @@ -98,8 +102,10 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, Value vec = bitcast(val, vecTy); Value val0 = extract_element(f32_ty, vec, i32_val(0)); Value val1 = extract_element(f32_ty, vec, i32_val(1)); - val0 = shuffleCommon(loc, rewriter, val0, i, strideInt, mode, clamp); - val1 = shuffleCommon(loc, rewriter, val1, i, strideInt, mode, clamp); + val0 = shuffleCommon(loc, rewriter, isaFamily, val0, i, strideInt, mode, + clamp); + val1 = shuffleCommon(loc, rewriter, isaFamily, val1, i, strideInt, mode, + clamp); vec = undef(vecTy); vec = insert_element(vecTy, vec, val0, i32_val(0)); vec = insert_element(vecTy, vec, val1, i32_val(1)); @@ -134,13 +140,83 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, Value stride = i32_val(32); Value lineId = xor_(threadId, stride); return bpermute(lineId); - } else { - // This map facilates the butterfly shuffle pattern for a stride less - // than 16. The pattern stride is the key of the map. - DenseMap masks{ - {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; - Value offset = i32_val(masks[strideInt]); + } else if (strideInt == 16) { + Value offset = i32_val(0x401F); return rewriter.create(loc, valType, val, offset); + } else { + if (isaFamily != ISAFamily::CDNA2 && isaFamily != ISAFamily::CDNA3) { + // DPP is only supportted for CDNA2 and CDNA3 right now, so we fallback + // to ds_swizzle for other archs. + // + // This map facilates the butterfly shuffle pattern for a stride less + // than 16. The pattern stride is the key of the map. + DenseMap masks{ + {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; + Value offset = i32_val(masks[strideInt]); + return rewriter.create(loc, valType, val, offset); + } + + auto createDppOpWithoutBoundCtrl = [&](Value &old, Value &src, + uint32_t dppCtrl, uint32_t rowMask, + uint32_t bankMask) { + return rewriter.create( + loc, valType, old, src, rewriter.getI32IntegerAttr(dppCtrl), + rewriter.getI32IntegerAttr(rowMask), + rewriter.getI32IntegerAttr(bankMask), rewriter.getBoolAttr(false)); + }; + + const int allRows = 0xf; + const int allBanks = 0xf; + + switch (strideInt) { + case 1: { + // quad_perm: 1, 0, 3, 2 + uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); + std::array mask = {1, 0, 3, 2}; + for (int i = 0; i < mask.size(); i++) { + dppCtrl |= mask[i] << (i * 2); + } + return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, + allBanks); + } + case 2: { + // quad_perm: 2, 3, 0, 1 + uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); + std::array mask = {2, 3, 0, 1}; + for (int i = 0; i < mask.size(); i++) { + dppCtrl |= mask[i] << (i * 2); + } + return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, + allBanks); + } + case 4: { + // row_shr:4 bank_mask: 0xa + auto ret = createDppOpWithoutBoundCtrl( + val, val, 4 + static_cast(DppCtrl::ROW_SHR0), + allRows, 0xa) + .getRes(); + + // row_shl:4 bank_mask: 0x5 + return createDppOpWithoutBoundCtrl( + ret, val, 4 + static_cast(DppCtrl::ROW_SHL0), allRows, + 0x5); + } + case 8: { + // row_shr:8 bank_mask: 0xc + auto ret = createDppOpWithoutBoundCtrl( + val, val, 8 + static_cast(DppCtrl::ROW_SHR0), + allRows, 0xc) + .getRes(); + + // row_shl:8 bank_mask: 0x3 + return createDppOpWithoutBoundCtrl( + ret, val, 8 + static_cast(DppCtrl::ROW_SHL0), allRows, + 0x3); + } + default: + assert(false && + "bfly shfl with stride >= 16 should not be handled by dpp."); + } } break; case ShflKind::up: { @@ -158,22 +234,27 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, return Value(); } -Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::bfly, - i32_val(0x1f)); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + return shuffleCommon(loc, rewriter, isaFamily, val, i32_val(i), i, + ShflKind::bfly, i32_val(0x1f)); } -Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::up, - i32_val(0x0)); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + return shuffleCommon(loc, rewriter, isaFamily, val, i32_val(i), i, + ShflKind::up, i32_val(0x0)); } -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleIdx(loc, rewriter, val, i32_val(i)); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + return shuffleIdx(loc, rewriter, val, i32_val(i), isaFamily); } -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { - return shuffleCommon(loc, rewriter, val, i, 0, ShflKind::idx, i32_val(0x1f)); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, + ISAFamily isaFamily) { + return shuffleCommon(loc, rewriter, isaFamily, val, i, 0, ShflKind::idx, + i32_val(0x1f)); } Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index 123234fd4824..d150531848e3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -2,12 +2,14 @@ #define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_UTILITY_H #include "TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "TritonAMDGPUToLLVM/TargetUtils.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" + namespace mlir::LLVM::AMD { const char predicatedLoad[] = "__predicated_load"; @@ -19,10 +21,18 @@ const char predicatedStoreCG[] = "__predicated_store_CG"; const char predicatedStoreCS[] = "__predicated_store_CS"; const char predicatedStoreWT[] = "__predicated_store_WT"; -Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); -Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, int axis); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index 7da8083cfb92..c3a69a5f9a2a 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_triton_library(TritonAMDGPUTransforms MfmaGroup.cpp DEPENDS + TritonAMDGPUIR TritonAMDGPUTransformsIncGen TritonGPUIR ) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index e122f15fd901..22349c50e308 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -5,23 +5,28 @@ #include "mlir/IR/Verifier.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include - -#define GEN_PASS_CLASSES -#include "TritonAMDGPUTransforms/Passes.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; namespace ttg = mlir::triton::gpu; -namespace tt = mlir::triton; - -static bool isLocalLoadOrDotLayoutConversion(Operation *op) { - if (isa(op)) - return true; - if (auto cvt = dyn_cast(op)) - return isa(cvt.getType().getEncoding()); - return false; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +// Return true if the given moduleOp contains a pure matmul problem; i.e., +// single dot in the main loop. +static bool isPureMatmulProblem(ModuleOp moduleOp) { + for (auto forOp : moduleOp.getOps()) { + int counter = 0; + forOp.walk([&counter](triton::DotOp dotOp) { ++counter; }); + if (counter != 1) + return false; + } + return true; } // Search through block to find earliest insertion point for move op. This can @@ -61,194 +66,233 @@ findEarlyInsertionPoint(Block *block, Operation *move) { return ipnt; } +// Return the first user in the same block of the given op. If the user is in a +// nested block then return the op owning the block. Return nullptr if not +// existing. +static Operation *getFirstUseInSameBlock(Operation *op) { + SmallVector usersInSameBlock; + for (auto user : op->getUsers()) { + if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) + usersInSameBlock.push_back(ancestor); + } + auto minOpIt = + llvm::min_element(usersInSameBlock, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + return minOpIt != usersInSameBlock.end() ? *minOpIt : nullptr; +} + // Check if the operation opInsideLoop is inside any scf::ForOp and // opOutsideLoop is not inside the same loop. -bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, - mlir::Operation *opOutsideLoop) { +static bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, + mlir::Operation *opOutsideLoop) { scf::ForOp parentForOp = opInsideLoop->getParentOfType(); return parentForOp && !parentForOp->isAncestor(opOutsideLoop); } -class TritonAMDGPUReorderInstructionsPass - : public TritonAMDGPUReorderInstructionsBase< - TritonAMDGPUReorderInstructionsPass> { -public: - TritonAMDGPUReorderInstructionsPass() = default; - - Operation *getFirstUse(Operation *op) { - std::vector users; - for (auto user : op->getUsers()) { - if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) - users.push_back(ancestor); - } - auto minOpIt = std::min_element(users.begin(), users.end(), - [](mlir::Operation *a, mlir::Operation *b) { - return a->isBeforeInBlock(b); - }); - return minOpIt != users.end() ? *minOpIt : nullptr; - } +//===----------------------------------------------------------------------===// +// Reorder mechanisms +//===----------------------------------------------------------------------===// - void runOnOperation() override { - ModuleOp m = getOperation(); +// Sink dot layout conversions into loops to decrease register pressure when +// possible. +static void sinkDotConversion(ModuleOp moduleOp) { + DenseMap opToMove; + moduleOp.walk([&](ttg::ConvertLayoutOp op) { + Attribute encoding = op.getType().getEncoding(); + if (!isa_and_nonnull(encoding)) + return; + if (!op->hasOneUse()) + return; + Operation *user = *op->getUsers().begin(); + if (user->getParentOfType() == + op->getParentOfType()) + return; + opToMove[op] = user; + }); - // Sink shared memory loads and layout conversions into loops to decrease - // register pressure when possible. - DenseMap opToMove; - m.walk([&](Operation *op) { - if (!isLocalLoadOrDotLayoutConversion(op)) - return; - if (!op->hasOneUse()) - return; - Operation *user = *op->getUsers().begin(); - if (user->getParentOfType() == - op->getParentOfType()) + for (auto &kv : opToMove) + kv.first->moveBefore(kv.second); +} + +// Adjust the placement of shared memory writes and reads to immediately follow +// the definition of their operands in case where shared memory write is in the +// loop but its operand is not. +// +// This is a heuristic driven by optimizing fused attention by hoisting Q tensor +// shared memory read/write operations outside of the loop, as Q is a loop +// invariant and can be loaded once before entering the loop. But it should be +// generally applicable. +// +// There are two possible patterns for this adjustment depending on whether the +// write to shared memory is performed using an optional `local_alloc` argument +// or a `local_store` instruction. +// +// 1) %1 = some_op ... (typically a load or an operation that scales the tensor +// after loading) +// %2 = local_alloc %1 +// %3 = local_load %2 +// +// 2) %1 = some_op ... +// %2 = local_alloc +// %3 = local_store %1, %2 +// %4 = local_load %2 +static void hoistLocalLoad(ModuleOp moduleOp) { + moduleOp.walk([&](ttg::LocalLoadOp localLoad) { + auto localAlloc = localLoad.getSrc().getDefiningOp(); + if (!localAlloc) + return; + + // Case when localAlloc has operands + if (localAlloc->getNumOperands() == 1) { + if (!localAlloc->hasOneUse()) return; - opToMove.insert({op, user}); - }); - for (auto &kv : opToMove) - kv.first->moveBefore(kv.second); - opToMove.clear(); - - // Adjust the placement of LDS writes and reads to immediately follow the - // definition of their operands in case where LDS write is in the - // loop but it's operand is not. This is a heuristic for optimizing fused - // attention by hoisting Q tensor LDS read/write operations outside of the - // loop, as Q is a loop invariant and can be loaded once before entering the - // loop. - // There are two possible patterns for this adjustment depending on - // whether the write to LDS is performed using an optional `local_alloc` - // argument or a `local_store` instruction. - // - // clang-format off - // - // 1) %1 = some_op ... (typically a load or an operation that scales the tensor after loading) - // %2 = local_alloc %1 - // %3 = local_load %2 - // - // 2) %1 = some_op ... - // %2 = local_alloc - // %3 = local_store %1, %2 - // %4 = local_load %2 - // - // clang-format on - m.walk([&](ttg::LocalLoadOp localLoad) { - auto localAlloc = localLoad.getSrc().getDefiningOp(); - if (!localAlloc) + + auto srcTensorOp = localAlloc.getSrc().getDefiningOp(); + // Check if localAlloc is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) return; - // Case when localAlloc has operands - if (localAlloc->getNumOperands() == 1) { - if (!localAlloc->hasOneUse()) - return; + localAlloc->moveAfter(srcTensorOp); + localLoad->moveAfter(localAlloc); + return; + } - auto srcTensorOp = localAlloc->getOperand(0).getDefiningOp(); - // Check if localAlloc is in the loop but it's src tensor defining op is - // outside of it. - if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) { - return; - } + // Case when localAlloc has no operands + assert(localAlloc->getNumOperands() < 1); + auto allocVal = localAlloc->getResult(0); - localAlloc->moveAfter(srcTensorOp); - localLoad->moveAfter(localAlloc); - return; - } + // Check if the localAlloc has exactly two uses (localStore and localLoad) + int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); + if (numUses != 2) + return; - // Case when localAlloc has no operands - assert(localAlloc->getNumOperands() < 1); - auto allocVal = localAlloc->getResult(0); + // localStore comes before localLoad in block. + Operation *localStore = getFirstUseInSameBlock(localAlloc); + if (!isa(localStore)) + return; - // Check if the localAlloc has exactly two uses (localStore and localLoad) - int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); - if (numUses != 2) - return; + auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); + // Check if localStore is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { + return; + } - // localStore comes before localLoad in block. - Operation *localStore = getFirstUse(localAlloc); - if (!isa(localStore)) - return; + localAlloc->moveAfter(srcTensorOp); + localStore->moveAfter(localAlloc); + localLoad->moveAfter(localStore); + }); +} - auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); - // Check if localStore is in the loop but it's src tensor defining op is - // outside of it. - if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { - return; - } +// Sink conversion after the last dealloc but before the first use in its block. +// This helps to avoid unnecessary shared memory allocation. +static void moveDownCoversion(ModuleOp moduleOp) { + SmallVector convertOps; + moduleOp.walk([&](ttg::ConvertLayoutOp op) { convertOps.push_back(op); }); - localAlloc->moveAfter(srcTensorOp); - localStore->moveAfter(localAlloc); - localLoad->moveAfter(localStore); - }); + for (auto op : convertOps) { + Operation *user = getFirstUseInSameBlock(op); + for (auto it = Block::iterator(op), ie = op->getBlock()->end(); + it != ie && &*it != user; ++it) + if (isa(&*it)) + op->moveAfter(&*it); + } +} - // Sink conversion after the last dealloc but before the first use ancestor - // in its block. This helps to avoid unnecessary shared memory allocation. - m.walk([&](triton::gpu::ConvertLayoutOp op) { - auto curr = mlir::Block::iterator(op); - for (; &*curr != getFirstUse(op); curr++) - if (isa(&*curr)) - op->moveAfter(&*curr); - }); +// Move transpositions just after their definition. +static void moveUpTranspose(ModuleOp moduleOp) { + SmallVector transOps; + moduleOp.walk([&](triton::TransOp op) { transOps.push_back(op); }); - // Move transpositions just after their definition. - m.walk([&](triton::TransOp op) { - if (Operation *argOp = op.getSrc().getDefiningOp()) - op->moveAfter(argOp); - }); + for (auto op : transOps) + if (Operation *argOp = op.getSrc().getDefiningOp()) + op->moveAfter(argOp); +} - SmallVector moveOps; - // Move global loads early to prefetch. This may increase register pressure - // but it enables issuing global loads early. - m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); - // Move local_stores early if dependence distance greater than - // one iteration. - // Best perf on GEMM when these precede global loads. - m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); - - for (auto op : llvm::reverse(moveOps)) { - // Gather use-def chain in block. - Block *block = op->getBlock(); - bool leadsToLoad = false; - SetVector backwardSet; - - BackwardSliceOptions options; - options.omitBlockArguments = true; - options.inclusive = false; - options.filter = [&](Operation *defOp) -> bool { - Block *defBlock = defOp->getBlock(); - if (!block->findAncestorOpInBlock(*defOp)) - return false; - // Check for a `load` dependent path. - leadsToLoad |= isa(defOp); - // Only move ops residing in the same block. - return defBlock == block; - }; - mlir::getBackwardSlice(op, &backwardSet, options); - backwardSet.insert(op); - - // Don't move a local_store if its source is a load from - // the same iteration. - if (isa(op) && leadsToLoad) - continue; - - auto ipoint = findEarlyInsertionPoint(block, op); - // Remove ops that already precede the insertion point. This is done - // before moves happen to avoid `Operation::isBeforeInBlock` N^2 - // complexity. - - SmallVector dfg = backwardSet.takeVector(); - if (ipoint != block->end()) { - // Move ops to insertion point. - llvm::erase_if( - dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); - for (auto *dfgop : llvm::reverse(dfg)) - dfgop->moveAfter(block, ipoint); - } else { - // Move ops to block begin. - for (auto *dfgop : llvm::reverse(dfg)) - dfgop->moveBefore(block, block->begin()); - } +// Schedule global load and local store ops for better GEMM performance. +static void scheduleGlobalLoadLocalStore(ModuleOp m) { + SmallVector moveOps; + // Move global loads early to prefetch. This may increase register pressure + // but it enables issuing global loads early. + m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); + // Move local_stores early if dependence distance greater than one iteration. + // Best perf on GEMM when these precede global loads. + m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); + + for (auto op : llvm::reverse(moveOps)) { + // Gather use-def chain in block. + Block *block = op->getBlock(); + bool leadsToLoad = false; + SetVector backwardSet; + + BackwardSliceOptions options; + options.omitBlockArguments = true; + options.inclusive = false; + options.filter = [&](Operation *defOp) -> bool { + Block *defBlock = defOp->getBlock(); + if (!block->findAncestorOpInBlock(*defOp)) + return false; + // Check for a `load` dependent path. + leadsToLoad |= isa(defOp); + // Only move ops residing in the same block. + return defBlock == block; + }; + mlir::getBackwardSlice(op, &backwardSet, options); + backwardSet.insert(op); + + // Don't move a local_store if its source is a load from + // the same iteration. + if (isa(op) && leadsToLoad) + continue; + + auto ipoint = findEarlyInsertionPoint(block, op); + // Remove ops that already precede the insertion point. This is done + // before moves happen to avoid `Operation::isBeforeInBlock` N^2 + // complexity. + + SmallVector dfg = backwardSet.takeVector(); + if (ipoint != block->end()) { + // Move ops to insertion point. + llvm::erase_if( + dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveAfter(block, ipoint); + } else { + // Move ops to block begin. + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveBefore(block, block->begin()); } } +} + +//===----------------------------------------------------------------------===// +// Pass definition +//===----------------------------------------------------------------------===// + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +namespace { +struct TritonAMDGPUReorderInstructionsPass + : public TritonAMDGPUReorderInstructionsBase< + TritonAMDGPUReorderInstructionsPass> { + void runOnOperation() override { + ModuleOp m = getOperation(); + + hoistLocalLoad(m); + + sinkDotConversion(m); + moveDownCoversion(m); + + moveUpTranspose(m); + + if (isPureMatmulProblem(m)) + scheduleGlobalLoadLocalStore(m); + } }; +} // namespace std::unique_ptr mlir::createTritonAMDGPUReorderInstructionsPass() { return std::make_unique(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 75f9354104b1..d1cef15a354e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -93,20 +93,12 @@ static std::optional matchReduxKind(triton::ReduceOp op, int computeCapability) { if (computeCapability < 80) return std::nullopt; - if (op.getNumOperands() != 1 || op.getNumResults() != 1) - return std::nullopt; - Block *block = &(*op.getCombineOp().begin()); - Operation *yield = block->getTerminator(); - Operation *reduceOp = yield->getOperand(0).getDefiningOp(); - if (!reduceOp || reduceOp->getNumOperands() != 2 || - reduceOp->getNumResults() != 1) + Operation *reduceOp = op.getSingleCombiner(); + if (!reduceOp) return std::nullopt; auto intType = dyn_cast(reduceOp->getResultTypes()[0]); if (!intType || intType.getWidth() > 32) return std::nullopt; - if (reduceOp->getOperand(0) != block->getArgument(0) || - reduceOp->getOperand(1) != block->getArgument(1)) - return std::nullopt; if (isa(reduceOp)) return NVVM::ReduxKind::ADD; if (isa(reduceOp))