diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 25a891c2f7d9..27e9eb6914bb 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -59,6 +59,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { // TritonAMDGPUTransforms passes mlir::registerTritonAMDGPUAccelerateMatmul(); mlir::registerTritonAMDGPUOptimizeEpilogue(); + mlir::registerTritonAMDGPUBypassLDSForDotOperand(); mlir::registerTritonAMDGPUReorderInstructions(); mlir::registerTritonAMDGPUStreamPipelineV2(); mlir::registerTritonAMDGPUCanonicalizePointers(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index e688b52245ee..80cd95f907fd 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -192,6 +192,8 @@ bool isPureUnaryInlineAsm(Operation *op); // read the compute capability from the module attributes int getNVIDIAComputeCapability(Operation *module); +// Convert \param op operands and results to layout \param encoding. +void convertOpEncoding(Attribute encoding, Operation *op); } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index b3814329ae72..cf0f952ce971 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -104,55 +104,6 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase { threadsPerWarp, CTALayout); } - static Type getNewType(Type type, Attribute encoding) { - RankedTensorType tensorType = cast(type); - return RankedTensorType::get(tensorType.getShape(), - tensorType.getElementType(), encoding); - } - - void coalesceOp(Attribute encoding, Operation *op) { - OpBuilder builder(op); - // Convert operands - // For load/store with tensor pointers, we don't have to change the - // operands' type, we do this by changing the outputs' type of - // `make_tensor_ptr` - SmallVector newArgs; - for (auto operand : op->getOperands()) { - auto tensorType = dyn_cast(operand.getType()); - if (tensorType && - !isa(tensorType.getEncoding())) { - Type newType = getNewType(tensorType, encoding); - newArgs.push_back(builder.create( - op->getLoc(), newType, operand)); - } else { - newArgs.push_back(operand); - } - } - - // Convert output types - SmallVector newTypes; - for (auto t : op->getResultTypes()) { - bool isAsync = isa(op); - newTypes.push_back(isAsync ? t : getNewType(t, encoding)); - } - - // Construct new op with the new encoding - Operation *newOp = - builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs, - newTypes, op->getAttrs()); - - // Cast the results back to the original layout - for (size_t i = 0; i < op->getNumResults(); i++) { - Value newResult = newOp->getResult(i); - if (newTypes[i] != op->getResultTypes()[i]) { - newResult = builder.create( - op->getLoc(), op->getResult(i).getType(), newResult); - } - op->getResult(i).replaceAllUsesWith(newResult); - } - op->erase(); - } - void runOnOperation() override { // Run axis info analysis ModuleOp moduleOp = getOperation(); @@ -187,7 +138,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase { // 4. Convert the output of this new memory op back to L1 // 5. Replace all the uses of the original memory op by the new one for (auto &kv : layoutMap) { - coalesceOp(kv.second, kv.first); + convertOpEncoding(kv.second, kv.first); } } }; diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index cee1ae84ef59..8d45c0ca19f1 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -967,10 +967,15 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { void LayoutRematerialization::backwardRematerialization( ConvertLayoutOp convertOp) { - // we don't handle conversions to DotOperandEncodingAttr - // this is a heuristic to accommodate fused attention + // Skip conversions to DotOperandEncodingAttr when the operand index is 0. + // This heuristic is applied to prevent moving the blocked->dot conversion of + // the Q tensor (a loop invariant in Flash Attention) outside the loop. Doing + // so can increase register pressure and cause spilling in some cases. + // TODO: Fix this logic to avoid propagating conversions backward unless + // it reduces the total number of conversions. RankedTensorType targetType = convertOp.getType(); - if (isa(targetType.getEncoding())) + auto dotEnc = dyn_cast(targetType.getEncoding()); + if (dotEnc && dotEnc.getOpIdx() == 0) return; Value oldV = convertOp->getOperand(0); LDBG("check backward remat with source " << oldV << " encoding " @@ -1010,10 +1015,15 @@ void LayoutRematerialization::backwardRematerialization( // of the convert. void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( ConvertLayoutOp convertOp) { - // we don't handle conversions to DotOperandEncodingAttr - // this is a heuristics to accommodate fused attention + // Skip conversions to DotOperandEncodingAttr when the operand index is 0. + // This heuristic is applied to prevent moving the blocked->dot conversion of + // the Q tensor (a loop invariant in Flash Attention) outside the loop. Doing + // so can increase register pressure and cause spilling in some cases. + // TODO: Fix this logic to avoid propagating conversions backward unless + // it reduces the total number of conversions. RankedTensorType targetType = convertOp.getType(); - if (mlir::isa(targetType.getEncoding())) + auto dotEnc = dyn_cast(targetType.getEncoding()); + if (dotEnc && dotEnc.getOpIdx() == 0) return; auto isExtOrBroadcastOp = [](Operation *op) { diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 4ef9d1cd1d11..176e6ea5970b 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -930,6 +930,54 @@ int getNVIDIAComputeCapability(Operation *module) { return computeCapability; } +static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); +} + +void convertOpEncoding(Attribute encoding, Operation *op) { + OpBuilder builder(op); + // Convert operands + // For load/store with tensor pointers, we don't have to change the + // operands' type, we do this by changing the outputs' type of + // `make_tensor_ptr` + SmallVector newArgs; + for (auto operand : op->getOperands()) { + auto tensorType = dyn_cast(operand.getType()); + if (tensorType && + !isa(tensorType.getEncoding())) { + Type newType = getNewType(tensorType, encoding); + newArgs.push_back(builder.create( + op->getLoc(), newType, operand)); + } else { + newArgs.push_back(operand); + } + } + + // Convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) { + bool isAsync = isa(op); + newTypes.push_back(isAsync ? t : getNewType(t, encoding)); + } + + // Construct new op with the new encoding + Operation *newOp = builder.create(op->getLoc(), op->getName().getIdentifier(), + newArgs, newTypes, op->getAttrs()); + + // Cast the results back to the original layout + for (size_t i = 0; i < op->getNumResults(); i++) { + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = builder.create( + op->getLoc(), op->getResult(i).getType(), newResult); + } + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); +} + namespace { /// Detect dead arguments in scf.for op by assuming all the values are dead and diff --git a/test/TritonGPU/amd/bypass-lds.mlir b/test/TritonGPU/amd/bypass-lds.mlir new file mode 100644 index 000000000000..d2394f5a8f4d --- /dev/null +++ b/test/TritonGPU/amd/bypass-lds.mlir @@ -0,0 +1,158 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-bypass-lds-for-dot-operand -tritonamdgpu-stream-pipeline-v2=num_stages=2 -tritongpu-remove-layout-conversions | FileCheck %s + +// For Bypass LDS optimization to be efficient we need collaboration of 3 passes: +// 1) Pipelining pass: This is because pipelining in registers is necessary. +// 2) Bypass LDS pass: To convert load from blocked->dot layout. +// 3) Remove layout conversion pass: To remove blocked->dot layout by changing layout of all ops that form tensor of pointers to dot layout. +// Check that all of the optimizations were done properly to create efficient IR. + +// CHECK-LABEL: bypass_lds +// CHECK-NOT: triton_gpu.convert_layout %{{.*}} : tensor<{{.*}}, #blocked2> -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> +// CHECK: %[[DOT_LOAD_1:.+]] = tt.load %{{.*}} : tensor<64x256x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> +// CHECK: %{{.*}} = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %[[DOT_LOAD_1]], %{{.*}} = %{{.*}}) -> (tensor<256x256xf32, #mma>, tensor<256x64x!tt.ptr, #blocked>, i32, !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, tensor<64x256x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>) +// CHECK: %[[DOT_LOAD_2:.+]] = tt.load %{{.*}} : tensor<64x256x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> +// CHECK: scf.yield %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[DOT_LOAD_2:.+]], %{{.*}} : tensor<256x256xf32, #mma>, tensor<256x64x!tt.ptr, #blocked>, i32, !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, tensor<64x256x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> + +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#sliced_blocked1 = #triton_gpu.slice<{parent=#blocked1, dim=0}> +#sliced_blocked2 = #triton_gpu.slice<{parent=#blocked2, dim=0}> +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 16], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @bypass_lds(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<64> : tensor<256x64xi32, #blocked1> + %cst_0 = arith.constant dense<64> : tensor<64x256xi32, #blocked2> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c64_i32 = arith.constant 64 : i32 + %c63_i32 = arith.constant 63 : i32 + %c8_i32 = arith.constant 8 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mfma> + %a_ptr_splat = tt.splat %arg0 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked1> + %a_tmp0 = tt.make_range {end = 64: i32, start = 0: i32} : tensor<64xi32, #sliced_blocked1> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<64xi32, #sliced_blocked1> -> tensor<1x64xi32, #blocked1> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %b_ptr_splat = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked2> + %b_tmp0 = tt.make_range {end = 256: i32, start = 0: i32} : tensor<256xi32, #sliced_blocked2> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<256xi32, #sliced_blocked2> -> tensor<1x256xi32, #blocked2> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x256xi32, #blocked2> -> tensor<64x256xi32, #blocked2> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<64x256x!tt.ptr, #blocked2>, tensor<64x256xi32, #blocked2> + %56:3 = scf.for %arg10 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg11 = %cst_1, %arg12 = %a_ptr_init, %arg13 = %b_ptr_init) -> (tensor<256x256xf32, #mfma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked2>) : i32 { + %74 = tt.load %arg12 : tensor<256x64x!tt.ptr, #blocked1> + %75 = tt.load %arg13 : tensor<64x256x!tt.ptr, #blocked2> + %76 = triton_gpu.convert_layout %74 : tensor<256x64xf16, #blocked1> -> tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> + %77 = triton_gpu.convert_layout %75 : tensor<64x256xf16, #blocked2> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> + %78 = tt.dot %76, %77, %arg11 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> -> tensor<256x256xf32, #mfma> + %79 = tt.addptr %arg12, %cst : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %80 = tt.addptr %arg13, %cst_0 : tensor<64x256x!tt.ptr, #blocked2>, tensor<64x256xi32, #blocked2> + scf.yield %78, %79, %80 : tensor<256x256xf32, #mfma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked2> + } + %addr_res = tt.splat %arg2 : !tt.ptr -> tensor<256x256x!tt.ptr, #blocked3> + %57 = arith.truncf %56#0 : tensor<256x256xf32, #mfma> to tensor<256x256xf16, #mfma> + %72 = triton_gpu.convert_layout %addr_res : tensor<256x256x!tt.ptr, #blocked3> -> tensor<256x256x!tt.ptr, #mfma> + tt.store %72, %57 : tensor<256x256x!tt.ptr, #mfma> + tt.return + } +} + +// ----- + +// Check that bypass LDS optimization is not done because warpsPerCTA condition is not satisfied. + +// CHECK-LABEL: no_bypass_lds_warps_per_cta +// CHECK-NOT: tt.load %{{.*}} : tensor<64x256x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#sliced_blocked1 = #triton_gpu.slice<{parent=#blocked1, dim=0}> +#sliced_blocked2 = #triton_gpu.slice<{parent=#blocked2, dim=0}> +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @no_bypass_lds_warps_per_cta(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<64> : tensor<256x64xi32, #blocked1> + %cst_0 = arith.constant dense<64> : tensor<64x256xi32, #blocked2> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c64_i32 = arith.constant 64 : i32 + %c63_i32 = arith.constant 63 : i32 + %c8_i32 = arith.constant 8 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mfma> + %a_ptr_splat = tt.splat %arg0 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked1> + %a_tmp0 = tt.make_range {end = 64: i32, start = 0: i32} : tensor<64xi32, #sliced_blocked1> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<64xi32, #sliced_blocked1> -> tensor<1x64xi32, #blocked1> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %b_ptr_splat = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked2> + %b_tmp0 = tt.make_range {end = 256: i32, start = 0: i32} : tensor<256xi32, #sliced_blocked2> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<256xi32, #sliced_blocked2> -> tensor<1x256xi32, #blocked2> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x256xi32, #blocked2> -> tensor<64x256xi32, #blocked2> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<64x256x!tt.ptr, #blocked2>, tensor<64x256xi32, #blocked2> + %56:3 = scf.for %arg10 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg11 = %cst_1, %arg12 = %a_ptr_init, %arg13 = %b_ptr_init) -> (tensor<256x256xf32, #mfma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked2>) : i32 { + %74 = tt.load %arg12 : tensor<256x64x!tt.ptr, #blocked1> + %75 = tt.load %arg13 : tensor<64x256x!tt.ptr, #blocked2> + %76 = triton_gpu.convert_layout %74 : tensor<256x64xf16, #blocked1> -> tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> + %77 = triton_gpu.convert_layout %75 : tensor<64x256xf16, #blocked2> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> + %78 = tt.dot %76, %77, %arg11 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> -> tensor<256x256xf32, #mfma> + %79 = tt.addptr %arg12, %cst : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %80 = tt.addptr %arg13, %cst_0 : tensor<64x256x!tt.ptr, #blocked2>, tensor<64x256xi32, #blocked2> + scf.yield %78, %79, %80 : tensor<256x256xf32, #mfma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked2> + } + %addr_res = tt.splat %arg2 : !tt.ptr -> tensor<256x256x!tt.ptr, #blocked3> + %57 = arith.truncf %56#0 : tensor<256x256xf32, #mfma> to tensor<256x256xf16, #mfma> + %72 = triton_gpu.convert_layout %addr_res : tensor<256x256x!tt.ptr, #blocked3> -> tensor<256x256x!tt.ptr, #mfma> + tt.store %72, %57 : tensor<256x256x!tt.ptr, #mfma> + tt.return + } +} + +// ----- + +// Check that bypass LDS optimization is not done because kWidth condition is not satisfied. + +// CHECK-LABEL: no_bypass_lds_kWidth +// CHECK-NOT: tt.load %{{.*}} : tensor<64x256x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#sliced_blocked1 = #triton_gpu.slice<{parent=#blocked1, dim=0}> +#sliced_blocked2 = #triton_gpu.slice<{parent=#blocked2, dim=0}> +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 16], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @no_bypass_lds_kWidth(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<64> : tensor<256x64xi32, #blocked1> + %cst_0 = arith.constant dense<64> : tensor<64x256xi32, #blocked2> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c64_i32 = arith.constant 64 : i32 + %c63_i32 = arith.constant 63 : i32 + %c8_i32 = arith.constant 8 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mfma> + %a_ptr_splat = tt.splat %arg0 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked1> + %a_tmp0 = tt.make_range {end = 64: i32, start = 0: i32} : tensor<64xi32, #sliced_blocked1> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<64xi32, #sliced_blocked1> -> tensor<1x64xi32, #blocked1> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %b_ptr_splat = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked2> + %b_tmp0 = tt.make_range {end = 256: i32, start = 0: i32} : tensor<256xi32, #sliced_blocked2> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<256xi32, #sliced_blocked2> -> tensor<1x256xi32, #blocked2> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x256xi32, #blocked2> -> tensor<64x256xi32, #blocked2> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<64x256x!tt.ptr, #blocked2>, tensor<64x256xi32, #blocked2> + %56:3 = scf.for %arg10 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg11 = %cst_1, %arg12 = %a_ptr_init, %arg13 = %b_ptr_init) -> (tensor<256x256xf32, #mfma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked2>) : i32 { + %74 = tt.load %arg12 : tensor<256x64x!tt.ptr, #blocked1> + %75 = tt.load %arg13 : tensor<64x256x!tt.ptr, #blocked2> + %76 = triton_gpu.convert_layout %74 : tensor<256x64xf16, #blocked1> -> tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = triton_gpu.convert_layout %75 : tensor<64x256xf16, #blocked2> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %78 = tt.dot %76, %77, %arg11 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x256xf32, #mfma> + %79 = tt.addptr %arg12, %cst : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %80 = tt.addptr %arg13, %cst_0 : tensor<64x256x!tt.ptr, #blocked2>, tensor<64x256xi32, #blocked2> + scf.yield %78, %79, %80 : tensor<256x256xf32, #mfma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked2> + } + %addr_res = tt.splat %arg2 : !tt.ptr -> tensor<256x256x!tt.ptr, #blocked3> + %57 = arith.truncf %56#0 : tensor<256x256xf32, #mfma> to tensor<256x256xf16, #mfma> + %72 = triton_gpu.convert_layout %addr_res : tensor<256x256x!tt.ptr, #blocked3> -> tensor<256x256x!tt.ptr, #mfma> + tt.store %72, %57 : tensor<256x256x!tt.ptr, #mfma> + tt.return + } +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 682c1cb3019d..d1789082f31e 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -288,6 +288,23 @@ tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 } } +// CHECK-LABEL: @check_dot_op_idx1_propagation +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +tt.func @check_dot_op_idx1_propagation(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + // CHECK-NOT: triton_gpu.convert_layout {{.*}} : {{.*}} -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = {{.*}}}>> + // CHECK: tt.load {{.+}} : tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = {{.*}}}>> + %cst_1 = arith.constant dense<1> : tensor<64x64xi32, #blocked1> + %splat_load_addr = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked1> + %splat_store_addr = tt.splat %arg1 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked1> + %1 = tt.addptr %splat_load_addr, %cst_1 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %2 = triton_gpu.convert_layout %1 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> + %3 = tt.load %2 : tensor<64x64x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> + %4 = triton_gpu.convert_layout %3 : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x64xf32, #blocked1> + tt.store %splat_store_addr, %4 : tensor<64x64x!tt.ptr, #blocked1> + tt.return +} +} + // CHECK-LABEL: loop module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { tt.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 3d215a635da3..83ca58a4c09e 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -740,14 +740,17 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return %119#0 : tensor<32x32xf32, #C> } -// COMMON-LABEL: tt.func @dep_arg_two_uses -// COMMON: tt.expand_dims -// COMMON: tt.expand_dims -// COMMON: tt.expand_dims %arg5 -// COMMON-NEXT: tt.expand_dims %arg5 -// COMMON: %[[PTR0:.*]] = tt.splat %arg6 -// COMMON: %[[PTR1:.*]] = tt.addptr %[[PTR0]] -// COMMON-NEXT: tt.load %[[PTR1]] +// CHECK-LABEL: tt.func @dep_arg_two_uses +// CHECK: tt.expand_dims +// CHECK: tt.expand_dims +// CHECK: tt.expand_dims %arg5 +// CHECK-NEXT: tt.expand_dims %arg5 +// CHECK: %[[PTR0:.*]] = tt.splat %arg6 +// CHECK: %[[PTR1:.*]] = tt.addptr %[[PTR0]] +// CHECK-NEXT: tt.load %[[PTR1]] +// AMD-LABEL: tt.func @dep_arg_two_uses +// AMD-COUNT-3: tt.load +// AMD: scf.for tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index a53a06dd4248..a7c87df217be 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -215,6 +215,8 @@ def make_ttgir(mod, metadata, options): passes.ttgpuir.add_remove_layout_conversions(pm) amd.passes.ttgpuir.add_optimize_epilogue(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) + amd.passes.ttgpuir.add_tritongpu_bypass_lds_for_dot_operand(pm) + if amd.has_matrix_core_feature(options.arch): assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. " "We used to trigger software pipelining with " diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index 841137887ba0..0fd90ff63f78 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -20,6 +20,7 @@ std::unique_ptr createTritonAMDGPUReorderInstructionsPass(); std::unique_ptr createTritonAMDGPUVerifier(); std::unique_ptr createTritonAMDGPUOptimizeEpiloguePass(); +std::unique_ptr createTritonAMDGPUBypassLDSForDotOperand(); std::unique_ptr createTritonAMDGPUCanonicalizePointersPass(); diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index d59935e796fa..9fc7b187b853 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -96,7 +96,34 @@ def TritonAMDGPUCanonicalizePointers : Pass<"tritonamdgpu-canonicalize-pointers" let constructor = "mlir::createTritonAMDGPUCanonicalizePointersPass()"; let dependentDialects = []; +} + +def TritonAMDGPUBypassLDSForDotOperand: Pass<"tritonamdgpu-bypass-lds-for-dot-operand", "mlir::ModuleOp"> { + let summary = "Bypass moving data trough LDS for dot operand when possible."; + let description = [{ + Under certain conditions, the dot layout of one of the operands allows direct + loading from HBM to VGPRs in the MFMA dot layout, without losing of vectorization of global loads + or increasing the number of global loads due to shared data between threads. + The required conditions are: + 1. K-Major Tensor Layout: + The operand we want to bypass LDS for must be K-major (i.e., row-major for + operand 0 or column-major for operand 1). This supports vectorized global + load instructions, as MFMA instructions require each thread to hold B + operand elements along the K dimension. + 2. kWidth * sizeof(dataType) == 128: + Using the maximum kWidth for a specific data type ensures optimal global + load vectorization (e.g., using global_load_dwordx4 instructions). + 3. Single Warp per CTA Dimension: + Either warpsPerCTA[ndim] == 1 for operand A bypass or warpsPerCTA[mDim] == + 1 for operand B bypass. This guarantees that each tensor element is + handled by exactly one thread, maintaining the same number of global loads + as in the blocked layout (i.e., each element is loaded only once). + }]; + + let constructor = "mlir::createTritonAMDGPUBypassLDSForDotOperand()"; + + let dependentDialects = []; } def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", "mlir::ModuleOp"> { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AMDBypassLDSForDotOperand.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AMDBypassLDSForDotOperand.cpp new file mode 100644 index 000000000000..577f0bdc25a7 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AMDBypassLDSForDotOperand.cpp @@ -0,0 +1,176 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +//===----------------------------------------------------------------------===// +// AMDBypassLDSForDotOperandPass Overview +// +// The AMDBypassLDSForDotOperandPass implements a strategy to bypass using the +// Local Data Share (LDS) for one of the operands in an MFMA dot operation. +// +//===----------------------------------------------------------------------===// +// +// Standard Data Flow for MFMA Dot Operations: +// +// Typically, the data flow for operands in a dot operation involves three main +// steps: +// +// 1. Load Tensor from HBM to VGPRs: +// The tensor is initially loaded into the VGPRs using a blocked (coalesced) +// layout. +// +// 2. Write Tensor to Shared Memory (LDS): +// This step is used for data rearrangement across threads. +// +// 3. Read Tensor from Shared Memory: +// The tensor is read from shared memory using the dot layout, which is +// optimized for MFMA instructions. +// +//===----------------------------------------------------------------------===// +// +// Coalescing in Triton: +// +// Coalescing in Triton is managed by configuring parameters for the blocked +// layout during the Coalesce pass. There are two primary levels of +// coalescing to consider: +// +// 1. Maximizing Load Width: +// Achieving the widest possible loads ensures that elements are grouped into +// larger memory transactions by the VMEM unit. This reduces the number of +// instructions needed to load data, minimizing instruction queue size and +// reducing wait times. +// +// 2. Ensuring Consecutive Thread Access: +// When consecutive threads access sequential memory addresses, the TA unit +// in the VMEM can merge multiple requests into a single, larger transaction. +// This approach significantly boosts memory bandwidth utilization. +// +//===----------------------------------------------------------------------===// +// +// Optimizing the Data Flow: +// +// Under certain conditions, the dot layout of one of the operands allows direct +// loading from HBM to VGPRs in the MFMA dot layout, without losing level 1 +// coalescing efficiency or increasing the number of global loads due to shared +// data between threads. +// +// The required conditions are: +// +// 1. K-Major (K dimension is continuous) Tensor Layout : +// The operand we want to bypass LDS for must be K-major (i.e., row-major for +// operand 0 or column-major for operand 1). This supports vectorized global +// load instructions, as MFMA instructions require each thread to hold B +// operand elements along the K dimension. +// +// 2. kWidth * sizeof(dataType) == 128: +// Using the maximum kWidth for a specific data type ensures optimal global +// load vectorization (e.g., using global_load_dwordx4 instructions). +// +// 3. Single Warp per CTA Dimension: +// Either warpsPerCTA[ndim] == 1 for operand A bypass or warpsPerCTA[mDim] == +// 1 for operand B bypass. This guarantees that each tensor element is +// handled by exactly one thread, maintaining the same number of global loads +// as in the blocked layout (i.e., each element is loaded only once). +// +//===----------------------------------------------------------------------===// +// Current Limitations: +// These limitations are temporary and will be addressed in future updates: +// +// 1. Support is limited to bypassing LDS for operand 1 (e.g., weights in +// MoE-like kernels). Bypassing for operand 0 is not yet implemented. +// +// 2. LDS bypass is only supported for the fp16 data type due to the +// kWidth == 8 condition. Other data types will be supported in the future. +//===----------------------------------------------------------------------===// + +using namespace mlir; +namespace ttg = triton::gpu; + +// Find all tt.load instructions that are involved in computation of a tensor +// for operand that is getting converted to dot layout. +SmallVector getAllLoadOpsReachingOp(Operation *op, + ModuleOp &mod) { + SmallVector loadOpsVec; + + mod.walk([&](triton::LoadOp loadOp) { + SetVector forwardSlices; + getForwardSlice((Operation *)loadOp, &forwardSlices); + if (std::find(forwardSlices.begin(), forwardSlices.end(), op) != + forwardSlices.end()) { + loadOpsVec.push_back(loadOp); + } + }); + + return loadOpsVec; +} + +struct TritonAMDGPUBypassLDSForDotOperandPass + : public TritonAMDGPUBypassLDSForDotOperandBase< + TritonAMDGPUBypassLDSForDotOperandPass> { + + TritonAMDGPUBypassLDSForDotOperandPass() = default; + + void runOnOperation() override { + ModuleOp module = getOperation(); + auto convertOps = collectConvertOps(module); + + module.dump(); + + for (ttg::ConvertLayoutOp &convertOp : convertOps) { + auto loadInsts = getAllLoadOpsReachingOp(convertOp, module); + assert(!loadInsts.empty()); + + // Convert load instructions to dot layout. + for (auto loadInst : loadInsts) { + auto loadType = + dyn_cast(loadInst.getResult().getType()); + if (!loadType) + return; + + auto dstType = llvm::cast(convertOp.getType()); + auto dstDotOp = + llvm::cast(dstType.getEncoding()); + convertOpEncoding(dstDotOp, loadInst); + } + } + } + + SmallVector collectConvertOps(ModuleOp &module) { + SmallVector convertOps; + + module.walk([&](ttg::ConvertLayoutOp cvtOp) { + if (isEligibleConvertOp(cvtOp)) + convertOps.push_back(cvtOp); + }); + + return convertOps; + } + + // Check if the required conditions and current limitations from the above doc + // are met. + bool isEligibleConvertOp(ttg::ConvertLayoutOp convertOp) { + auto srcType = dyn_cast(convertOp.getOperand().getType()); + auto dstType = dyn_cast(convertOp.getType()); + + if (!srcType || !dstType || srcType.getShape().size() != 2) + return false; + + auto srcBlocked = dyn_cast(srcType.getEncoding()); + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (!srcBlocked || !dstDotOp) + return false; + + // srcBlocked.getOrder[0] == 0 is the requirement for opIdx 1 tensor to be K + // major (required condition 1) from the above doc). + auto mfmaLayout = dyn_cast(dstDotOp.getParent()); + return mfmaLayout && dstDotOp.getKWidth() == 8 && + mfmaLayout.getWarpsPerCTA()[0] == 1 && dstDotOp.getOpIdx() == 1 && + srcBlocked.getOrder()[0] == 0; + } +}; + +std::unique_ptr mlir::createTritonAMDGPUBypassLDSForDotOperand() { + return std::make_unique(); +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index 414e4a329fdb..8486cd241eed 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -5,7 +5,7 @@ add_triton_library(TritonAMDGPUTransforms ReorderInstructions.cpp StreamPipelineV2.cpp MfmaGroup.cpp - + AMDBypassLDSForDotOperand.cpp DEPENDS TritonAMDGPUTransformsIncGen TritonGPUIR diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp index 027f06652f20..820e3d1ab176 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -307,18 +307,15 @@ void StreamPipeliner::assignMemoryLayouts() { cast(tensorTy.getElementType()).getPointeeType(); unsigned width = vec * pointeeTy.getIntOrFloatBitWidth(); - // Limit shared memory sharing to width >= 32 elements. - LDBG("Load " << *loadOp << " has width " << width); - if (width < 32) { - LDBG("Skip width<32 load " << *loadOp); - continue; - } - if (use->hasTrait()) { // Only use shared memory when feeding into a dot op. loadInfo.usedByDot = true; - loadInfo.sharedEncoding = - getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); + // Limit shared memory sharing to width >= 32 elements. + LDBG("Load " << *loadOp << " has width " << width); + if (width >= 32) { + loadInfo.sharedEncoding = + getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); + } } else if (auto useOp = dyn_cast(use)) { // The use of this loadOp is another loadOp. If the use is not in the // loadToInfo already, it means that the use is not valid for pipelining diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index a9f3a8ee2f60..0fb0416bf151 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -66,6 +66,8 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUOptimizeEpiloguePass); ADD_PASS_WRAPPER_0("add_canonicalize_pointers", mlir::createTritonAMDGPUCanonicalizePointersPass); + ADD_PASS_WRAPPER_0("add_tritongpu_bypass_lds_for_dot_operand", + mlir::createTritonAMDGPUBypassLDSForDotOperand); ADD_PASS_WRAPPER_0("add_reorder_instructions", mlir::createTritonAMDGPUReorderInstructionsPass); ADD_PASS_WRAPPER_1("add_stream_pipelinev2",