From 4b50c4891bbfa0fdf84a4d3bf4147e4fea7f2f35 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Thu, 17 Oct 2024 08:38:54 +0000 Subject: [PATCH 01/11] Revert "[AMD] revert optimizations (#4919)" This reverts commit 93de4266fd2be5ee053d87bdd00596b08b1debb5. --- bin/RegisterTritonDialects.h | 2 + .../PatternTritonGPUOpToLLVM.h | 26 +- .../Conversion/TritonGPUToLLVM/Utility.h | 10 +- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 36 +- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 8 +- test/TritonGPU/amd/amd-instruction-sched.mlir | 148 +++++++ .../TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td | 27 ++ .../TritonAMDGPU/IR/TritonAMDGPUDialect.td | 3 + .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 22 +- .../amd/include/TritonAMDGPUToLLVM/Passes.h | 4 +- .../amd/include/TritonAMDGPUToLLVM/Passes.td | 14 +- .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 6 + .../SharedToDotOperandMFMA.cpp | 11 +- .../SharedToDotOperandWMMA.cpp | 11 +- .../TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp | 10 +- .../TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp | 5 + .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 6 +- .../TritonAMDGPUToLLVM/SchedInstructions.cpp | 388 ++++++++++++++++-- .../TritonAMDGPUToLLVM/SchedInstructions.h | 22 + .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 15 +- third_party/amd/python/triton_amd.cc | 4 +- 21 files changed, 715 insertions(+), 63 deletions(-) create mode 100644 test/TritonGPU/amd/amd-instruction-sched.mlir create mode 100644 third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index d5eb81eb9f4a..f094ce963a5a 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -63,6 +63,8 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonAMDGPUStreamPipelineV2(); mlir::registerTritonAMDGPUCanonicalizePointers(); mlir::registerTritonAMDGPUConvertToBufferOps(); + mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); + mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); // TODO: register Triton & TritonGPU passes registry.insert + localStoreOpConversion = nullptr; +}; + void populateElementwiseOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, PatternBenefit benefit); -void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter, - const TargetInfoBase &targetInfo, - RewritePatternSet &patterns, - PatternBenefit benefit); +// The given callback is invoked at the end of a successful rewrite. The +// callback receives 1) the current source op, 2) the number of issued LLVM +// instructions and 3) their input types. Each MLIR backend can provide a +// callback and, thus, handle backend-specific behaviors. +void populateMemoryOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit, + std::optional backendCallbacks = std::nullopt); void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 56a82d7cc0fb..a3d8fe9e64ba 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1366,11 +1366,11 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, Location loc, RewriterBase &rewriter, const TargetInfoBase &target); -void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, - Type elemLlvmTy, ArrayRef srcVals, - Value smemBase, ArrayRef dstStrides, - Location loc, RewriterBase &rewriter, - const TargetInfoBase &target); +void storeDistributedToShared( + MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, + ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, + Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + std::pair *const llvmOpCount = nullptr); inline Value getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index e2ed0228de8d..38fa1bd62343 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -15,12 +15,11 @@ using namespace mlir::triton::gpu; // blocked -> shared. // Swizzling in shared memory to avoid bank conflict. Normally used for // A/B operands of dots. -void lowerDistributedToShared(Location loc, Value src, Value dst, - Value adaptorSrc, - const SharedMemoryObject &smemObj, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, - const TargetInfoBase &targetInfo) { +void lowerDistributedToShared( + Location loc, Value src, Value dst, Value adaptorSrc, + const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo, + std::pair *const llvmOpCount = nullptr) { auto srcTy = cast(src.getType()); auto dstTy = cast(dst.getType()); auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); @@ -33,7 +32,7 @@ void lowerDistributedToShared(Location loc, Value src, Value dst, auto dstStrides = smemObj.getStrides(); auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides, - loc, rewriter, targetInfo); + loc, rewriter, targetInfo, llvmOpCount); } struct LocalAllocOpConversion @@ -200,12 +199,15 @@ struct LocalStoreOpConversion public: using ConvertOpToLLVMPattern< triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern; + using BackendCallbackType = + decltype(BackendCallbacks::localStoreOpConversion); LocalStoreOpConversion(const LLVMTypeConverter &converter, const TargetInfoBase &targetInfo, + BackendCallbackType backendCallback, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(converter, benefit), - targetInfo(targetInfo) {} + targetInfo(targetInfo), backendCallback(backendCallback) {} LogicalResult matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, @@ -215,24 +217,36 @@ struct LocalStoreOpConversion getTypeConverter()->convertType(op.getDst().getType().getElementType()); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter); + + std::pair llvmOpCount; lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(), adaptor.getSrc(), smemObj, getTypeConverter(), - rewriter, targetInfo); + rewriter, targetInfo, &llvmOpCount); + + if (backendCallback) + (backendCallback)(op, llvmOpCount.first, llvmOpCount.second); + rewriter.eraseOp(op); return success(); } private: const TargetInfoBase &targetInfo; + BackendCallbackType backendCallback; }; } // namespace void mlir::triton::populateMemoryOpToLLVMPattern( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, - RewritePatternSet &patterns, PatternBenefit benefit) { + RewritePatternSet &patterns, PatternBenefit benefit, + std::optional backendCallbacks) { patterns.add(typeConverter, targetInfo, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, targetInfo, benefit); - patterns.add(typeConverter, targetInfo, benefit); + + auto backendCall = + backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr; + patterns.add(typeConverter, targetInfo, backendCall, + benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index e857dd36f6cb..67954e5daede 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -404,7 +404,8 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target) { + const TargetInfoBase &target, + std::pair *const llvmOpCount) { bool success = emitTransferBetweenRegistersAndShared( srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase, dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { @@ -418,7 +419,12 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, store(vec, vecAddr) .setAlignment(vecTy.getNumElements() * elemLlvmTy.getIntOrFloatBitWidth() / 8); + if (llvmOpCount) { + ++(llvmOpCount->first); + llvmOpCount->second = vecTy; + } }); + if (!success) llvm::report_fatal_error("Failed to emit transfer from register to shared"); } diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir new file mode 100644 index 000000000000..bca502f980cb --- /dev/null +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -0,0 +1,148 @@ +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -allocate-shared-memory -convert-scf-to-cf -convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s -check-prefix=INSTR_INSERTION +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -allocate-shared-memory -convert-scf-to-cf -convert-triton-amdgpu-to-llvm=arch=gfx942 -triton-amdgpu-lower-insert-instruction-sched-hints=variant="iglp0" | FileCheck %s -check-prefix=LOWER_IGLP0 + +#shared0_ex0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#mma0_ex0 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = false}> + +#blocked0_ex1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1_ex1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked2_ex1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared0_ex1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1_ex1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#mma0_ex1 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = false}> +#dot0_ex1 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma0_ex1, kWidth = 8}> +#dot1_ex1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma0_ex1, kWidth = 8}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // LOWER_IGLP0-LABEL: test_instruction_hints_lowering + tt.func @test_instruction_hints_lowering( + %arg0: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0_ex0, kWidth = 16}>>, + %arg1: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0_ex0, kWidth = 16}>>, + %arg2: tensor<32x32xf16, #mma0_ex0>) { + + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c64_i32 = arith.constant 1 : i32 + + scf.for %arg11 = %c0_i32 to %c64_i32 step %c1_i32 iter_args() -> () : i32 { + // LOWER_IGLP0: llvm.add + // LOWER_IGLP0-NEXT: %[[OPT_LEVEL:.*]] = llvm.mlir.constant(0 : i32) : i32 + // LOWER_IGLP0-NEXT: llvm.call_intrinsic "llvm.amdgcn.iglp.opt"(%[[OPT_LEVEL]]) : (i32) -> () + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0_ex0, kWidth = 16}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0_ex0, kWidth = 16}>> -> tensor<32x32xf16, #mma0_ex0> + scf.yield + } + tt.return + } + + // INSTR_INSERTION-LABEL: @test_llvm_instruction_count + tt.func public @test_llvm_instruction_count( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + + %cst = arith.constant dense<64> : tensor<256x64xi32, #blocked0_ex1> + %cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked1_ex1> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c64_i32 = arith.constant 64 : i32 + %c63_i32 = arith.constant 63 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + + %19 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>> + %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>> + %21 = tt.splat %c256_i32 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>> + %22 = tt.splat %c256_i32 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>> + %23 = arith.addi %21, %19 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>> + %24 = arith.addi %22, %20 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>> + + %26 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>> + %27 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>> + %28 = tt.splat %c128_i32 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>> + %29 = tt.splat %c128_i32 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>> + %30 = arith.addi %28, %26 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>> + %31 = arith.addi %29, %27 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>> + %32 = tt.expand_dims %23 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>> -> tensor<256x1xi32, #blocked0_ex1> + %33 = tt.expand_dims %24 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>> -> tensor<256x1xi32, #blocked2_ex1> + %34 = tt.splat %c64_i32 : i32 -> tensor<256x1xi32, #blocked0_ex1> + %35 = arith.muli %32, %34 : tensor<256x1xi32, #blocked0_ex1> + %36 = tt.splat %arg0 : !tt.ptr -> tensor<256x1x!tt.ptr, #blocked0_ex1> + %37 = tt.addptr %36, %35 : tensor<256x1x!tt.ptr, #blocked0_ex1>, tensor<256x1xi32, #blocked0_ex1> + %38 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0_ex1}>> + %39 = tt.expand_dims %38 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0_ex1}>> -> tensor<1x64xi32, #blocked0_ex1> + %40 = tt.broadcast %37 : tensor<256x1x!tt.ptr, #blocked0_ex1> -> tensor<256x64x!tt.ptr, #blocked0_ex1> + %41 = tt.broadcast %39 : tensor<1x64xi32, #blocked0_ex1> -> tensor<256x64xi32, #blocked0_ex1> + %42 = tt.addptr %40, %41 : tensor<256x64x!tt.ptr, #blocked0_ex1>, tensor<256x64xi32, #blocked0_ex1> + + %43 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1_ex1}>> + %44 = tt.expand_dims %43 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1_ex1}>> -> tensor<64x1xi32, #blocked1_ex1> + %45 = tt.splat %arg1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1_ex1> + %46 = tt.addptr %45, %44 : tensor<64x1x!tt.ptr, #blocked1_ex1>, tensor<64x1xi32, #blocked1_ex1> + %47 = tt.expand_dims %30 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>> -> tensor<1x128xi32, #blocked1_ex1> + %48 = tt.expand_dims %31 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>> -> tensor<1x128xi32, #blocked2_ex1> + %49 = tt.splat %c64_i32 : i32 -> tensor<1x128xi32, #blocked1_ex1> + %50 = arith.muli %47, %49 : tensor<1x128xi32, #blocked1_ex1> + %51 = tt.broadcast %46 : tensor<64x1x!tt.ptr, #blocked1_ex1> -> tensor<64x128x!tt.ptr, #blocked1_ex1> + %52 = tt.broadcast %50 : tensor<1x128xi32, #blocked1_ex1> -> tensor<64x128xi32, #blocked1_ex1> + %53 = tt.addptr %51, %52 : tensor<64x128x!tt.ptr, #blocked1_ex1>, tensor<64x128xi32, #blocked1_ex1> + + %56 = triton_gpu.local_alloc : () -> !tt.memdesc<1x256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> + %57 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> + + %cst_1 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma0_ex1> + + %cc0_i1 = arith.constant 1 : i1 + %59 = tt.splat %cc0_i1 : i1 -> tensor<256x64xi1, #blocked0_ex1> + %60 = tt.load %42, %59 : tensor<256x64x!tt.ptr, #blocked0_ex1> + %61 = tt.splat %cc0_i1 : i1 -> tensor<64x128xi1, #blocked1_ex1> + %62 = tt.load %53, %61 : tensor<64x128x!tt.ptr, #blocked1_ex1> + + %63 = triton_gpu.memdesc_subview %56[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %60, %63 : tensor<256x64xf16, #blocked0_ex1> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> + %64 = triton_gpu.memdesc_subview %57[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %62, %64 : tensor<64x128xf16, #blocked1_ex1> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> + + %66:5 = scf.for %arg11 = %c0_i32 to %c63_i32 step %c1_i32 iter_args( + %arg12 = %cst_1, + %arg13 = %42, + %arg14 = %53, + %arg16 = %63, + %arg17 = %64) -> ( + tensor<256x128xf32, #mma0_ex1>, + tensor<256x64x!tt.ptr, #blocked0_ex1>, + tensor<64x128x!tt.ptr, #blocked1_ex1>, + !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable>, + !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable>) : i32 { + + %82 = triton_gpu.local_load %arg16 : !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dot0_ex1> + %83 = triton_gpu.local_load %arg17 : !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> -> tensor<64x128xf16, #dot1_ex1> + + // INSTR_INSERTION: amdgpu.instruction_sched_hint + // INSTR_INSERTION-SAME: numDsReadsA = #amdgpu.InstCounter<16, vector<8xf16>> + // INSTR_INSERTION-SAME: numDsReadsB = #amdgpu.InstCounter<8, vector<8xf16>> + // INSTR_INSERTION-SAME: numDsWritesA = #amdgpu.InstCounter<8, vector<8xf16>> + // INSTR_INSERTION-SAME: numDsWritesB = #amdgpu.InstCounter<4, vector<8xf16>> + // INSTR_INSERTION-SAME: numGlobalLoadsA = #amdgpu.InstCounter<8, vector<8xf16>> + // INSTR_INSERTION-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<8xf16>> + // INSTR_INSERTION-SAME: numMMAs = #amdgpu.InstCounter<64, tensor<32x32x8xf16>> + + %84 = tt.dot %82, %83, %arg12 : tensor<256x64xf16, #dot0_ex1> * tensor<64x128xf16, #dot1_ex1> -> tensor<256x128xf32, #mma0_ex1> + %85 = tt.addptr %arg13, %cst : tensor<256x64x!tt.ptr, #blocked0_ex1>, tensor<256x64xi32, #blocked0_ex1> + %86 = tt.addptr %arg14, %cst_0 : tensor<64x128x!tt.ptr, #blocked1_ex1>, tensor<64x128xi32, #blocked1_ex1> + %87 = tt.load %85 : tensor<256x64x!tt.ptr, #blocked0_ex1> + %88 = tt.load %86 : tensor<64x128x!tt.ptr, #blocked1_ex1> + %89 = triton_gpu.memdesc_subview %56[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %87, %89 : tensor<256x64xf16, #blocked0_ex1> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> + %90 = triton_gpu.memdesc_subview %57[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %88, %90 : tensor<64x128xf16, #blocked1_ex1> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> + + scf.yield %84, %85, %86, %89, %90 : + tensor<256x128xf32, #mma0_ex1>, + tensor<256x64x!tt.ptr, #blocked0_ex1>, + tensor<64x128x!tt.ptr, #blocked1_ex1>, + !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable>, + !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> + } + tt.return + } +} diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td index 31a43acd2f89..c0aa08421bdd 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -32,4 +32,31 @@ class TritonAMDGPU_Attr traits = [], : AttrDef { } +def TritonAMDGPU_OpIdxAttr : TritonAMDGPU_Attr<"OpIdx"> { + let cppNamespace = "::mlir::triton::amdgpu"; + let mnemonic = "OpIdx"; + let summary = "An operand index attribute."; + let description = [{ + The attribute is a way to describe which input argument of the target + operation (e.g., `tt.dot`) the result of a given operation belongs to. + }]; + + let parameters = (ins "uint32_t":$value); + let assemblyFormat = "`<` $value `>`"; +} + +def TritonAMDGPU_InstCounter : TritonAMDGPU_Attr<"InstCounter"> { + let cppNamespace = "::mlir::triton::amdgpu"; + let mnemonic = "InstCounter"; + let summary = "An instruction counter attribute."; + let description = [{ + The attribute holds the number of issued LLVM instructions of a specific kind as well as + the data type. + }]; + + let parameters = (ins "uint32_t":$value, "Type":$type); + let assemblyFormat = "`<` params `>`"; +} + + #endif diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td index d5956cf7a33c..c0c18b07e907 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td @@ -35,6 +35,9 @@ def TritonAMDGPU_Dialect : Dialect { }]; let dependentDialects = []; + + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; } #endif diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 538e31378fe8..494e45819836 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -57,7 +57,27 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { interleave for better instruction level parallelism. }]; - let assemblyFormat = [{attr-dict}]; + let arguments = (ins + TritonAMDGPU_InstCounter:$numDsReadsA, + TritonAMDGPU_InstCounter:$numDsReadsB, + TritonAMDGPU_InstCounter:$numDsWritesA, + TritonAMDGPU_InstCounter:$numDsWritesB, + TritonAMDGPU_InstCounter:$numGlobalLoadsA, + TritonAMDGPU_InstCounter:$numGlobalLoadsB, + TritonAMDGPU_InstCounter:$numMMAs + ); + + let builders = [ + OpBuilder<(ins), [{ + auto ctx = $_state.getContext(); + auto type = IntegerType::get(ctx, 32); + auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, type); + build($_builder, $_state, emptyAttr, emptyAttr, emptyAttr, emptyAttr, + emptyAttr, emptyAttr, emptyAttr); + }]> + ]; + + let assemblyFormat = [{ attr-dict }]; } // diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h index bd726bd845d2..969b357a74f6 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -36,9 +36,9 @@ createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz); std::unique_ptr> createConvertBuiltinFuncToLLVMPass(bool ftz); std::unique_ptr> -createInsertInstructionSchedHintsPass(); +createTritonAMDGPUInsertInstructionSchedHintsPass(); std::unique_ptr> -createLowerInstructionSchedHintsPass(std::string variant); +createTritonAMDGPULowerInstructionSchedHintsPass(std::string variant); #define GEN_PASS_REGISTRATION #include "TritonAMDGPUToLLVM/Passes.h.inc" diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index 9f4665aef217..b9b06e47d217 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -59,18 +59,20 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul ]; } -def InsertInstructionSchedHints : Pass<"insert-instruction-sched-hints", "mlir::ModuleOp"> { +def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruction-sched-hints", "mlir::ModuleOp"> { let summary = "Insert instruction scheduling hints after the dot ops in the main loop"; - let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()"; + let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass()"; - let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect"]; } -def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { +def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { let summary = "Lower instruction scheduling hints to LLVM intrinsics"; - let constructor = "mlir::triton::createLowerInstructionSchedHintsPass(\"\")"; + let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(\"\")"; - let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect"]; let options = [ Option<"variant", "variant", "std::string", /*default*/"\"default\"", diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index a82a77e9f57e..1e429fdc39a9 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -24,6 +24,9 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" + +#include "llvm/ADT/TypeSwitch.h" // clang-format off #include "Dialect/TritonAMDGPU/IR/Dialect.h" @@ -45,5 +48,8 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { >(); } +#define GET_ATTRDEF_CLASSES +#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc" + #define GET_OP_CLASSES #include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc" diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index b832d985bbe7..9043090802bf 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -21,6 +21,7 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "SharedToDotOperandHelper.h" #include "Utility.h" @@ -336,6 +337,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int elemsPerLoad = numOfElems / loadsPerThread; assert(numOfElems % loadsPerThread == 0); + VectorType loadVecTy = vec_ty(elemTy, elemsPerLoad); for (int b = 0; b < repB; ++b) { int operandSize = shape[rank - 1] * shape[rank - 2]; Value batchOffset = mul(i32_val(operandSize), @@ -346,7 +348,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, for (int k = 0; k < numRepK; ++k) { auto vecTy = vec_ty(resElemTy, numOfElems); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { - auto loadVecTy = vec_ty(elemTy, elemsPerLoad); Value loadOffset; loadOffset = offsets[nonK * loadsPerThread * numRepK + k * loadsPerThread + loadId]; @@ -363,6 +364,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, } } + for (auto op : tensor.getUsers()) { + if (auto localLoadOp = llvm::dyn_cast(op)) { + const size_t numDsReadsCount = + repB * numRepNonK * numRepK * loadsPerThread; + setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy); + } + } + MLIRContext *ctx = mfmaLayout.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(loadedValues.size(), loadedValues[0].getType())); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp index b60c86e1a3a5..1ca9e49745d6 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -21,6 +21,7 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "SharedToDotOperandHelper.h" #include "Utility.h" @@ -212,6 +213,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int loadsPerThread = offsets.size() / (numRepNonK * numRepK); int elemsPerLoad = numElemsPerThreadPerRep / loadsPerThread; assert(numElemsPerThreadPerRep % loadsPerThread == 0); + auto loadVecTy = vec_ty(elemTy, elemsPerLoad); for (int b = 0; b < repB; ++b) { int operandSize = shape[rank - 1] * shape[rank - 2]; Value batchOffset = mul(i32_val(operandSize), @@ -221,7 +223,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto vecTy = vec_ty(resElemTy, numElemsPerThreadPerRep); Value valVec = undef(vecTy); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { - auto loadVecTy = vec_ty(elemTy, elemsPerLoad); Value loadOffset = offsets[nonK * loadsPerThread * numRepK + k * loadsPerThread + loadId]; loadOffset = add(loadOffset, batchOffset); @@ -237,6 +238,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, } } + for (auto op : tensor.getUsers()) { + if (auto localLoadOp = llvm::dyn_cast(op)) { + const size_t numDsReadsCount = + repB * numRepNonK * numRepK * loadsPerThread; + setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy); + } + } + MLIRContext *ctx = wmmaLayout.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(loadedValues.size(), loadedValues[0].getType())); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp index 204d54894d3b..1eed112c30c0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -21,9 +21,9 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "TritonAMDGPUTransforms/MfmaGroup.h" #include "Utility.h" - #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" using namespace mlir; @@ -261,6 +261,14 @@ struct DotOpMFMAConversionHelper { Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(fc.size(), dstElemTy)); Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); + + Type elemtTy = elemTyA; + const size_t mmaCount = + numRepB * numRepM * numRepN * numRepK * kWidth / kBase; + setNumGeneratedMMAs(op, mmaCount, maybeMfmaInsn->getMDim(), + maybeMfmaInsn->getNDim(), maybeMfmaInsn->getKDim(), + elemtTy); + rewriter.replaceOp(op, res); return success(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index 5a003f768833..0042cf89e93b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -22,6 +22,7 @@ */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "Utility.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -325,6 +326,10 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, Type structTy = LLVM::LLVMStructType::getLiteral( wmmaLayout.getContext(), SmallVector(fc.size(), dstElemTy)); Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); + + const size_t mmaCount = numRepB * numRepM * numRepN * numRepK; + setNumGeneratedMMAs(op, mmaCount, mnkDim[0], mnkDim[1], mnkDim[2], elemTy); + rewriter.replaceOp(op, res); return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 5265f631ad9e..437b64b438d1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1,6 +1,7 @@ #include "BufferOpsEmitter.h" #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "PatternTritonGPUOpToLLVM.h" +#include "SchedInstructions.h" #include "TargetInfo.h" #include "Utility.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -276,6 +277,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, auto cacheMod = op.getCache(); SmallVector loadedVals; + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { const size_t maxWordWidth = std::max(32, valueElemNBits); const size_t totalWidth = valueElemNBits * vec; @@ -286,7 +288,6 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, assert(wordNElems * nWords * numVecs == numElems); Value pred = mask ? maskElems[vecStart] : int_val(1, 1); - auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]); Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); @@ -391,6 +392,9 @@ struct BufferLoadOpConversion Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, rewriter, llvmResultStructTy); + + setNumGeneratedGlobalLoads(op, numVecs, vecTy); + rewriter.replaceOp(op, {resultStruct}); return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index 9bed87961966..3c30ae7e544c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -1,5 +1,4 @@ #include "TritonAMDGPUToLLVM/Passes.h" - #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" @@ -7,13 +6,77 @@ #include "triton/Dialect/Triton/IR/Dialect.h" namespace mlir::triton { -#define GEN_PASS_DEF_INSERTINSTRUCTIONSCHEDHINTS -#define GEN_PASS_DEF_LOWERINSTRUCTIONSCHEDHINTS +#define GEN_PASS_DEF_TRITONAMDGPUINSERTINSTRUCTIONSCHEDHINTS +#define GEN_PASS_DEF_TRITONAMDGPULOWERINSTRUCTIONSCHEDHINTS #include "TritonAMDGPUToLLVM/Passes.h.inc" } // namespace mlir::triton using namespace mlir; +namespace mlir::triton { +void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, + unsigned k, Type elementType) { + auto *ctx = op->getContext(); + auto mmaType = RankedTensorType::get({m, n, k}, elementType); + auto counterAttr = amdgpu::InstCounterAttr::get(ctx, mmaCount, mmaType); + + op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) { + schedHint.setNumMMAsAttr(counterAttr); + }); +} + +void setNumGeneratedGlobalLoads(triton::LoadOp op, size_t globalLoadsCount, + Type type) { + MLIRContext *ctx = op->getContext(); + auto counterAttr = amdgpu::InstCounterAttr::get(ctx, globalLoadsCount, type); + + op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) { + auto opIdxAttr = + cast(op->getAttr(amdgpu::OpIdxAttr::getMnemonic())); + assert(opIdxAttr.getValue() < 2); + if (opIdxAttr.getValue() == 0) + schedHint.setNumGlobalLoadsAAttr(counterAttr); + else + schedHint.setNumGlobalLoadsBAttr(counterAttr); + }); +} + +void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t dsReadsCount, + Type type) { + auto *ctx = op->getContext(); + auto counterAttr = amdgpu::InstCounterAttr::get(ctx, dsReadsCount, type); + + op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) { + Value dst = op.getResult(); + auto dstTensorTy = cast(dst.getType()); + auto dotOperandLayout = + cast(dstTensorTy.getEncoding()); + const size_t opIdx = dotOperandLayout.getOpIdx(); + assert(opIdx < 2); + if (opIdx == 0) + schedHint.setNumDsReadsAAttr(counterAttr); + else + schedHint.setNumDsReadsBAttr(counterAttr); + }); +} + +void storeOpConversionCallback(triton::gpu::LocalStoreOp op, + size_t localStoreOpCount, Type type) { + MLIRContext *ctx = op->getContext(); + auto counterAttr = amdgpu::InstCounterAttr::get(ctx, localStoreOpCount, type); + + op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) { + auto opIdxAttr = + op->getAttrOfType(amdgpu::OpIdxAttr::getMnemonic()); + assert(opIdxAttr.getValue() < 2); + if (opIdxAttr.getValue() == 0) + schedHint.setNumDsWritesAAttr(counterAttr); + else + schedHint.setNumDsWritesBAttr(counterAttr); + }); +} +} // namespace mlir::triton + namespace { // The bitmask that encodes kinds of the instructions from AMD ISA. @@ -52,7 +115,7 @@ void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc, } // Insert intrinsic that controls the types of instructions that may be -// allowed to cross the intrinsic during instruction scheduling +// allowed to cross the intrinsic during instruction scheduling. Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc, int64_t maskValue) { MLIRContext *ctx = rewriter.getContext(); @@ -78,7 +141,7 @@ Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) { } struct InstructionSchedHintsRewriter - : public OpRewritePattern { + : public OpRewritePattern { InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, std::string variant) : OpRewritePattern(ctx) { @@ -89,13 +152,119 @@ struct InstructionSchedHintsRewriter .Case("default", SchedulingType::NONE) .Case("iglp0", SchedulingType::IGLP0) .Case("iglp1", SchedulingType::IGLP1) + .Case("ck_v3", SchedulingType::CK_V3) .Default(SchedulingType::UNKNOWN); } - enum class SchedulingType : uint32_t { NONE = 0, IGLP0, IGLP1, UNKNOWN }; + enum class SchedulingType : uint32_t { + NONE = 0, + IGLP0, + IGLP1, + CK_V3, + UNKNOWN + }; + + // This is the implementation of the CK's V3 pipelining (see + // see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp). + // This scheduling requires 1x register and 1x LDS buffers combined with the + // local (LDS to registers) and global (HBN to registers) data prefetching. + // see: + // include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.h + void createCKV3Schedule(PatternRewriter &rewriter, Location loc, + amdgpu::InstructionSchedHint schedHint) const { + const uint32_t numDsReadInstA = schedHint.getNumDsReadsA().getValue(); + const uint32_t numDsReadInstB = schedHint.getNumDsReadsB().getValue(); + + const uint32_t numDsWriteInstA = schedHint.getNumDsWritesA().getValue(); + const uint32_t numDsWriteInstB = schedHint.getNumDsWritesB().getValue(); + + const uint32_t numBufferLoadInstA = + schedHint.getNumGlobalLoadsA().getValue(); + const uint32_t numBufferLoadInstB = + schedHint.getNumGlobalLoadsB().getValue(); + + const uint32_t numMfmaInst = schedHint.getNumMMAs().getValue(); + + auto mfmaType = cast(schedHint.getNumMMAs().getType()); + const uint32_t nPerXDL = mfmaType.getShape()[1]; + const uint32_t mfmaCycle = nPerXDL == 16 ? 16 : 32; + + auto dsReadsAType = cast(schedHint.getNumDsReadsA().getType()); + auto dsReadsBType = cast(schedHint.getNumDsReadsB().getType()); + + const uint32_t dsReadAIssueCycle = dsReadsAType.getShape()[0] == 16 ? 8 : 4; + const uint32_t dsReadBIssueCycle = dsReadsBType.getShape()[0] == 16 ? 8 : 4; + + const auto dsReadAMfmaRate = + (mfmaCycle - 4 + 2 * dsReadAIssueCycle - 1) / (2 * dsReadAIssueCycle); + const auto dsReadBMfmaRate = + (mfmaCycle - 4 + 2 * dsReadBIssueCycle - 1) / (2 * dsReadBIssueCycle); + + const auto numDsreadAMfma = + (numDsReadInstA + dsReadAMfmaRate - 1) / dsReadAMfmaRate; + const auto numDsreadBMfma = + (numDsReadInstB + dsReadBMfmaRate - 1) / dsReadBMfmaRate; + + // stage 1 + const auto numMfmaStage1 = numMfmaInst - (numDsreadAMfma + numDsreadBMfma); + const auto num_mfma_per_issue = + numMfmaStage1 / (numBufferLoadInstA + numBufferLoadInstB); + + const auto numDswritePerIssueA = numDsWriteInstA / numBufferLoadInstA; + const auto numDswritePerIssueB = numDsWriteInstB / numBufferLoadInstB; + + for (size_t i = 0; i < numBufferLoadInstA; ++i) { + for (size_t idswrite = 0; idswrite < numDswritePerIssueA; ++idswrite) { + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_WRITE, 1, + 0); + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0); + } + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::VMEM_READ, 1, + 0); + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, + num_mfma_per_issue - numDswritePerIssueA, 0); + } + + for (size_t i = 0; i < numBufferLoadInstB; ++i) { + for (size_t idswrite = 0; idswrite < numDswritePerIssueB; ++idswrite) { + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_WRITE, 1, + 0); + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0); + } + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::VMEM_READ, 1, + 0); + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, + num_mfma_per_issue - numDswritePerIssueB, 0); + } + + // stage 2 + for (size_t i = 0; i < numDsreadAMfma; ++i) { + if ((numDsReadInstA - (i + 1) * dsReadAMfmaRate) >= dsReadAMfmaRate) { + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_READ, + dsReadAMfmaRate, 0); + } else { + createSchedGroupBarrier( + rewriter, loc, InstructionKindMask::DS_READ, + numDsReadInstA - (numDsreadAMfma - 1) * dsReadAMfmaRate, 0); + } + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0); + } + + for (size_t i = 0; i < numDsreadBMfma; ++i) { + if ((numDsReadInstB - (i + 1) * dsReadBMfmaRate) >= dsReadBMfmaRate) { + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_READ, + dsReadBMfmaRate, 0); + } else { + createSchedGroupBarrier( + rewriter, loc, InstructionKindMask::DS_READ, + numDsReadInstB - (numDsreadBMfma - 1) * dsReadBMfmaRate, 0); + } + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0); + } + } LogicalResult - matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, + matchAndRewrite(amdgpu::InstructionSchedHint instructionSchedHint, PatternRewriter &rewriter) const override { if (this->schedulingType == SchedulingType::UNKNOWN) { @@ -110,7 +279,8 @@ struct InstructionSchedHintsRewriter // not supposed to be used together with IGLP OPT according to the AMDGPU // backend documentation. const bool limitSchedulingRange = - !(schedulingType == SchedulingType::IGLP0 || + !(schedulingType == SchedulingType::NONE || + schedulingType == SchedulingType::IGLP0 || schedulingType == SchedulingType::IGLP1); Location loc = instructionSchedHint->getLoc(); Block *block = instructionSchedHint->getBlock(); @@ -128,6 +298,10 @@ struct InstructionSchedHintsRewriter createIglpOpt(rewriter, loc, static_cast(schedulingType) - 1); break; } + case SchedulingType::CK_V3: { + createCKV3Schedule(rewriter, loc, instructionSchedHint); + break; + } case SchedulingType::NONE: [[fallthrough]]; default: { @@ -146,11 +320,11 @@ struct InstructionSchedHintsRewriter SchedulingType schedulingType; }; -struct LowerInstructionSchedHints - : public triton::impl::LowerInstructionSchedHintsBase< - LowerInstructionSchedHints> { +struct TritonAMDGPULowerInstructionSchedHints + : public triton::impl::TritonAMDGPULowerInstructionSchedHintsBase< + TritonAMDGPULowerInstructionSchedHints> { - explicit LowerInstructionSchedHints(std::string variant) { + explicit TritonAMDGPULowerInstructionSchedHints(std::string variant) { this->variant = variant; } @@ -160,7 +334,7 @@ struct LowerInstructionSchedHints ConversionTarget target(*ctx); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); RewritePatternSet patterns(ctx); patterns.add(ctx, this->variant); @@ -172,32 +346,200 @@ struct LowerInstructionSchedHints } }; -struct InsertInstructionSchedHints - : public triton::impl::InsertInstructionSchedHintsBase< - InsertInstructionSchedHints> { +struct TritonAMDGPUInsertInstructionSchedHints + : public triton::impl::TritonAMDGPUInsertInstructionSchedHintsBase< + TritonAMDGPUInsertInstructionSchedHints> { void runOnOperation() override { MLIRContext *ctx = &getContext(); ModuleOp mod = getOperation(); - mod->walk([ctx](triton::DotOp dot) { - if (dyn_cast(dot->getParentOp())) { + mod.walk([this, ctx](scf::ForOp forOp) { + triton::DotOp dot = nullptr; + size_t dotCounter = 0; + forOp->walk([&dot, &dotCounter](triton::DotOp op) { + dot = op; + ++dotCounter; + }); + // Note, instruction schedule barriers are inserted only in the case of + // a single `tt.dot` op in a `scf::ForOp` scope in the current + // implementation. + if (dotCounter == 1) { mlir::OpBuilder rewriter(ctx); rewriter.setInsertionPointAfter(dot); - rewriter.create(dot->getLoc()); + rewriter.create(dot->getLoc()); + annotateDotUsageOnLoadStore(forOp); } }); } + + template bool isOf(Operation *op) const { + return llvm::isa(op); + } + + template + llvm::SmallVector getUsersOfTypes(Value value) const { + llvm::SmallVector concreteUsers; + for (auto user : value.getUsers()) { + std::vector values = {(isOf(user), ...)}; + if (llvm::any_of(values, [](bool value) { return value; })) + concreteUsers.push_back(user); + } + return concreteUsers; + } + + template + llvm::SmallVector getUsersOfType(Value value) const { + auto users = getUsersOfTypes(value); + llvm::SmallVector concreteUsers; + for (auto user : getUsersOfTypes(value)) { + concreteUsers.push_back(cast(user)); + } + return concreteUsers; + } + + // Go through a single use chain of `convert_layout` and/or `fp_to_fp` Ops to + // get the final value after all conversions + Value rewindUnaryOps(Value value) const { + auto unaryOps = + getUsersOfTypes(value); + while (!unaryOps.empty()) { + assert(unaryOps.size() == 1); + value = unaryOps[0]->getResult(0); + unaryOps = + getUsersOfTypes( + value); + } + return value; + } + + // Given a `scf::ForOp`, the method finds and annotates all Ops which produce + // input values for the `tt.dot` operation. The algorithm handles software + // pipelining. Therefore, we start by tracking `tt.load` ops and unwind the + // data flow by looking up to the yielded values and iteration arguments of a + // given `scf::ForOp` till we find `ttg.local_store` Op. Once a + // `ttg.local_store` Op is found, we need a single yielded-arguments lookup to + // find the corresponding `ttg.local_load` Op from which we have a direct data + // flow path to the target `tt.dot` op. At this point, we can annotate all + // found Ops (i.e., `tt.load`, `ttg.local_store`) with the input argument + // index of the data to `tt.dot` Op. Here is an example of the resulting + // annotated TTGIR: + // + // %13:8 = scf.for %arg11 = %c0_i32 to %0 step %c1_i32 iter_args( + // %arg0 = %cst_1, %arg1 = %in_0, %arg2 = %in_1, %arg3 = %c0_i32, + // %arg4 = %in_2, %arg5 = %in_3, %arg6 = %in_4, %arg7 = %in_5) + // -> (...) : i32 { + // %1 = triton_gpu.local_load %arg4 : {OpIdx = 0} + // %2 = triton_gpu.local_load %arg5 : {OpIdx = 1} + // %3 = tt.dot %1, %2, %arg0 + // %4 = tt.addptr %arg1, %cst + // %5 = tt.addptr %arg2, %cst_0 + // %6 = tt.load %4 : {OpIdx = 0} + // %7 = tt.load %5 : {OpIdx = 1} + // %8 = arith.addi %arg3, %c1_i32 + // %9 = arith.cmpi slt, %8, %c2_i32 + // %10 = arith.select %9, %8, %c0_i32 + // %11 = triton_gpu.memdesc_subview %56[%10, %c0_i32, %c0_i32] + // triton_gpu.local_store %arg6, %11 : {OpIdx = 0} + // %12 = triton_gpu.memdesc_subview %57[%10, %c0_i32, %c0_i32] + // triton_gpu.local_store %arg7, %12 : {OpIdx = 1} + // scf.yield %3, %4, %5, %10, %11, %12, %6, %7 : (...) + // } + // + // Note, this is required for counting issued `llvm` instructions during + // lowering from TTGIR to LLVM dialects to perform advanced instruction + // scheduling. + void annotateDotUsageOnLoadStore(scf::ForOp forOp) const { + llvm::SmallVector loadOps; + forOp.walk( + [&loadOps](triton::LoadOp loadOp) { loadOps.push_back(loadOp); }); + + ValueRange yieldedValues = forOp.getYieldedValues(); + auto initArgs = forOp.getRegionIterArgs(); + + MLIRContext *ctx = forOp->getContext(); + mlir::OpBuilder rewriter(ctx); + + for (auto loadOp : loadOps) { + Value loadResult = loadOp.getResult(); + + // Unwind till the first carried loop iteration regarding `tt.load`. + Value loopCarriedLoadValue = loadResult; + bool foundFirstCarriedLoopIteration = false; + while (!foundFirstCarriedLoopIteration) { + auto it = llvm::find(yieldedValues, loopCarriedLoadValue); + if (it != yieldedValues.end()) { + size_t idx = std::distance(yieldedValues.begin(), it); + loopCarriedLoadValue = initArgs[idx]; + } else { + foundFirstCarriedLoopIteration = true; + } + } + + loopCarriedLoadValue = rewindUnaryOps(loopCarriedLoadValue); + assert(loopCarriedLoadValue.hasOneUse()); + + // Handle pipelining - i.e., `local_store`, `memdesc_subview`, + // `local_load` ops. + triton::gpu::LocalLoadOp localLoadOp = nullptr; + auto loadOpUser = *(loopCarriedLoadValue.user_begin()); + auto localStoreOp = llvm::dyn_cast(loadOpUser); + if (localStoreOp) { + auto subviewOp = localStoreOp.getDst() + .getDefiningOp(); + Value subviewResult = subviewOp.getResult(); + auto it = llvm::find(yieldedValues, subviewResult); + if (it != yieldedValues.end()) { + size_t idx = std::distance(yieldedValues.begin(), it); + Value loopCarriedSubviewValue = initArgs[idx]; + + auto subviewLoadOps = + getUsersOfType(loopCarriedSubviewValue); + assert(subviewLoadOps.size() == 1); + localLoadOp = *subviewLoadOps.begin(); + + loopCarriedLoadValue = localLoadOp.getResult(); + } else { + auto localLoadOps = + getUsersOfType(subviewResult); + assert(localLoadOps.size() == 1); + localLoadOp = *localLoadOps.begin(); + auto it = llvm::find(yieldedValues, localLoadOp.getResult()); + assert(it != yieldedValues.end()); + size_t idx = std::distance(yieldedValues.begin(), it); + loopCarriedLoadValue = initArgs[idx]; + } + loopCarriedLoadValue = rewindUnaryOps(loopCarriedLoadValue); + } + + // Find the corresponding `DotOp`. + auto dots = getUsersOfType(loopCarriedLoadValue); + assert(dots.size() == 1); + + // Find which `DotOp` argument the current `loadOp` belongs to. + auto dotOperands = dots.begin()->getOperands(); + auto it = llvm::find(dotOperands, loopCarriedLoadValue); + assert(it != dotOperands.end()); + size_t opIdx = std::distance(dotOperands.begin(), it); + + // Set `OpIdx` attributes. + auto opIdxAttr = amdgpu::OpIdxAttr::get(ctx, opIdx); + + loadOp->setAttr(amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); + if (localStoreOp) + localStoreOp->setAttr(amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); + } + } }; } // namespace namespace mlir::triton { std::unique_ptr> -createLowerInstructionSchedHintsPass(std::string variant) { - return std::make_unique(variant); +createTritonAMDGPULowerInstructionSchedHintsPass(std::string variant) { + return std::make_unique(variant); } std::unique_ptr> -createInsertInstructionSchedHintsPass() { - return std::make_unique(); +createTritonAMDGPUInsertInstructionSchedHintsPass() { + return std::make_unique(); } } // namespace mlir::triton diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h new file mode 100644 index 000000000000..6b81dd0ab2db --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h @@ -0,0 +1,22 @@ +#ifndef TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_SCHED_INSTRUCTIONS_H +#define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_SCHED_INSTRUCTIONS_H + +#include "mlir/IR/Types.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// The following functions are used to collect and set side-channel information +// during to LLVM conversion/lowering to facilitate instruction scheduling +// controls. +namespace mlir::triton { +void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, + unsigned k, Type elementType); +void setNumGeneratedGlobalLoads(triton::LoadOp op, size_t globalLoadsCount, + Type type); +void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t numDsReadsCount, + Type type); +void storeOpConversionCallback(triton::gpu::LocalStoreOp op, size_t llvmOpCount, + Type type); +} // namespace mlir::triton + +#endif diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index d227bb6c6a4b..f99cd50b0d27 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -1,6 +1,7 @@ #include "TritonAMDGPUToLLVM/Passes.h" #include "PatternTritonGPUOpToLLVM.h" +#include "SchedInstructions.h" #include "TargetInfo.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -20,6 +21,7 @@ #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -72,8 +74,9 @@ struct ConvertTritonAMDGPUToLLVM } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } void runOnOperation() override { @@ -193,8 +196,12 @@ struct ConvertTritonAMDGPUToLLVM commonBenefit); populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns, commonBenefit); - mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, targetInfo, - patterns, commonBenefit); + + mlir::triton::BackendCallbacks callbacks; + callbacks.localStoreOpConversion = storeOpConversionCallback; + + mlir::triton::populateMemoryOpToLLVMPattern( + typeConverter, targetInfo, patterns, commonBenefit, callbacks); mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, patterns, commonBenefit); mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index f97676aafe36..6f1df5a43fae 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -45,11 +45,11 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { pm.addPass(createConvertBuiltinFuncToLLVMPass(ftz)); }); m.def("insert_instruction_sched_hints", [](mlir::PassManager &pm) { - pm.addPass(createInsertInstructionSchedHintsPass()); + pm.addPass(createTritonAMDGPUInsertInstructionSchedHintsPass()); }); m.def("lower_instruction_sched_hints", [](mlir::PassManager &pm, std::string variant) { - pm.addPass(createLowerInstructionSchedHintsPass(variant)); + pm.addPass(createTritonAMDGPULowerInstructionSchedHintsPass(variant)); }); m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm, const std::string &arch) { From d53c4999e7871a62f70f9ab314544a04cedf506d Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Mon, 14 Oct 2024 15:24:40 +0000 Subject: [PATCH 02/11] [AMD] use rocdl instr.sched.barriers from upstream MLIR/ROCDL --- .../TritonAMDGPUToLLVM/SchedInstructions.cpp | 169 +++++++++--------- 1 file changed, 82 insertions(+), 87 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index 3c30ae7e544c..b78a4dd49d36 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -1,5 +1,7 @@ #include "TritonAMDGPUToLLVM/Passes.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Pass/Pass.h" #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -18,9 +20,10 @@ void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, unsigned k, Type elementType) { auto *ctx = op->getContext(); auto mmaType = RankedTensorType::get({m, n, k}, elementType); - auto counterAttr = amdgpu::InstCounterAttr::get(ctx, mmaCount, mmaType); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, mmaCount, mmaType); - op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) { + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { schedHint.setNumMMAsAttr(counterAttr); }); } @@ -28,11 +31,12 @@ void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, void setNumGeneratedGlobalLoads(triton::LoadOp op, size_t globalLoadsCount, Type type) { MLIRContext *ctx = op->getContext(); - auto counterAttr = amdgpu::InstCounterAttr::get(ctx, globalLoadsCount, type); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, globalLoadsCount, type); - op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) { - auto opIdxAttr = - cast(op->getAttr(amdgpu::OpIdxAttr::getMnemonic())); + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + auto opIdxAttr = cast( + op->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic())); assert(opIdxAttr.getValue() < 2); if (opIdxAttr.getValue() == 0) schedHint.setNumGlobalLoadsAAttr(counterAttr); @@ -44,9 +48,10 @@ void setNumGeneratedGlobalLoads(triton::LoadOp op, size_t globalLoadsCount, void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t dsReadsCount, Type type) { auto *ctx = op->getContext(); - auto counterAttr = amdgpu::InstCounterAttr::get(ctx, dsReadsCount, type); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, dsReadsCount, type); - op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) { + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { Value dst = op.getResult(); auto dstTensorTy = cast(dst.getType()); auto dotOperandLayout = @@ -63,11 +68,12 @@ void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t dsReadsCount, void storeOpConversionCallback(triton::gpu::LocalStoreOp op, size_t localStoreOpCount, Type type) { MLIRContext *ctx = op->getContext(); - auto counterAttr = amdgpu::InstCounterAttr::get(ctx, localStoreOpCount, type); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, localStoreOpCount, type); - op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) { - auto opIdxAttr = - op->getAttrOfType(amdgpu::OpIdxAttr::getMnemonic()); + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + auto opIdxAttr = op->getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic()); assert(opIdxAttr.getValue() < 2); if (opIdxAttr.getValue() == 0) schedHint.setNumDsWritesAAttr(counterAttr); @@ -79,69 +85,39 @@ void storeOpConversionCallback(triton::gpu::LocalStoreOp op, namespace { -// The bitmask that encodes kinds of the instructions from AMD ISA. -// The bitmask is used for providing instruction scheduling hints. -enum InstructionKindMask { - NONE = 0x0000000, - ALL_ALU = 0x00000001, - VALU = 0x00000002, - SALU = 0x00000004, - MFMA = 0x00000008, - ALL_VMEM = 0x00000010, - VMEM_READ = 0x00000020, - VMEM_WRITE = 0x00000040, - ALL_DS = 0x00000080, - DS_READ = 0x00000100, - DS_WRITE = 0x00000200 -}; - // Create an intrinsic to control how different instruction kinds should // interleave for better ILP. void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc, - InstructionKindMask maskValue, int sizeValue, - int groupIdValue) { - MLIRContext *ctx = rewriter.getContext(); - const char *intrinsicName = "llvm.amdgcn.sched.group.barrier"; - - Value mask = - LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); - Value size = - LLVM::createConstantI32(loc, rewriter, static_cast(sizeValue)); - Value groupId = LLVM::createConstantI32(loc, rewriter, - static_cast(groupIdValue)); - - LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, TypeRange{}, - ValueRange{mask, size, groupId}); + mlir::amdgpu::sched_barrier_opt_enum maskValue, + int sizeValue, int groupIdValue) { + IntegerAttr mask = + rewriter.getI32IntegerAttr(static_cast(maskValue)); + IntegerAttr size = + rewriter.getI32IntegerAttr(static_cast(sizeValue)); + IntegerAttr groupId = + rewriter.getI32IntegerAttr(static_cast(groupIdValue)); + rewriter.create(loc, mask, size, groupId); } // Insert intrinsic that controls the types of instructions that may be // allowed to cross the intrinsic during instruction scheduling. Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc, - int64_t maskValue) { - MLIRContext *ctx = rewriter.getContext(); - const char *intrinsicName = "llvm.amdgcn.sched.barrier"; - LLVM::FastmathFlagsAttr defaultFlags{}; - - Value mask = - LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); - return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, - TypeRange{}, ValueRange{mask}); + mlir::amdgpu::sched_barrier_opt_enum maskValue) { + IntegerAttr mask = + rewriter.getI32IntegerAttr(static_cast(maskValue)); + return rewriter.create(loc, mask); } // Insert an experimental intrinsic for instruction group level parallelism. // The intrinsic takes a value that specifies the strategy. Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) { - MLIRContext *ctx = rewriter.getContext(); - const char *intrinsicName = "llvm.amdgcn.iglp.opt"; - LLVM::FastmathFlagsAttr defaultFlags{}; - Value iglpValue = - LLVM::createConstantI32(loc, rewriter, static_cast(value)); - return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, - TypeRange{}, ValueRange{iglpValue}); + IntegerAttr iglpValue = + rewriter.getI32IntegerAttr(static_cast(value)); + return rewriter.create(loc, iglpValue); } struct InstructionSchedHintsRewriter - : public OpRewritePattern { + : public OpRewritePattern { InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, std::string variant) : OpRewritePattern(ctx) { @@ -170,8 +146,9 @@ struct InstructionSchedHintsRewriter // local (LDS to registers) and global (HBN to registers) data prefetching. // see: // include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.h - void createCKV3Schedule(PatternRewriter &rewriter, Location loc, - amdgpu::InstructionSchedHint schedHint) const { + void + createCKV3Schedule(PatternRewriter &rewriter, Location loc, + triton::amdgpu::InstructionSchedHint schedHint) const { const uint32_t numDsReadInstA = schedHint.getNumDsReadsA().getValue(); const uint32_t numDsReadInstB = schedHint.getNumDsReadsB().getValue(); @@ -215,56 +192,68 @@ struct InstructionSchedHintsRewriter for (size_t i = 0; i < numBufferLoadInstA; ++i) { for (size_t idswrite = 0; idswrite < numDswritePerIssueA; ++idswrite) { - createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_WRITE, 1, - 0); - createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_write, + 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + 1, 0); } - createSchedGroupBarrier(rewriter, loc, InstructionKindMask::VMEM_READ, 1, - 0); - createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, num_mfma_per_issue - numDswritePerIssueA, 0); } for (size_t i = 0; i < numBufferLoadInstB; ++i) { for (size_t idswrite = 0; idswrite < numDswritePerIssueB; ++idswrite) { - createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_WRITE, 1, - 0); - createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_write, + 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + 1, 0); } - createSchedGroupBarrier(rewriter, loc, InstructionKindMask::VMEM_READ, 1, - 0); - createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, num_mfma_per_issue - numDswritePerIssueB, 0); } // stage 2 for (size_t i = 0; i < numDsreadAMfma; ++i) { if ((numDsReadInstA - (i + 1) * dsReadAMfmaRate) >= dsReadAMfmaRate) { - createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_READ, + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_read, dsReadAMfmaRate, 0); } else { createSchedGroupBarrier( - rewriter, loc, InstructionKindMask::DS_READ, + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::ds_read, numDsReadInstA - (numDsreadAMfma - 1) * dsReadAMfmaRate, 0); } - createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0); + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, 1, 0); } for (size_t i = 0; i < numDsreadBMfma; ++i) { if ((numDsReadInstB - (i + 1) * dsReadBMfmaRate) >= dsReadBMfmaRate) { - createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_READ, + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_read, dsReadBMfmaRate, 0); } else { createSchedGroupBarrier( - rewriter, loc, InstructionKindMask::DS_READ, + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::ds_read, numDsReadInstB - (numDsreadBMfma - 1) * dsReadBMfmaRate, 0); } - createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0); + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, 1, 0); } } LogicalResult - matchAndRewrite(amdgpu::InstructionSchedHint instructionSchedHint, + matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, PatternRewriter &rewriter) const override { if (this->schedulingType == SchedulingType::UNKNOWN) { @@ -286,7 +275,8 @@ struct InstructionSchedHintsRewriter Block *block = instructionSchedHint->getBlock(); if (limitSchedulingRange) { rewriter.setInsertionPointToStart(block); - createSchedBarrier(rewriter, loc, InstructionKindMask::NONE); + createSchedBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::none); } rewriter.setInsertionPoint(block, std::prev(block->end())); @@ -310,7 +300,8 @@ struct InstructionSchedHintsRewriter } if (limitSchedulingRange) - createSchedBarrier(rewriter, loc, InstructionKindMask::NONE); + createSchedBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::none); rewriter.eraseOp(instructionSchedHint); return mlir::success(); @@ -334,7 +325,10 @@ struct TritonAMDGPULowerInstructionSchedHints ConversionTarget target(*ctx); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); RewritePatternSet patterns(ctx); patterns.add(ctx, this->variant); @@ -366,7 +360,7 @@ struct TritonAMDGPUInsertInstructionSchedHints if (dotCounter == 1) { mlir::OpBuilder rewriter(ctx); rewriter.setInsertionPointAfter(dot); - rewriter.create(dot->getLoc()); + rewriter.create(dot->getLoc()); annotateDotUsageOnLoadStore(forOp); } }); @@ -522,11 +516,12 @@ struct TritonAMDGPUInsertInstructionSchedHints size_t opIdx = std::distance(dotOperands.begin(), it); // Set `OpIdx` attributes. - auto opIdxAttr = amdgpu::OpIdxAttr::get(ctx, opIdx); + auto opIdxAttr = triton::amdgpu::OpIdxAttr::get(ctx, opIdx); - loadOp->setAttr(amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); + loadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); if (localStoreOp) - localStoreOp->setAttr(amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); + localStoreOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), + opIdxAttr); } } }; From ce029688fc8bebf759994f10a45b6a06457427d7 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Thu, 17 Oct 2024 09:50:34 +0000 Subject: [PATCH 03/11] [AMD] fixed a bug resulted in reverting PR#4919 Replaced temlate-based impl. of `rewindUnaryOps` in `SchedInstructions.cpp` using regular for-loops. The new impl. is more robust and can handle other unary ops automatically. --- .../TritonAMDGPUToLLVM/SchedInstructions.cpp | 53 +++++++++---------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index b78a4dd49d36..110ab9246f39 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -366,42 +366,33 @@ struct TritonAMDGPUInsertInstructionSchedHints }); } - template bool isOf(Operation *op) const { - return llvm::isa(op); - } - - template - llvm::SmallVector getUsersOfTypes(Value value) const { - llvm::SmallVector concreteUsers; - for (auto user : value.getUsers()) { - std::vector values = {(isOf(user), ...)}; - if (llvm::any_of(values, [](bool value) { return value; })) - concreteUsers.push_back(user); - } - return concreteUsers; - } - template llvm::SmallVector getUsersOfType(Value value) const { - auto users = getUsersOfTypes(value); llvm::SmallVector concreteUsers; - for (auto user : getUsersOfTypes(value)) { - concreteUsers.push_back(cast(user)); + for (auto user : value.getUsers()) { + if (auto concreteUser = llvm::dyn_cast(user)) + concreteUsers.push_back(concreteUser); } return concreteUsers; } - // Go through a single use chain of `convert_layout` and/or `fp_to_fp` Ops to - // get the final value after all conversions - Value rewindUnaryOps(Value value) const { - auto unaryOps = - getUsersOfTypes(value); + // Go through a single use chain to get the final value after all unary + // ops - e.g., `convert_layout`, `fp_to_fp`, etc. + FailureOr rewindUnaryOps(Value value) const { + auto findUnaryOps = [](Value value) { + llvm::SmallVector unaryOps; + for (Operation *op : value.getUsers()) { + if (op->getNumOperands() == 1) + unaryOps.push_back(op); + } + return unaryOps; + }; + auto unaryOps = findUnaryOps(value); while (!unaryOps.empty()) { - assert(unaryOps.size() == 1); + if (unaryOps.size() != 1) + return failure(); value = unaryOps[0]->getResult(0); - unaryOps = - getUsersOfTypes( - value); + unaryOps = findUnaryOps(value); } return value; } @@ -469,7 +460,9 @@ struct TritonAMDGPUInsertInstructionSchedHints } } - loopCarriedLoadValue = rewindUnaryOps(loopCarriedLoadValue); + auto maybeRewoundResults = rewindUnaryOps(loopCarriedLoadValue); + assert(succeeded(maybeRewoundResults)); + loopCarriedLoadValue = *maybeRewoundResults; assert(loopCarriedLoadValue.hasOneUse()); // Handle pipelining - i.e., `local_store`, `memdesc_subview`, @@ -502,7 +495,9 @@ struct TritonAMDGPUInsertInstructionSchedHints size_t idx = std::distance(yieldedValues.begin(), it); loopCarriedLoadValue = initArgs[idx]; } - loopCarriedLoadValue = rewindUnaryOps(loopCarriedLoadValue); + maybeRewoundResults = rewindUnaryOps(loopCarriedLoadValue); + assert(succeeded(maybeRewoundResults)); + loopCarriedLoadValue = *maybeRewoundResults; } // Find the corresponding `DotOp`. From 2aecafa74b288c11de873dea9defc3bb17f029c2 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Mon, 21 Oct 2024 15:16:51 +0000 Subject: [PATCH 04/11] [AMD] Moved `annotateDotUsageOnLoadStore` to stream pipeliner --- .../TritonAMDGPUToLLVM/SchedInstructions.cpp | 155 ------------------ .../StreamPipelineV2.cpp | 55 +++++++ 2 files changed, 55 insertions(+), 155 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index 110ab9246f39..b762bcb0e0b0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -361,164 +361,9 @@ struct TritonAMDGPUInsertInstructionSchedHints mlir::OpBuilder rewriter(ctx); rewriter.setInsertionPointAfter(dot); rewriter.create(dot->getLoc()); - annotateDotUsageOnLoadStore(forOp); } }); } - - template - llvm::SmallVector getUsersOfType(Value value) const { - llvm::SmallVector concreteUsers; - for (auto user : value.getUsers()) { - if (auto concreteUser = llvm::dyn_cast(user)) - concreteUsers.push_back(concreteUser); - } - return concreteUsers; - } - - // Go through a single use chain to get the final value after all unary - // ops - e.g., `convert_layout`, `fp_to_fp`, etc. - FailureOr rewindUnaryOps(Value value) const { - auto findUnaryOps = [](Value value) { - llvm::SmallVector unaryOps; - for (Operation *op : value.getUsers()) { - if (op->getNumOperands() == 1) - unaryOps.push_back(op); - } - return unaryOps; - }; - auto unaryOps = findUnaryOps(value); - while (!unaryOps.empty()) { - if (unaryOps.size() != 1) - return failure(); - value = unaryOps[0]->getResult(0); - unaryOps = findUnaryOps(value); - } - return value; - } - - // Given a `scf::ForOp`, the method finds and annotates all Ops which produce - // input values for the `tt.dot` operation. The algorithm handles software - // pipelining. Therefore, we start by tracking `tt.load` ops and unwind the - // data flow by looking up to the yielded values and iteration arguments of a - // given `scf::ForOp` till we find `ttg.local_store` Op. Once a - // `ttg.local_store` Op is found, we need a single yielded-arguments lookup to - // find the corresponding `ttg.local_load` Op from which we have a direct data - // flow path to the target `tt.dot` op. At this point, we can annotate all - // found Ops (i.e., `tt.load`, `ttg.local_store`) with the input argument - // index of the data to `tt.dot` Op. Here is an example of the resulting - // annotated TTGIR: - // - // %13:8 = scf.for %arg11 = %c0_i32 to %0 step %c1_i32 iter_args( - // %arg0 = %cst_1, %arg1 = %in_0, %arg2 = %in_1, %arg3 = %c0_i32, - // %arg4 = %in_2, %arg5 = %in_3, %arg6 = %in_4, %arg7 = %in_5) - // -> (...) : i32 { - // %1 = triton_gpu.local_load %arg4 : {OpIdx = 0} - // %2 = triton_gpu.local_load %arg5 : {OpIdx = 1} - // %3 = tt.dot %1, %2, %arg0 - // %4 = tt.addptr %arg1, %cst - // %5 = tt.addptr %arg2, %cst_0 - // %6 = tt.load %4 : {OpIdx = 0} - // %7 = tt.load %5 : {OpIdx = 1} - // %8 = arith.addi %arg3, %c1_i32 - // %9 = arith.cmpi slt, %8, %c2_i32 - // %10 = arith.select %9, %8, %c0_i32 - // %11 = triton_gpu.memdesc_subview %56[%10, %c0_i32, %c0_i32] - // triton_gpu.local_store %arg6, %11 : {OpIdx = 0} - // %12 = triton_gpu.memdesc_subview %57[%10, %c0_i32, %c0_i32] - // triton_gpu.local_store %arg7, %12 : {OpIdx = 1} - // scf.yield %3, %4, %5, %10, %11, %12, %6, %7 : (...) - // } - // - // Note, this is required for counting issued `llvm` instructions during - // lowering from TTGIR to LLVM dialects to perform advanced instruction - // scheduling. - void annotateDotUsageOnLoadStore(scf::ForOp forOp) const { - llvm::SmallVector loadOps; - forOp.walk( - [&loadOps](triton::LoadOp loadOp) { loadOps.push_back(loadOp); }); - - ValueRange yieldedValues = forOp.getYieldedValues(); - auto initArgs = forOp.getRegionIterArgs(); - - MLIRContext *ctx = forOp->getContext(); - mlir::OpBuilder rewriter(ctx); - - for (auto loadOp : loadOps) { - Value loadResult = loadOp.getResult(); - - // Unwind till the first carried loop iteration regarding `tt.load`. - Value loopCarriedLoadValue = loadResult; - bool foundFirstCarriedLoopIteration = false; - while (!foundFirstCarriedLoopIteration) { - auto it = llvm::find(yieldedValues, loopCarriedLoadValue); - if (it != yieldedValues.end()) { - size_t idx = std::distance(yieldedValues.begin(), it); - loopCarriedLoadValue = initArgs[idx]; - } else { - foundFirstCarriedLoopIteration = true; - } - } - - auto maybeRewoundResults = rewindUnaryOps(loopCarriedLoadValue); - assert(succeeded(maybeRewoundResults)); - loopCarriedLoadValue = *maybeRewoundResults; - assert(loopCarriedLoadValue.hasOneUse()); - - // Handle pipelining - i.e., `local_store`, `memdesc_subview`, - // `local_load` ops. - triton::gpu::LocalLoadOp localLoadOp = nullptr; - auto loadOpUser = *(loopCarriedLoadValue.user_begin()); - auto localStoreOp = llvm::dyn_cast(loadOpUser); - if (localStoreOp) { - auto subviewOp = localStoreOp.getDst() - .getDefiningOp(); - Value subviewResult = subviewOp.getResult(); - auto it = llvm::find(yieldedValues, subviewResult); - if (it != yieldedValues.end()) { - size_t idx = std::distance(yieldedValues.begin(), it); - Value loopCarriedSubviewValue = initArgs[idx]; - - auto subviewLoadOps = - getUsersOfType(loopCarriedSubviewValue); - assert(subviewLoadOps.size() == 1); - localLoadOp = *subviewLoadOps.begin(); - - loopCarriedLoadValue = localLoadOp.getResult(); - } else { - auto localLoadOps = - getUsersOfType(subviewResult); - assert(localLoadOps.size() == 1); - localLoadOp = *localLoadOps.begin(); - auto it = llvm::find(yieldedValues, localLoadOp.getResult()); - assert(it != yieldedValues.end()); - size_t idx = std::distance(yieldedValues.begin(), it); - loopCarriedLoadValue = initArgs[idx]; - } - maybeRewoundResults = rewindUnaryOps(loopCarriedLoadValue); - assert(succeeded(maybeRewoundResults)); - loopCarriedLoadValue = *maybeRewoundResults; - } - - // Find the corresponding `DotOp`. - auto dots = getUsersOfType(loopCarriedLoadValue); - assert(dots.size() == 1); - - // Find which `DotOp` argument the current `loadOp` belongs to. - auto dotOperands = dots.begin()->getOperands(); - auto it = llvm::find(dotOperands, loopCarriedLoadValue); - assert(it != dotOperands.end()); - size_t opIdx = std::distance(dotOperands.begin(), it); - - // Set `OpIdx` attributes. - auto opIdxAttr = triton::amdgpu::OpIdxAttr::get(ctx, opIdx); - - loadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); - if (localStoreOp) - localStoreOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), - opIdxAttr); - } - } }; } // namespace diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp index deb566a8b1b5..0f1636699ca7 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -1,6 +1,7 @@ #include "TritonAMDGPUTransforms/Passes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -72,6 +73,7 @@ class StreamPipeliner { void scheduleDependencies(); void scheduleDistanceOneDependencies(); void scheduleRemainingToLastStage(tt::CoarseSchedule::Cluster afterPrologue); + void labelLoadOpsForTritonDot(); bool preprocessLoopAndBuildSchedule(); bool pipelineLoop(); @@ -168,6 +170,11 @@ void StreamPipeliner::createStreamCopy( result = select->getResults(); } + if (auto attr = loadOp->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic())) { + storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); + // sharedLoad->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); + } + loadOp->replaceAllUsesWith(result); // Prefetch load ahead of the dot stage if is used by the dot. @@ -337,6 +344,52 @@ void StreamPipeliner::assignMemoryLayouts() { } } +// Go through a single use chain to get the result of the target op after all +// unary ops - e.g., `convert_layout`, `fp_to_fp`, etc. +template +FailureOr rewindUnaryOps(Value value) { + auto getNextUnaryOps = [](Value value) -> FailureOr { + auto defOp = value.getDefiningOp(); + if (isa(value)) + return failure(); + if (defOp->getNumOperands() == 1) + return defOp; + return failure(); + }; + + auto maybeUnaryOp = getNextUnaryOps(value); + while (llvm::succeeded(maybeUnaryOp)) { + auto unaryOp = maybeUnaryOp.value(); + if (llvm::dyn_cast(unaryOp)) + return unaryOp; + + maybeUnaryOp = getNextUnaryOps(unaryOp->getOperand(0)); + } + return failure(); +} + +void StreamPipeliner::labelLoadOpsForTritonDot() { + mlir::MLIRContext *ctx = forOp->getContext(); + + triton::DotOp dotOp; + size_t dotCounter = 0; + forOp->walk([&dotCounter, &dotOp](triton::DotOp op) { + dotOp = op; + ++dotCounter; + }); + + if (dotCounter == 1) { + for (auto [opIdx, dotOperand] : llvm::enumerate(dotOp->getOperands())) { + auto maybeLoadOp = rewindUnaryOps(dotOperand); + if (llvm::succeeded(maybeLoadOp)) { + auto loadOp = maybeLoadOp.value(); + auto opIdxAttr = triton::amdgpu::OpIdxAttr::get(ctx, opIdx); + loadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); + } + } + } +} + void StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { // Get all loads that are (transitively) used by dot ops and their distance // to the dot op. @@ -594,6 +647,8 @@ void StreamPipeliner::createStreamOps() { } bool StreamPipeliner::preprocessLoopAndBuildSchedule() { + labelLoadOpsForTritonDot(); + // Schedule the loads and root ops (dot ops) in the loop. This will give us // a scaffold for the final schedule. DenseSet rootUsers; From 088fbd95f0474dc38decbc99ca0ad0a0648d3f1d Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Mon, 21 Oct 2024 17:31:07 +0000 Subject: [PATCH 05/11] [AMD] Fixed bug in `setNumGeneratedGlobalLoads` * add a test for the presence of OpIdx attribute --- third_party/amd/backend/compiler.py | 2 +- .../amd/include/TritonAMDGPUToLLVM/Passes.h | 3 +- .../amd/include/TritonAMDGPUToLLVM/Passes.td | 4 +- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 4 + .../TritonAMDGPUToLLVM/SchedInstructions.cpp | 105 ++++++++++++------ .../TritonAMDGPUToLLVM/SchedInstructions.h | 6 +- .../StreamPipelineV2.cpp | 16 ++- third_party/amd/python/triton_amd.cc | 5 +- 8 files changed, 95 insertions(+), 50 deletions(-) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 390d1c83e61d..8669f5e04707 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -274,7 +274,7 @@ def make_llir(src, metadata, options): passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) - amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.instruction_sched_variant) + amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.num_stages, options.instruction_sched_variant) if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": passes.llvmir.add_di_scope(pm) amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ) diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h index 969b357a74f6..4036cdecd1bd 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -38,7 +38,8 @@ createConvertBuiltinFuncToLLVMPass(bool ftz); std::unique_ptr> createTritonAMDGPUInsertInstructionSchedHintsPass(); std::unique_ptr> -createTritonAMDGPULowerInstructionSchedHintsPass(std::string variant); +createTritonAMDGPULowerInstructionSchedHintsPass(int32_t numStages, + std::string variant); #define GEN_PASS_REGISTRATION #include "TritonAMDGPUToLLVM/Passes.h.inc" diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index b9b06e47d217..47f8395fa836 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -69,12 +69,14 @@ def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruc def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { let summary = "Lower instruction scheduling hints to LLVM intrinsics"; - let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(\"\")"; + let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*numStages=*/2, /*variant=*/\"\")"; let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::triton::amdgpu::TritonAMDGPUDialect"]; let options = [ + Option<"numStages", "num_stages", "int32_t", /*default*/"2", + "number of pipeline stages">, Option<"variant", "variant", "std::string", /*default*/"\"default\"", "instruction scheduling variant">, ]; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 437b64b438d1..ef0ef5e59132 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -310,6 +310,9 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, rewriter, llvmResultStructTy); + + setNumGeneratedGlobalLoads(op, numVecs, vecTy); + rewriter.replaceOp(op, {resultStruct}); return success(); } @@ -393,6 +396,7 @@ struct BufferLoadOpConversion Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, rewriter, llvmResultStructTy); + const int numVecs = numElems / vec; setNumGeneratedGlobalLoads(op, numVecs, vecTy); rewriter.replaceOp(op, {resultStruct}); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index b762bcb0e0b0..70747cacabe0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -1,11 +1,10 @@ +#include "SchedInstructions.h" #include "TritonAMDGPUToLLVM/Passes.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Pass/Pass.h" -#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" namespace mlir::triton { #define GEN_PASS_DEF_TRITONAMDGPUINSERTINSTRUCTIONSCHEDHINTS @@ -28,22 +27,28 @@ void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, }); } -void setNumGeneratedGlobalLoads(triton::LoadOp op, size_t globalLoadsCount, +template +void setNumGeneratedGlobalLoads(LoadOpType op, size_t globalLoadsCount, Type type) { MLIRContext *ctx = op->getContext(); auto counterAttr = triton::amdgpu::InstCounterAttr::get(ctx, globalLoadsCount, type); op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { - auto opIdxAttr = cast( - op->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic())); - assert(opIdxAttr.getValue() < 2); - if (opIdxAttr.getValue() == 0) - schedHint.setNumGlobalLoadsAAttr(counterAttr); - else - schedHint.setNumGlobalLoadsBAttr(counterAttr); + if (auto opIdxAttr = op->template getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + assert(opIdxAttr.getValue() < 2); + if (opIdxAttr.getValue() == 0) + schedHint.setNumGlobalLoadsAAttr(counterAttr); + else + schedHint.setNumGlobalLoadsBAttr(counterAttr); + } }); } +template void setNumGeneratedGlobalLoads(triton::amdgpu::BufferLoadOp op, + size_t globalLoadsCount, Type type); +template void setNumGeneratedGlobalLoads(triton::LoadOp op, + size_t globalLoadsCount, Type type); void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t dsReadsCount, Type type) { @@ -72,15 +77,28 @@ void storeOpConversionCallback(triton::gpu::LocalStoreOp op, triton::amdgpu::InstCounterAttr::get(ctx, localStoreOpCount, type); op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { - auto opIdxAttr = op->getAttrOfType( - triton::amdgpu::OpIdxAttr::getMnemonic()); - assert(opIdxAttr.getValue() < 2); - if (opIdxAttr.getValue() == 0) - schedHint.setNumDsWritesAAttr(counterAttr); - else - schedHint.setNumDsWritesBAttr(counterAttr); + if (auto opIdxAttr = op->getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + assert(opIdxAttr.getValue() < 2); + if (opIdxAttr.getValue() == 0) + schedHint.setNumDsWritesAAttr(counterAttr); + else + schedHint.setNumDsWritesBAttr(counterAttr); + } }); } + +llvm::FailureOr hasSingleDotOp(scf::ForOp forOp) { + triton::DotOp dotOp = nullptr; + size_t dotCounter = 0; + forOp->walk( + [&dotOp, &dotCounter](triton::DotOp op) { dotOp = op, ++dotCounter; }); + + if (dotCounter == 1) + return dotOp; + + return llvm::failure(); +} } // namespace mlir::triton namespace { @@ -119,8 +137,9 @@ Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) { struct InstructionSchedHintsRewriter : public OpRewritePattern { - InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, std::string variant) - : OpRewritePattern(ctx) { + InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, int32_t numStages, + std::string variant) + : OpRewritePattern(ctx), numStages(numStages) { std::transform(variant.begin(), variant.end(), variant.begin(), [](unsigned char c) { return std::tolower(c); }); @@ -130,6 +149,13 @@ struct InstructionSchedHintsRewriter .Case("iglp1", SchedulingType::IGLP1) .Case("ck_v3", SchedulingType::CK_V3) .Default(SchedulingType::UNKNOWN); + + if (this->numStages < 2) { + this->schedulingType = SchedulingType::NONE; + llvm::dbgs() << "[" << getDebugName() << "]: " + << "ignoring instruction scheduling due to a very low num. " + "stages value. Must be >= 2\n"; + } } enum class SchedulingType : uint32_t { @@ -160,6 +186,11 @@ struct InstructionSchedHintsRewriter const uint32_t numBufferLoadInstB = schedHint.getNumGlobalLoadsB().getValue(); + assert(numBufferLoadInstA && + "buffer load count for tile A must be initialized"); + assert(numBufferLoadInstB && + "buffer load count for tile B must be initialized"); + const uint32_t numMfmaInst = schedHint.getNumMMAs().getValue(); auto mfmaType = cast(schedHint.getNumMMAs().getType()); @@ -184,7 +215,7 @@ struct InstructionSchedHintsRewriter // stage 1 const auto numMfmaStage1 = numMfmaInst - (numDsreadAMfma + numDsreadBMfma); - const auto num_mfma_per_issue = + const auto numMfmaPerIssue = numMfmaStage1 / (numBufferLoadInstA + numBufferLoadInstB); const auto numDswritePerIssueA = numDsWriteInstA / numBufferLoadInstA; @@ -203,7 +234,7 @@ struct InstructionSchedHintsRewriter rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0); createSchedGroupBarrier(rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, - num_mfma_per_issue - numDswritePerIssueA, 0); + numMfmaPerIssue - numDswritePerIssueA, 0); } for (size_t i = 0; i < numBufferLoadInstB; ++i) { @@ -219,7 +250,7 @@ struct InstructionSchedHintsRewriter rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0); createSchedGroupBarrier(rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, - num_mfma_per_issue - numDswritePerIssueB, 0); + numMfmaPerIssue - numDswritePerIssueB, 0); } // stage 2 @@ -308,6 +339,7 @@ struct InstructionSchedHintsRewriter } private: + int32_t numStages; SchedulingType schedulingType; }; @@ -315,7 +347,9 @@ struct TritonAMDGPULowerInstructionSchedHints : public triton::impl::TritonAMDGPULowerInstructionSchedHintsBase< TritonAMDGPULowerInstructionSchedHints> { - explicit TritonAMDGPULowerInstructionSchedHints(std::string variant) { + explicit TritonAMDGPULowerInstructionSchedHints(int32_t numStages, + std::string variant) { + this->numStages = numStages; this->variant = variant; } @@ -331,7 +365,8 @@ struct TritonAMDGPULowerInstructionSchedHints target.addLegalOp(); RewritePatternSet patterns(ctx); - patterns.add(ctx, this->variant); + patterns.add(ctx, this->numStages, + this->variant); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { @@ -343,24 +378,22 @@ struct TritonAMDGPULowerInstructionSchedHints struct TritonAMDGPUInsertInstructionSchedHints : public triton::impl::TritonAMDGPUInsertInstructionSchedHintsBase< TritonAMDGPUInsertInstructionSchedHints> { + void runOnOperation() override { MLIRContext *ctx = &getContext(); ModuleOp mod = getOperation(); mod.walk([this, ctx](scf::ForOp forOp) { - triton::DotOp dot = nullptr; - size_t dotCounter = 0; - forOp->walk([&dot, &dotCounter](triton::DotOp op) { - dot = op; - ++dotCounter; - }); + auto maybeSingleDotOp = hasSingleDotOp(forOp); + // Note, instruction schedule barriers are inserted only in the case of // a single `tt.dot` op in a `scf::ForOp` scope in the current // implementation. - if (dotCounter == 1) { + if (llvm::succeeded(maybeSingleDotOp)) { + triton::DotOp dotOp = maybeSingleDotOp.value(); mlir::OpBuilder rewriter(ctx); - rewriter.setInsertionPointAfter(dot); - rewriter.create(dot->getLoc()); + rewriter.setInsertionPointAfter(dotOp); + rewriter.create(dotOp->getLoc()); } }); } @@ -369,8 +402,10 @@ struct TritonAMDGPUInsertInstructionSchedHints namespace mlir::triton { std::unique_ptr> -createTritonAMDGPULowerInstructionSchedHintsPass(std::string variant) { - return std::make_unique(variant); +createTritonAMDGPULowerInstructionSchedHintsPass(int32_t numStages, + std::string variant) { + return std::make_unique(numStages, + variant); } std::unique_ptr> diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h index 6b81dd0ab2db..58c036dffbbf 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h @@ -2,6 +2,7 @@ #define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_SCHED_INSTRUCTIONS_H #include "mlir/IR/Types.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -11,12 +12,15 @@ namespace mlir::triton { void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, unsigned k, Type elementType); -void setNumGeneratedGlobalLoads(triton::LoadOp op, size_t globalLoadsCount, + +template +void setNumGeneratedGlobalLoads(LoadOpType op, size_t globalLoadsCount, Type type); void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t numDsReadsCount, Type type); void storeOpConversionCallback(triton::gpu::LocalStoreOp op, size_t llvmOpCount, Type type); +llvm::FailureOr hasSingleDotOp(scf::ForOp forOp); } // namespace mlir::triton #endif diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp index 0f1636699ca7..d65b142d05d7 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -2,6 +2,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -172,7 +173,6 @@ void StreamPipeliner::createStreamCopy( if (auto attr = loadOp->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic())) { storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); - // sharedLoad->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); } loadOp->replaceAllUsesWith(result); @@ -368,17 +368,15 @@ FailureOr rewindUnaryOps(Value value) { return failure(); } +// Annotate each `tt.LoadOp` instruction with its corresponding gemm operand +// index. Note, this is a part of the instruction scheduling routine. Currently, +// we support `forOp`s which contain only a single `tt.DotOp` in the bodies. void StreamPipeliner::labelLoadOpsForTritonDot() { mlir::MLIRContext *ctx = forOp->getContext(); + auto maybeSingleDotOp = triton::hasSingleDotOp(forOp); - triton::DotOp dotOp; - size_t dotCounter = 0; - forOp->walk([&dotCounter, &dotOp](triton::DotOp op) { - dotOp = op; - ++dotCounter; - }); - - if (dotCounter == 1) { + if (llvm::succeeded(maybeSingleDotOp)) { + triton::DotOp dotOp = maybeSingleDotOp.value(); for (auto [opIdx, dotOperand] : llvm::enumerate(dotOp->getOperands())) { auto maybeLoadOp = rewindUnaryOps(dotOperand); if (llvm::succeeded(maybeLoadOp)) { diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 6f1df5a43fae..d30be6959839 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -48,8 +48,9 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { pm.addPass(createTritonAMDGPUInsertInstructionSchedHintsPass()); }); m.def("lower_instruction_sched_hints", - [](mlir::PassManager &pm, std::string variant) { - pm.addPass(createTritonAMDGPULowerInstructionSchedHintsPass(variant)); + [](mlir::PassManager &pm, int32_t numStages, std::string variant) { + pm.addPass(createTritonAMDGPULowerInstructionSchedHintsPass(numStages, + variant)); }); m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm, const std::string &arch) { From c72bfb96b137ecec301d123c042b5e6e955aada8 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Tue, 22 Oct 2024 13:13:46 +0000 Subject: [PATCH 06/11] [AMD] added additional check into `createCKV3Schedule` The extra check tests whether the data are loaded from HBM using `buffer_load` instructions. The CKV3 scheduling is skipped if the check fails. --- .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 4 ++- .../TritonAMDGPUToLLVM/SchedInstructions.cpp | 31 ++++++++++++++----- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 494e45819836..e4eec7f7f827 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -64,6 +64,8 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { TritonAMDGPU_InstCounter:$numDsWritesB, TritonAMDGPU_InstCounter:$numGlobalLoadsA, TritonAMDGPU_InstCounter:$numGlobalLoadsB, + BoolAttr:$isBufferLoadsAEnabled, + BoolAttr:$isBufferLoadsBEnabled, TritonAMDGPU_InstCounter:$numMMAs ); @@ -73,7 +75,7 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { auto type = IntegerType::get(ctx, 32); auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, type); build($_builder, $_state, emptyAttr, emptyAttr, emptyAttr, emptyAttr, - emptyAttr, emptyAttr, emptyAttr); + emptyAttr, emptyAttr, false, false, emptyAttr); }]> ]; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index 70747cacabe0..10f680a6213b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -12,6 +12,11 @@ namespace mlir::triton { #include "TritonAMDGPUToLLVM/Passes.h.inc" } // namespace mlir::triton +#undef DEBUG_TYPE +#define DEBUG_TYPE "lower-insert-instruction-sched-hints" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + using namespace mlir; namespace mlir::triton { @@ -38,10 +43,15 @@ void setNumGeneratedGlobalLoads(LoadOpType op, size_t globalLoadsCount, if (auto opIdxAttr = op->template getAttrOfType( triton::amdgpu::OpIdxAttr::getMnemonic())) { assert(opIdxAttr.getValue() < 2); - if (opIdxAttr.getValue() == 0) + const bool isBufferLoadOp = + std::is_same_v; + if (opIdxAttr.getValue() == 0) { schedHint.setNumGlobalLoadsAAttr(counterAttr); - else + schedHint.setIsBufferLoadsAEnabled(isBufferLoadOp); + } else { schedHint.setNumGlobalLoadsBAttr(counterAttr); + schedHint.setIsBufferLoadsBEnabled(isBufferLoadOp); + } } }); } @@ -152,9 +162,8 @@ struct InstructionSchedHintsRewriter if (this->numStages < 2) { this->schedulingType = SchedulingType::NONE; - llvm::dbgs() << "[" << getDebugName() << "]: " - << "ignoring instruction scheduling due to a very low num. " - "stages value. Must be >= 2\n"; + LDBG("ignoring instruction scheduling due to a very low num. " + "stages value. Must be >= 2"); } } @@ -175,6 +184,14 @@ struct InstructionSchedHintsRewriter void createCKV3Schedule(PatternRewriter &rewriter, Location loc, triton::amdgpu::InstructionSchedHint schedHint) const { + + if (!(schedHint.getIsBufferLoadsAEnabled() && + schedHint.getIsBufferLoadsBEnabled())) { + LDBG("Skipping instruction scheduling because `ckv3` " + "scheduling can be used only with `buffer_load` instructions"); + return; + } + const uint32_t numDsReadInstA = schedHint.getNumDsReadsA().getValue(); const uint32_t numDsReadInstB = schedHint.getNumDsReadsB().getValue(); @@ -288,9 +305,7 @@ struct InstructionSchedHintsRewriter PatternRewriter &rewriter) const override { if (this->schedulingType == SchedulingType::UNKNOWN) { - llvm::dbgs() - << "[" << getDebugName() << "]: " - << "unknown instruction scheduling variant has been provided\n"; + LDBG("unknown instruction scheduling variant has been provided"); return mlir::failure(); } From 68e7fac3f6796d9dded3515c80f8fc652f39ced0 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Wed, 23 Oct 2024 13:28:42 +0000 Subject: [PATCH 07/11] [AMD] Udated tests for `SchedInstructions` passes --- test/TritonGPU/amd/amd-instruction-sched.mlir | 227 +++++++----------- .../amd/include/TritonAMDGPUToLLVM/Passes.td | 1 + .../include/TritonAMDGPUTransforms/Passes.td | 2 +- .../TritonAMDGPUToLLVM/SchedInstructions.cpp | 30 ++- .../StreamPipelineV2.cpp | 92 ++++--- 5 files changed, 148 insertions(+), 204 deletions(-) diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir index bca502f980cb..496bcd06ad4c 100644 --- a/test/TritonGPU/amd/amd-instruction-sched.mlir +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -1,148 +1,87 @@ -// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -allocate-shared-memory -convert-scf-to-cf -convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s -check-prefix=INSTR_INSERTION -// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -allocate-shared-memory -convert-scf-to-cf -convert-triton-amdgpu-to-llvm=arch=gfx942 -triton-amdgpu-lower-insert-instruction-sched-hints=variant="iglp0" | FileCheck %s -check-prefix=LOWER_IGLP0 - -#shared0_ex0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -#mma0_ex0 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = false}> - -#blocked0_ex1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1_ex1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked2_ex1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared0_ex1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared1_ex1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#mma0_ex1 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = false}> -#dot0_ex1 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma0_ex1, kWidth = 8}> -#dot1_ex1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma0_ex1, kWidth = 8}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - // LOWER_IGLP0-LABEL: test_instruction_hints_lowering - tt.func @test_instruction_hints_lowering( - %arg0: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0_ex0, kWidth = 16}>>, - %arg1: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0_ex0, kWidth = 16}>>, - %arg2: tensor<32x32xf16, #mma0_ex0>) { - - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c64_i32 = arith.constant 1 : i32 - - scf.for %arg11 = %c0_i32 to %c64_i32 step %c1_i32 iter_args() -> () : i32 { - // LOWER_IGLP0: llvm.add - // LOWER_IGLP0-NEXT: %[[OPT_LEVEL:.*]] = llvm.mlir.constant(0 : i32) : i32 - // LOWER_IGLP0-NEXT: llvm.call_intrinsic "llvm.amdgcn.iglp.opt"(%[[OPT_LEVEL]]) : (i32) -> () - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0_ex0, kWidth = 16}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0_ex0, kWidth = 16}>> -> tensor<32x32xf16, #mma0_ex0> - scf.yield - } - tt.return +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0 +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1 +// XFAIL: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=unknown' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=INSERT_UNKNOWN +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2 +// XFAIL: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=ck_v3' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_CKV3_GLOBAL_LOAD + + +module { + // INSERT_IGLP0-LABEL: @test_dot_op + // INSERT_IGLP1-LABEL: @test_dot_op + // INSTR_COUNT_NS1-LABEL: @test_dot_op + // INSTR_COUNT_NS2-LABEL: @test_dot_op + tt.func @test_dot_op(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}, + %C : !tt.ptr {tt.divisibility = 16 : i32}) { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32> -> tensor<128x32xi32> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr>, tensor<128x32xi32> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32> -> tensor<32x128xi32> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr>, tensor<32x128xi32> + + %a_mask = arith.constant dense : tensor<128x32xi1> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16> + %b_mask = arith.constant dense : tensor<32x128xi1> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32> + + %a_off = arith.constant dense<4> : tensor<128x32xi32> + %b_off = arith.constant dense<4> : tensor<32x128xi32> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr>, tensor<32x128x!tt.ptr>, tensor<128x128xf32>) { + %a = tt.load %a_ptr : tensor<128x32x!tt.ptr> + %b = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr> + + // INSERT_IGLP0: rocdl.iglp.opt 0 + // INSERT_IGLP1: rocdl.iglp.opt 1 + // INSERT_UNKNOWN: error: unknown instruction scheduling variant has been provided + + // INSTR_COUNT_NS1: amdgpu.instruction_sched_hint + // INSTR_COUNT_NS1-SAME: isBufferLoadsAEnabled = false + // INSTR_COUNT_NS1-SAME: isBufferLoadsBEnabled = false + // INSTR_COUNT_NS1-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>> + // INSTR_COUNT_NS1-SAME: numDsWritesA = #amdgpu.InstCounter<0, i32> + // INSTR_COUNT_NS1-SAME: numDsWritesB = #amdgpu.InstCounter<0, i32> + // INSTR_COUNT_NS1-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> + + // INSTR_COUNT_NS2: amdgpu.instruction_sched_hint + // INSTR_COUNT_NS2-SAME: isBufferLoadsAEnabled = false + // INSTR_COUNT_NS2-SAME: isBufferLoadsBEnabled = false + // INSTR_COUNT_NS2-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>> + // INSTR_COUNT_NS2-SAME: numDsWritesA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numDsWritesB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> + + // USE_CKV3_GLOBAL_LOAD: error: Skipping instruction scheduling because `ck_v3` scheduling can be used only with `buffer_load` instructions + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr>, tensor<128x32xi32> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr>, tensor<32x128xi32> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr>, tensor<32x128x!tt.ptr>, tensor<128x128xf32> } - // INSTR_INSERTION-LABEL: @test_llvm_instruction_count - tt.func public @test_llvm_instruction_count( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) attributes {noinline = false} { - - %cst = arith.constant dense<64> : tensor<256x64xi32, #blocked0_ex1> - %cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked1_ex1> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c64_i32 = arith.constant 64 : i32 - %c63_i32 = arith.constant 63 : i32 - %c128_i32 = arith.constant 128 : i32 - %c256_i32 = arith.constant 256 : i32 - - %19 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>> - %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>> - %21 = tt.splat %c256_i32 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>> - %22 = tt.splat %c256_i32 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>> - %23 = arith.addi %21, %19 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>> - %24 = arith.addi %22, %20 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>> - - %26 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>> - %27 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>> - %28 = tt.splat %c128_i32 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>> - %29 = tt.splat %c128_i32 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>> - %30 = arith.addi %28, %26 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>> - %31 = arith.addi %29, %27 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>> - %32 = tt.expand_dims %23 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>> -> tensor<256x1xi32, #blocked0_ex1> - %33 = tt.expand_dims %24 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>> -> tensor<256x1xi32, #blocked2_ex1> - %34 = tt.splat %c64_i32 : i32 -> tensor<256x1xi32, #blocked0_ex1> - %35 = arith.muli %32, %34 : tensor<256x1xi32, #blocked0_ex1> - %36 = tt.splat %arg0 : !tt.ptr -> tensor<256x1x!tt.ptr, #blocked0_ex1> - %37 = tt.addptr %36, %35 : tensor<256x1x!tt.ptr, #blocked0_ex1>, tensor<256x1xi32, #blocked0_ex1> - %38 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0_ex1}>> - %39 = tt.expand_dims %38 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0_ex1}>> -> tensor<1x64xi32, #blocked0_ex1> - %40 = tt.broadcast %37 : tensor<256x1x!tt.ptr, #blocked0_ex1> -> tensor<256x64x!tt.ptr, #blocked0_ex1> - %41 = tt.broadcast %39 : tensor<1x64xi32, #blocked0_ex1> -> tensor<256x64xi32, #blocked0_ex1> - %42 = tt.addptr %40, %41 : tensor<256x64x!tt.ptr, #blocked0_ex1>, tensor<256x64xi32, #blocked0_ex1> - - %43 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1_ex1}>> - %44 = tt.expand_dims %43 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1_ex1}>> -> tensor<64x1xi32, #blocked1_ex1> - %45 = tt.splat %arg1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1_ex1> - %46 = tt.addptr %45, %44 : tensor<64x1x!tt.ptr, #blocked1_ex1>, tensor<64x1xi32, #blocked1_ex1> - %47 = tt.expand_dims %30 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>> -> tensor<1x128xi32, #blocked1_ex1> - %48 = tt.expand_dims %31 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>> -> tensor<1x128xi32, #blocked2_ex1> - %49 = tt.splat %c64_i32 : i32 -> tensor<1x128xi32, #blocked1_ex1> - %50 = arith.muli %47, %49 : tensor<1x128xi32, #blocked1_ex1> - %51 = tt.broadcast %46 : tensor<64x1x!tt.ptr, #blocked1_ex1> -> tensor<64x128x!tt.ptr, #blocked1_ex1> - %52 = tt.broadcast %50 : tensor<1x128xi32, #blocked1_ex1> -> tensor<64x128xi32, #blocked1_ex1> - %53 = tt.addptr %51, %52 : tensor<64x128x!tt.ptr, #blocked1_ex1>, tensor<64x128xi32, #blocked1_ex1> - - %56 = triton_gpu.local_alloc : () -> !tt.memdesc<1x256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> - %57 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> + // C ptrs + %c_ptr_splat = tt.splat %C : !tt.ptr -> tensor<128x128x!tt.ptr> + %c_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32> + %c_tmp1 = tt.expand_dims %c_tmp0 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %c_offs = tt.broadcast %c_tmp1 : tensor<1x128xi32> -> tensor<128x128xi32> + %c_ptr = tt.addptr %c_ptr_splat, %c_offs : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma0_ex1> - - %cc0_i1 = arith.constant 1 : i1 - %59 = tt.splat %cc0_i1 : i1 -> tensor<256x64xi1, #blocked0_ex1> - %60 = tt.load %42, %59 : tensor<256x64x!tt.ptr, #blocked0_ex1> - %61 = tt.splat %cc0_i1 : i1 -> tensor<64x128xi1, #blocked1_ex1> - %62 = tt.load %53, %61 : tensor<64x128x!tt.ptr, #blocked1_ex1> - - %63 = triton_gpu.memdesc_subview %56[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %60, %63 : tensor<256x64xf16, #blocked0_ex1> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> - %64 = triton_gpu.memdesc_subview %57[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %62, %64 : tensor<64x128xf16, #blocked1_ex1> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> - - %66:5 = scf.for %arg11 = %c0_i32 to %c63_i32 step %c1_i32 iter_args( - %arg12 = %cst_1, - %arg13 = %42, - %arg14 = %53, - %arg16 = %63, - %arg17 = %64) -> ( - tensor<256x128xf32, #mma0_ex1>, - tensor<256x64x!tt.ptr, #blocked0_ex1>, - tensor<64x128x!tt.ptr, #blocked1_ex1>, - !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable>, - !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable>) : i32 { - - %82 = triton_gpu.local_load %arg16 : !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dot0_ex1> - %83 = triton_gpu.local_load %arg17 : !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> -> tensor<64x128xf16, #dot1_ex1> - - // INSTR_INSERTION: amdgpu.instruction_sched_hint - // INSTR_INSERTION-SAME: numDsReadsA = #amdgpu.InstCounter<16, vector<8xf16>> - // INSTR_INSERTION-SAME: numDsReadsB = #amdgpu.InstCounter<8, vector<8xf16>> - // INSTR_INSERTION-SAME: numDsWritesA = #amdgpu.InstCounter<8, vector<8xf16>> - // INSTR_INSERTION-SAME: numDsWritesB = #amdgpu.InstCounter<4, vector<8xf16>> - // INSTR_INSERTION-SAME: numGlobalLoadsA = #amdgpu.InstCounter<8, vector<8xf16>> - // INSTR_INSERTION-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<8xf16>> - // INSTR_INSERTION-SAME: numMMAs = #amdgpu.InstCounter<64, tensor<32x32x8xf16>> - - %84 = tt.dot %82, %83, %arg12 : tensor<256x64xf16, #dot0_ex1> * tensor<64x128xf16, #dot1_ex1> -> tensor<256x128xf32, #mma0_ex1> - %85 = tt.addptr %arg13, %cst : tensor<256x64x!tt.ptr, #blocked0_ex1>, tensor<256x64xi32, #blocked0_ex1> - %86 = tt.addptr %arg14, %cst_0 : tensor<64x128x!tt.ptr, #blocked1_ex1>, tensor<64x128xi32, #blocked1_ex1> - %87 = tt.load %85 : tensor<256x64x!tt.ptr, #blocked0_ex1> - %88 = tt.load %86 : tensor<64x128x!tt.ptr, #blocked1_ex1> - %89 = triton_gpu.memdesc_subview %56[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %87, %89 : tensor<256x64xf16, #blocked0_ex1> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> - %90 = triton_gpu.memdesc_subview %57[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %88, %90 : tensor<64x128xf16, #blocked1_ex1> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> - - scf.yield %84, %85, %86, %89, %90 : - tensor<256x128xf32, #mma0_ex1>, - tensor<256x64x!tt.ptr, #blocked0_ex1>, - tensor<64x128x!tt.ptr, #blocked1_ex1>, - !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable>, - !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> - } - tt.return - } + tt.store %c_ptr, %loop#2 : tensor<128x128x!tt.ptr> + tt.return +} } diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index 47f8395fa836..0c1ccee76d77 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -72,6 +72,7 @@ def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-in let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*numStages=*/2, /*variant=*/\"\")"; let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::ROCDL::ROCDLDialect", "mlir::triton::amdgpu::TritonAMDGPUDialect"]; let options = [ diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 433e60be67f6..93345b0d6de4 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -13,7 +13,7 @@ def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir let constructor = "mlir::createTritonAMDGPUStreamPipelineV2Pass()"; - let dependentDialects = []; + let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; let options = [ Option<"numStages", "num_stages", diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index 10f680a6213b..b6524fe23d61 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -98,7 +98,7 @@ void storeOpConversionCallback(triton::gpu::LocalStoreOp op, }); } -llvm::FailureOr hasSingleDotOp(scf::ForOp forOp) { +mlir::FailureOr hasSingleDotOp(scf::ForOp forOp) { triton::DotOp dotOp = nullptr; size_t dotCounter = 0; forOp->walk( @@ -107,7 +107,7 @@ llvm::FailureOr hasSingleDotOp(scf::ForOp forOp) { if (dotCounter == 1) return dotOp; - return llvm::failure(); + return mlir::failure(); } } // namespace mlir::triton @@ -187,8 +187,9 @@ struct InstructionSchedHintsRewriter if (!(schedHint.getIsBufferLoadsAEnabled() && schedHint.getIsBufferLoadsBEnabled())) { - LDBG("Skipping instruction scheduling because `ckv3` " - "scheduling can be used only with `buffer_load` instructions"); + schedHint.emitError( + "Skipping instruction scheduling because `ck_v3` " + "scheduling can be used only with `buffer_load` instructions"); return; } @@ -203,10 +204,11 @@ struct InstructionSchedHintsRewriter const uint32_t numBufferLoadInstB = schedHint.getNumGlobalLoadsB().getValue(); - assert(numBufferLoadInstA && - "buffer load count for tile A must be initialized"); - assert(numBufferLoadInstB && - "buffer load count for tile B must be initialized"); + if (numBufferLoadInstA == 0) + schedHint.emitError("buffer load count for tile A must be initialized"); + + if (numBufferLoadInstB == 0) + schedHint->emitError("buffer load count for tile B must be initialized"); const uint32_t numMfmaInst = schedHint.getNumMMAs().getValue(); @@ -305,7 +307,8 @@ struct InstructionSchedHintsRewriter PatternRewriter &rewriter) const override { if (this->schedulingType == SchedulingType::UNKNOWN) { - LDBG("unknown instruction scheduling variant has been provided"); + instructionSchedHint.emitError( + "unknown instruction scheduling variant has been provided"); return mlir::failure(); } @@ -380,11 +383,14 @@ struct TritonAMDGPULowerInstructionSchedHints target.addLegalOp(); RewritePatternSet patterns(ctx); + patterns.add(ctx, this->numStages, + this->variant); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { + if (mlir::failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); } } @@ -404,7 +410,7 @@ struct TritonAMDGPUInsertInstructionSchedHints // Note, instruction schedule barriers are inserted only in the case of // a single `tt.dot` op in a `scf::ForOp` scope in the current // implementation. - if (llvm::succeeded(maybeSingleDotOp)) { + if (mlir::succeeded(maybeSingleDotOp)) { triton::DotOp dotOp = maybeSingleDotOp.value(); mlir::OpBuilder rewriter(ctx); rewriter.setInsertionPointAfter(dotOp); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp index d65b142d05d7..6fd119bc4cdf 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -74,7 +74,6 @@ class StreamPipeliner { void scheduleDependencies(); void scheduleDistanceOneDependencies(); void scheduleRemainingToLastStage(tt::CoarseSchedule::Cluster afterPrologue); - void labelLoadOpsForTritonDot(); bool preprocessLoopAndBuildSchedule(); bool pipelineLoop(); @@ -344,50 +343,6 @@ void StreamPipeliner::assignMemoryLayouts() { } } -// Go through a single use chain to get the result of the target op after all -// unary ops - e.g., `convert_layout`, `fp_to_fp`, etc. -template -FailureOr rewindUnaryOps(Value value) { - auto getNextUnaryOps = [](Value value) -> FailureOr { - auto defOp = value.getDefiningOp(); - if (isa(value)) - return failure(); - if (defOp->getNumOperands() == 1) - return defOp; - return failure(); - }; - - auto maybeUnaryOp = getNextUnaryOps(value); - while (llvm::succeeded(maybeUnaryOp)) { - auto unaryOp = maybeUnaryOp.value(); - if (llvm::dyn_cast(unaryOp)) - return unaryOp; - - maybeUnaryOp = getNextUnaryOps(unaryOp->getOperand(0)); - } - return failure(); -} - -// Annotate each `tt.LoadOp` instruction with its corresponding gemm operand -// index. Note, this is a part of the instruction scheduling routine. Currently, -// we support `forOp`s which contain only a single `tt.DotOp` in the bodies. -void StreamPipeliner::labelLoadOpsForTritonDot() { - mlir::MLIRContext *ctx = forOp->getContext(); - auto maybeSingleDotOp = triton::hasSingleDotOp(forOp); - - if (llvm::succeeded(maybeSingleDotOp)) { - triton::DotOp dotOp = maybeSingleDotOp.value(); - for (auto [opIdx, dotOperand] : llvm::enumerate(dotOp->getOperands())) { - auto maybeLoadOp = rewindUnaryOps(dotOperand); - if (llvm::succeeded(maybeLoadOp)) { - auto loadOp = maybeLoadOp.value(); - auto opIdxAttr = triton::amdgpu::OpIdxAttr::get(ctx, opIdx); - loadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); - } - } - } -} - void StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { // Get all loads that are (transitively) used by dot ops and their distance // to the dot op. @@ -645,8 +600,6 @@ void StreamPipeliner::createStreamOps() { } bool StreamPipeliner::preprocessLoopAndBuildSchedule() { - labelLoadOpsForTritonDot(); - // Schedule the loads and root ops (dot ops) in the loop. This will give us // a scaffold for the final schedule. DenseSet rootUsers; @@ -738,6 +691,50 @@ bool StreamPipeliner::pipelineLoop() { } namespace { +// Go through a single use chain to get the result of the target op after all +// unary ops - e.g., `convert_layout`, `fp_to_fp`, etc. +template +FailureOr rewindUnaryOps(Value value) { + auto getNextUnaryOps = [](Value value) -> FailureOr { + auto defOp = value.getDefiningOp(); + if (isa(value)) + return failure(); + if ((defOp->getNumOperands() == 1) || llvm::dyn_cast(defOp)) + return defOp; + return failure(); + }; + + auto maybeUnaryOp = getNextUnaryOps(value); + while (llvm::succeeded(maybeUnaryOp)) { + auto unaryOp = maybeUnaryOp.value(); + if (llvm::dyn_cast(unaryOp)) + return unaryOp; + + maybeUnaryOp = getNextUnaryOps(unaryOp->getOperand(0)); + } + return failure(); +} + +// Annotate each `tt.LoadOp` instruction with its corresponding gemm operand +// index. Note, this is a part of the instruction scheduling routine. Currently, +// we support `forOp`s which contain only a single `tt.DotOp` in the bodies. +void labelLoadOpsForTritonDot(scf::ForOp forOp) { + mlir::MLIRContext *ctx = forOp->getContext(); + auto maybeSingleDotOp = triton::hasSingleDotOp(forOp); + + if (llvm::succeeded(maybeSingleDotOp)) { + triton::DotOp dotOp = maybeSingleDotOp.value(); + for (auto [opIdx, dotOperand] : llvm::enumerate(dotOp->getOperands())) { + auto maybeLoadOp = rewindUnaryOps(dotOperand); + if (llvm::succeeded(maybeLoadOp)) { + auto loadOp = maybeLoadOp.value(); + auto opIdxAttr = triton::amdgpu::OpIdxAttr::get(ctx, opIdx); + loadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); + } + } + } +} + struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { PipelinePass() = default; PipelinePass(int32_t numStages) { this->numStages = numStages; } @@ -745,6 +742,7 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { void runOnOperation() override { SmallVector loops; getOperation()->walk([&](scf::ForOp forOp) { + labelLoadOpsForTritonDot(forOp); // Bail out for loops with num_stage <= 1. if (getNumStagesOrDefault(forOp) > 1) loops.push_back(forOp); From dd7d2c6faecd98da9f2f1a4f3cc5347630435f2c Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Mon, 28 Oct 2024 12:49:24 +0000 Subject: [PATCH 08/11] [AMD] Addressed comments of PR#4940 --- test/TritonGPU/amd/amd-instruction-sched.mlir | 26 +++++++++--- .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 4 +- .../TritonAMDGPUToLLVM/SchedInstructions.cpp | 26 ++++++------ .../TritonAMDGPUToLLVM/SchedInstructions.h | 2 +- .../StreamPipelineV2.cpp | 40 +++++++++---------- 5 files changed, 54 insertions(+), 44 deletions(-) diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir index 496bcd06ad4c..a1814f4dfcdb 100644 --- a/test/TritonGPU/amd/amd-instruction-sched.mlir +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -1,16 +1,18 @@ // RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0 // RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1 -// XFAIL: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=unknown' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=INSERT_UNKNOWN // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1 // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2 -// XFAIL: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=ck_v3' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_CKV3_GLOBAL_LOAD - +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=ck_v3' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_CKV3_GLOBAL_LOAD +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2 module { // INSERT_IGLP0-LABEL: @test_dot_op // INSERT_IGLP1-LABEL: @test_dot_op // INSTR_COUNT_NS1-LABEL: @test_dot_op // INSTR_COUNT_NS2-LABEL: @test_dot_op + // LABELING_PS_1-LABEL: @test_dot_op + // LABELING_PS_2-LABEL: @test_dot_op tt.func @test_dot_op(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}, @@ -43,7 +45,6 @@ module { // INSERT_IGLP0: rocdl.iglp.opt 0 // INSERT_IGLP1: rocdl.iglp.opt 1 - // INSERT_UNKNOWN: error: unknown instruction scheduling variant has been provided // INSTR_COUNT_NS1: amdgpu.instruction_sched_hint // INSTR_COUNT_NS1-SAME: isBufferLoadsAEnabled = false @@ -67,7 +68,22 @@ module { // INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>> // INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> - // USE_CKV3_GLOBAL_LOAD: error: Skipping instruction scheduling because `ck_v3` scheduling can be used only with `buffer_load` instructions + // USE_CKV3_GLOBAL_LOAD: [lower-insert-instruction-sched-hints] + // USE_CKV3_GLOBAL_LOAD-SAME: Skipping instruction scheduling because `ck_v3` scheduling can be used only with `buffer_load` instructions. + + // LABELING_PS_1: scf.for + // LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_1: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>} + // LABELING_PS_1: %[[REG1_OP0:.+]] = triton_gpu.convert_layout %[[REG0_OP0]] + // LABELING_PS_1: %[[REG1_OP1:.+]] = triton_gpu.convert_layout %[[REG0_OP1]] + // LABELING_PS_1: tt.dot %[[REG1_OP0]], %[[REG1_OP1]], {{.*}} + + // LABELING_PS_2: scf.for + // LABELING_PS_2: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_2: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>} + // LABELING_PS_2: triton_gpu.local_store %[[REG0_OP0]], %{{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_2: triton_gpu.local_store %[[REG0_OP1]], %{{.*}} {OpIdx = #amdgpu.OpIdx<0>} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr>, tensor<128x32xi32> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr>, tensor<32x128xi32> diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index e4eec7f7f827..68c50d48635b 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -72,8 +72,8 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { let builders = [ OpBuilder<(ins), [{ auto ctx = $_state.getContext(); - auto type = IntegerType::get(ctx, 32); - auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, type); + auto noneType = NoneType::get(ctx); + auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, noneType); build($_builder, $_state, emptyAttr, emptyAttr, emptyAttr, emptyAttr, emptyAttr, emptyAttr, false, false, emptyAttr); }]> diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index b6524fe23d61..99047f2024d7 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -19,6 +19,11 @@ namespace mlir::triton { using namespace mlir; +// TODO: The following passes/algorithms are applicable only for a single +// `tt.dot` op in a `scf.for` block -i.e., a single schedule hint op per block. +// Note, we need to relax this assumption in the future and extend the current +// implementation. + namespace mlir::triton { void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, unsigned k, Type elementType) { @@ -98,16 +103,13 @@ void storeOpConversionCallback(triton::gpu::LocalStoreOp op, }); } -mlir::FailureOr hasSingleDotOp(scf::ForOp forOp) { +triton::DotOp getSingleDotOpIfExists(scf::ForOp forOp) { triton::DotOp dotOp = nullptr; size_t dotCounter = 0; forOp->walk( [&dotOp, &dotCounter](triton::DotOp op) { dotOp = op, ++dotCounter; }); - if (dotCounter == 1) - return dotOp; - - return mlir::failure(); + return (dotCounter == 1) ? dotOp : nullptr; } } // namespace mlir::triton @@ -178,7 +180,7 @@ struct InstructionSchedHintsRewriter // This is the implementation of the CK's V3 pipelining (see // see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp). // This scheduling requires 1x register and 1x LDS buffers combined with the - // local (LDS to registers) and global (HBN to registers) data prefetching. + // local (LDS to registers) and global (HBM to registers) data prefetching. // see: // include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.h void @@ -187,9 +189,8 @@ struct InstructionSchedHintsRewriter if (!(schedHint.getIsBufferLoadsAEnabled() && schedHint.getIsBufferLoadsBEnabled())) { - schedHint.emitError( - "Skipping instruction scheduling because `ck_v3` " - "scheduling can be used only with `buffer_load` instructions"); + LDBG("Skipping instruction scheduling because `ck_v3` " + "scheduling can be used only with `buffer_load` instructions."); return; } @@ -208,7 +209,7 @@ struct InstructionSchedHintsRewriter schedHint.emitError("buffer load count for tile A must be initialized"); if (numBufferLoadInstB == 0) - schedHint->emitError("buffer load count for tile B must be initialized"); + schedHint.emitError("buffer load count for tile B must be initialized"); const uint32_t numMfmaInst = schedHint.getNumMMAs().getValue(); @@ -405,13 +406,10 @@ struct TritonAMDGPUInsertInstructionSchedHints ModuleOp mod = getOperation(); mod.walk([this, ctx](scf::ForOp forOp) { - auto maybeSingleDotOp = hasSingleDotOp(forOp); - // Note, instruction schedule barriers are inserted only in the case of // a single `tt.dot` op in a `scf::ForOp` scope in the current // implementation. - if (mlir::succeeded(maybeSingleDotOp)) { - triton::DotOp dotOp = maybeSingleDotOp.value(); + if (auto dotOp = getSingleDotOpIfExists(forOp)) { mlir::OpBuilder rewriter(ctx); rewriter.setInsertionPointAfter(dotOp); rewriter.create(dotOp->getLoc()); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h index 58c036dffbbf..45985fe808f2 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h @@ -20,7 +20,7 @@ void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t numDsReadsCount, Type type); void storeOpConversionCallback(triton::gpu::LocalStoreOp op, size_t llvmOpCount, Type type); -llvm::FailureOr hasSingleDotOp(scf::ForOp forOp); +triton::DotOp getSingleDotOpIfExists(scf::ForOp forOp); } // namespace mlir::triton #endif diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp index 6fd119bc4cdf..3b4935026c3f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -170,6 +170,11 @@ void StreamPipeliner::createStreamCopy( result = select->getResults(); } + // If the currently processed `LoadOp` is labeled with an index regarding + // to which `DotOp` operand the corresponding data belongs to, then label the + // expanded `LocalStoreOp` with the same index. This is required for + // instruction scheduling hints to correctly count the emitted `ds_write` + // instructions for each GEMM tile. if (auto attr = loadOp->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic())) { storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); } @@ -693,26 +698,22 @@ bool StreamPipeliner::pipelineLoop() { namespace { // Go through a single use chain to get the result of the target op after all // unary ops - e.g., `convert_layout`, `fp_to_fp`, etc. -template -FailureOr rewindUnaryOps(Value value) { - auto getNextUnaryOps = [](Value value) -> FailureOr { - auto defOp = value.getDefiningOp(); - if (isa(value)) - return failure(); - if ((defOp->getNumOperands() == 1) || llvm::dyn_cast(defOp)) - return defOp; - return failure(); +template Operation *passPrevUnaryOps(Value value) { + auto getNextUnaryOps = [](Value value) -> Operation * { + if (auto defOp = value.getDefiningOp()) { + if ((defOp->getNumOperands() == 1) || llvm::dyn_cast(defOp)) + return defOp; + } + return nullptr; }; - auto maybeUnaryOp = getNextUnaryOps(value); - while (llvm::succeeded(maybeUnaryOp)) { - auto unaryOp = maybeUnaryOp.value(); + auto unaryOp = getNextUnaryOps(value); + while (unaryOp) { if (llvm::dyn_cast(unaryOp)) return unaryOp; - - maybeUnaryOp = getNextUnaryOps(unaryOp->getOperand(0)); + unaryOp = getNextUnaryOps(unaryOp->getOperand(0)); } - return failure(); + return nullptr; } // Annotate each `tt.LoadOp` instruction with its corresponding gemm operand @@ -720,14 +721,9 @@ FailureOr rewindUnaryOps(Value value) { // we support `forOp`s which contain only a single `tt.DotOp` in the bodies. void labelLoadOpsForTritonDot(scf::ForOp forOp) { mlir::MLIRContext *ctx = forOp->getContext(); - auto maybeSingleDotOp = triton::hasSingleDotOp(forOp); - - if (llvm::succeeded(maybeSingleDotOp)) { - triton::DotOp dotOp = maybeSingleDotOp.value(); + if (auto dotOp = triton::getSingleDotOpIfExists(forOp)) { for (auto [opIdx, dotOperand] : llvm::enumerate(dotOp->getOperands())) { - auto maybeLoadOp = rewindUnaryOps(dotOperand); - if (llvm::succeeded(maybeLoadOp)) { - auto loadOp = maybeLoadOp.value(); + if (auto loadOp = passPrevUnaryOps(dotOperand)) { auto opIdxAttr = triton::amdgpu::OpIdxAttr::get(ctx, opIdx); loadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); } From a2f8874b0851b7a3f1d0fc250b86f551cb80cc19 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Mon, 28 Oct 2024 17:45:09 +0000 Subject: [PATCH 09/11] [AMD] Fixed propagation of OpIdx attribute in `LoadToBufferLoad` pass --- .../ConvertToBufferOps.cpp | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index f1d922041fcf..e66a2feb57fe 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -177,8 +177,21 @@ struct ConvertTritonLoadToBufferLoad Value maybeMask{}; if (op.getMask() && !isZeroConst(op.getMask())) maybeMask = op.getMask(); - rewriter.replaceOpWithNewOp( - op, op.getType(), basePtr, tensorOffset, maybeMask, maybeOther); + + auto bufferLoadOp = rewriter.create( + op->getLoc(), op.getType(), basePtr, tensorOffset, maybeMask, + maybeOther); + + // Propagate `OpIdxAttr` if the currently processed `tt.LoadOp` was + // labeled it. The attribute needs to be preserved for custom instruction + // scheduling. + if (auto opIdxAttr = op->getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + bufferLoadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), + opIdxAttr); + } + rewriter.replaceOp(op, bufferLoadOp); + return success(); } LDBG("Failed to convert: " << op); From ae8c3c87e8b528d9036b40676871662ddad38c04 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Tue, 29 Oct 2024 10:18:39 +0000 Subject: [PATCH 10/11] [AMD] Fixed instruction.sched lit tests --- test/TritonGPU/amd/amd-instruction-sched.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir index a1814f4dfcdb..400c219b6790 100644 --- a/test/TritonGPU/amd/amd-instruction-sched.mlir +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -51,8 +51,8 @@ module { // INSTR_COUNT_NS1-SAME: isBufferLoadsBEnabled = false // INSTR_COUNT_NS1-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>> // INSTR_COUNT_NS1-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>> - // INSTR_COUNT_NS1-SAME: numDsWritesA = #amdgpu.InstCounter<0, i32> - // INSTR_COUNT_NS1-SAME: numDsWritesB = #amdgpu.InstCounter<0, i32> + // INSTR_COUNT_NS1-SAME: numDsWritesA = #amdgpu.InstCounter<0, none> + // INSTR_COUNT_NS1-SAME: numDsWritesB = #amdgpu.InstCounter<0, none> // INSTR_COUNT_NS1-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>> // INSTR_COUNT_NS1-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>> // INSTR_COUNT_NS1-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> @@ -82,7 +82,7 @@ module { // LABELING_PS_2: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} // LABELING_PS_2: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>} // LABELING_PS_2: triton_gpu.local_store %[[REG0_OP0]], %{{.*}} {OpIdx = #amdgpu.OpIdx<0>} - // LABELING_PS_2: triton_gpu.local_store %[[REG0_OP1]], %{{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_2: triton_gpu.local_store %[[REG0_OP1]], %{{.*}} {OpIdx = #amdgpu.OpIdx<1>} %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr>, tensor<128x32xi32> From 2504666b40568ea8cb5bd511c3372a0eeabc4bc7 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 29 Oct 2024 22:44:47 +0000 Subject: [PATCH 11/11] Drop unnecessary mlir::prefix and early return if none choice --- .../TritonAMDGPUToLLVM/SchedInstructions.cpp | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index 99047f2024d7..62ef7a164337 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -126,7 +126,7 @@ void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc, rewriter.getI32IntegerAttr(static_cast(sizeValue)); IntegerAttr groupId = rewriter.getI32IntegerAttr(static_cast(groupIdValue)); - rewriter.create(loc, mask, size, groupId); + rewriter.create(loc, mask, size, groupId); } // Insert intrinsic that controls the types of instructions that may be @@ -135,7 +135,7 @@ Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc, mlir::amdgpu::sched_barrier_opt_enum maskValue) { IntegerAttr mask = rewriter.getI32IntegerAttr(static_cast(maskValue)); - return rewriter.create(loc, mask); + return rewriter.create(loc, mask); } // Insert an experimental intrinsic for instruction group level parallelism. @@ -143,13 +143,13 @@ Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc, Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) { IntegerAttr iglpValue = rewriter.getI32IntegerAttr(static_cast(value)); - return rewriter.create(loc, iglpValue); + return rewriter.create(loc, iglpValue); } struct InstructionSchedHintsRewriter : public OpRewritePattern { - InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, int32_t numStages, + InstructionSchedHintsRewriter(MLIRContext *ctx, int32_t numStages, std::string variant) : OpRewritePattern(ctx), numStages(numStages) { std::transform(variant.begin(), variant.end(), variant.begin(), @@ -306,11 +306,15 @@ struct InstructionSchedHintsRewriter LogicalResult matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, PatternRewriter &rewriter) const override { + if (this->schedulingType == SchedulingType::NONE) { + rewriter.eraseOp(instructionSchedHint); + return success(); + } if (this->schedulingType == SchedulingType::UNKNOWN) { instructionSchedHint.emitError( "unknown instruction scheduling variant has been provided"); - return mlir::failure(); + return failure(); } // The switch controls whether instructions are allowed to cross the basic @@ -354,7 +358,7 @@ struct InstructionSchedHintsRewriter mlir::amdgpu::sched_barrier_opt_enum::none); rewriter.eraseOp(instructionSchedHint); - return mlir::success(); + return success(); } private: @@ -389,8 +393,8 @@ struct TritonAMDGPULowerInstructionSchedHints this->variant); - if (mlir::failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { signalPassFailure(); } @@ -410,7 +414,7 @@ struct TritonAMDGPUInsertInstructionSchedHints // a single `tt.dot` op in a `scf::ForOp` scope in the current // implementation. if (auto dotOp = getSingleDotOpIfExists(forOp)) { - mlir::OpBuilder rewriter(ctx); + OpBuilder rewriter(ctx); rewriter.setInsertionPointAfter(dotOp); rewriter.create(dotOp->getLoc()); }