diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 4bd3ed0fe5b5..ebd32ad4422f 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -60,7 +60,6 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { // TritonAMDGPUTransforms passes mlir::registerTritonAMDGPUAccelerateMatmul(); mlir::registerTritonAMDGPUOptimizeEpilogue(); - mlir::registerTritonAMDGPUBypassLDSForDotOperand(); mlir::registerTritonAMDGPUReorderInstructions(); mlir::registerTritonAMDGPUBlockPingpong(); mlir::registerTritonAMDGPUStreamPipeline(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 5e0c192afbc5..62f40417574f 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -205,8 +205,6 @@ enum class MMALoadType { }; MMALoadType getMMALoadType(Operation *loadOp); -// Convert \param op operands and results to layout \param encoding. -void convertOpEncoding(Attribute encoding, Operation *op); } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 3ef4fd1c4b1a..ba9440a40d61 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -31,7 +31,6 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_ENABLE_LLVM_DEBUG", "TRITON_HIP_STREAM_PREFETCH", "TRITON_HIP_USE_BLOCK_PINGPONG", - "TRITON_HIP_BYPASS_LDS_FOR_DOT", "TRITON_LLVM_DEBUG_ONLY", "TRITON_ENABLE_ASAN", "TRITON_OVERRIDE_ARCH", diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index cf0f952ce971..b3814329ae72 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -104,6 +104,55 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase { threadsPerWarp, CTALayout); } + static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + } + + void coalesceOp(Attribute encoding, Operation *op) { + OpBuilder builder(op); + // Convert operands + // For load/store with tensor pointers, we don't have to change the + // operands' type, we do this by changing the outputs' type of + // `make_tensor_ptr` + SmallVector newArgs; + for (auto operand : op->getOperands()) { + auto tensorType = dyn_cast(operand.getType()); + if (tensorType && + !isa(tensorType.getEncoding())) { + Type newType = getNewType(tensorType, encoding); + newArgs.push_back(builder.create( + op->getLoc(), newType, operand)); + } else { + newArgs.push_back(operand); + } + } + + // Convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) { + bool isAsync = isa(op); + newTypes.push_back(isAsync ? t : getNewType(t, encoding)); + } + + // Construct new op with the new encoding + Operation *newOp = + builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs, + newTypes, op->getAttrs()); + + // Cast the results back to the original layout + for (size_t i = 0; i < op->getNumResults(); i++) { + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = builder.create( + op->getLoc(), op->getResult(i).getType(), newResult); + } + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); + } + void runOnOperation() override { // Run axis info analysis ModuleOp moduleOp = getOperation(); @@ -138,7 +187,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase { // 4. Convert the output of this new memory op back to L1 // 5. Replace all the uses of the original memory op by the new one for (auto &kv : layoutMap) { - convertOpEncoding(kv.second, kv.first); + coalesceOp(kv.second, kv.first); } } }; diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index ea93a918915d..9ea074d67181 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -1022,43 +1022,6 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { } } -bool shouldPropagateConversion(ConvertLayoutOp convertOp) { - RankedTensorType targetType = convertOp.getType(); - auto dotEnc = dyn_cast(targetType.getEncoding()); - // If the target encoding is not DotOperandEncodingAttr, allow propagation. - if (!dotEnc) { - return true; - } - // Skip conversions to DotOperandEncodingAttr when the operand index is 0. - // This heuristic is applied to prevent moving the blocked->dot conversion of - // the Q tensor (a loop invariant in Flash Attention) outside the loop. Doing - // so can increase register pressure and cause spilling in some cases. - if (dotEnc.getOpIdx() == 0) { - return false; - } - // Skip conversions to DotOperandEncodingAttr when the operand index is 1 if - // it's not intentionally placed above a load as we have to be a bit more - // careful with the heuristics for both correctness and performance. - // TODO: Fix this logic to avoid propagating conversions backward unless - // it reduces the total number of conversions. - assert(dotEnc.getOpIdx() == 1); - SetVector slice; - BackwardSliceOptions opt; - opt.omitBlockArguments = true; - opt.filter = [&](Operation *op) { - return op->getParentRegion() == convertOp->getParentRegion(); - }; - getBackwardSlice(convertOp.getOperation(), &slice, opt); - - for (Operation *currOp : slice) { - if (isa(currOp)) { - return false; - } - } - // Allow propagation if no LoadOp is found. - return true; -} - void LayoutRematerialization::hoistConvertIntoConditionals() { // Go through each ConvertLayoutOp. SmallVector convertOps; @@ -1077,11 +1040,11 @@ void LayoutRematerialization::hoistConvertIntoConditionals() { void LayoutRematerialization::backwardRematerialization( ConvertLayoutOp convertOp) { + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristic to accommodate fused attention RankedTensorType targetType = convertOp.getType(); - if (!shouldPropagateConversion(convertOp)) { + if (isa(targetType.getEncoding())) return; - } - Value oldV = convertOp.getSrc(); LDBG("check backward remat with source " << oldV << " encoding " << targetType.getEncoding()); @@ -1120,10 +1083,11 @@ void LayoutRematerialization::backwardRematerialization( // of the convert. void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( ConvertLayoutOp convertOp) { + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristics to accommodate fused attention RankedTensorType targetType = convertOp.getType(); - if (!shouldPropagateConversion(convertOp)) { + if (isa(targetType.getEncoding())) return; - } auto isExtOrBroadcastOp = [](Operation *op) { if (isa(type); - return RankedTensorType::get(tensorType.getShape(), - tensorType.getElementType(), encoding); -} - -void convertOpEncoding(Attribute encoding, Operation *op) { - OpBuilder builder(op); - // Convert operands - // For load/store with tensor pointers, we don't have to change the - // operands' type, we do this by changing the outputs' type of - // `make_tensor_ptr` - SmallVector newArgs; - for (auto operand : op->getOperands()) { - auto tensorType = dyn_cast(operand.getType()); - if (tensorType && - !isa(tensorType.getEncoding())) { - Type newType = getNewType(tensorType, encoding); - newArgs.push_back(builder.create( - op->getLoc(), newType, operand)); - } else { - newArgs.push_back(operand); - } - } - - // Convert output types - SmallVector newTypes; - for (auto t : op->getResultTypes()) { - bool isAsync = isa(op); - newTypes.push_back(isAsync ? t : getNewType(t, encoding)); - } - - // Construct new op with the new encoding - Operation *newOp = builder.create(op->getLoc(), op->getName().getIdentifier(), - newArgs, newTypes, op->getAttrs()); - - // Cast the results back to the original layout - for (size_t i = 0; i < op->getNumResults(); i++) { - Value newResult = newOp->getResult(i); - if (newTypes[i] != op->getResultTypes()[i]) { - newResult = builder.create( - op->getLoc(), op->getResult(i).getType(), newResult); - } - op->getResult(i).replaceAllUsesWith(newResult); - } - op->erase(); -} - namespace { /// Detect dead arguments in scf.for op by assuming all the values are dead and diff --git a/test/TritonGPU/amd/bypass-lds.mlir b/test/TritonGPU/amd/bypass-lds.mlir deleted file mode 100644 index 549a44941413..000000000000 --- a/test/TritonGPU/amd/bypass-lds.mlir +++ /dev/null @@ -1,154 +0,0 @@ -// RUN: triton-opt %s -split-input-file -tritonamdgpu-bypass-lds-for-dot-operand -tritongpu-remove-layout-conversions | FileCheck %s - -// For Bypass LDS optimization to be efficient we need collaboration of 2 passes: -// 1) Bypass LDS pass: To convert load from blocked->dot layout. -// 2) Remove layout conversion pass: To remove blocked->dot layout by changing layout of all ops that form tensor of pointers to dot layout. -// Check that all of the optimizations were done properly to create efficient IR. - -// CHECK-LABEL: bypass_lds -// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<{{.*}}, #blocked2> -> tensor<{{.*}}, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> -// CHECK: %[[DOT_LOAD:.+]] = tt.load %{{.*}} : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -// CHECK: tt.dot %{{.*}}, %[[DOT_LOAD:.+]], %{{.*}} : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<256x256xf32, #mma> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#sliced_blocked1 = #ttg.slice<{parent=#blocked1, dim=0}> -#sliced_blocked2 = #ttg.slice<{parent=#blocked2, dim=0}> -#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 16], isTransposed = true}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { - tt.func public @bypass_lds(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %cst = arith.constant dense<64> : tensor<256x64xi32, #blocked1> - %cst_0 = arith.constant dense<64> : tensor<64x256xi32, #blocked2> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c64_i32 = arith.constant 64 : i32 - %c63_i32 = arith.constant 63 : i32 - %c8_i32 = arith.constant 8 : i32 - %cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mfma> - %a_ptr_splat = tt.splat %arg0 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked1> - %a_tmp0 = tt.make_range {end = 64: i32, start = 0: i32} : tensor<64xi32, #sliced_blocked1> - %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<64xi32, #sliced_blocked1> -> tensor<1x64xi32, #blocked1> - %a_offs = tt.broadcast %a_tmp1 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> - %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> - %b_ptr_splat = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked2> - %b_tmp0 = tt.make_range {end = 256: i32, start = 0: i32} : tensor<256xi32, #sliced_blocked2> - %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<256xi32, #sliced_blocked2> -> tensor<1x256xi32, #blocked2> - %b_offs = tt.broadcast %b_tmp1 : tensor<1x256xi32, #blocked2> -> tensor<64x256xi32, #blocked2> - %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<64x256x!tt.ptr, #blocked2>, tensor<64x256xi32, #blocked2> - %56:3 = scf.for %arg10 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg11 = %cst_1, %arg12 = %a_ptr_init, %arg13 = %b_ptr_init) -> (tensor<256x256xf32, #mfma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked2>) : i32 { - %74 = tt.load %arg12 : tensor<256x64x!tt.ptr, #blocked1> - %75 = tt.load %arg13 : tensor<64x256x!tt.ptr, #blocked2> - %76 = ttg.convert_layout %74 : tensor<256x64xf16, #blocked1> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> - %77 = ttg.convert_layout %75 : tensor<64x256xf16, #blocked2> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> - %78 = tt.dot %76, %77, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> -> tensor<256x256xf32, #mfma> - %79 = tt.addptr %arg12, %cst : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> - %80 = tt.addptr %arg13, %cst_0 : tensor<64x256x!tt.ptr, #blocked2>, tensor<64x256xi32, #blocked2> - scf.yield %78, %79, %80 : tensor<256x256xf32, #mfma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked2> - } - %addr_res = tt.splat %arg2 : !tt.ptr -> tensor<256x256x!tt.ptr, #blocked3> - %57 = arith.truncf %56#0 : tensor<256x256xf32, #mfma> to tensor<256x256xf16, #mfma> - %72 = ttg.convert_layout %addr_res : tensor<256x256x!tt.ptr, #blocked3> -> tensor<256x256x!tt.ptr, #mfma> - tt.store %72, %57 : tensor<256x256x!tt.ptr, #mfma> - tt.return - } -} - -// ----- - -// Check that bypass LDS optimization is not done because warpsPerCTA condition is not satisfied. - -// CHECK-LABEL: no_bypass_lds_warps_per_cta -// CHECK-NOT: tt.load %{{.*}} : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#sliced_blocked1 = #ttg.slice<{parent=#blocked1, dim=0}> -#sliced_blocked2 = #ttg.slice<{parent=#blocked2, dim=0}> -#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { - tt.func public @no_bypass_lds_warps_per_cta(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %cst = arith.constant dense<64> : tensor<256x64xi32, #blocked1> - %cst_0 = arith.constant dense<64> : tensor<64x256xi32, #blocked2> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c64_i32 = arith.constant 64 : i32 - %c63_i32 = arith.constant 63 : i32 - %c8_i32 = arith.constant 8 : i32 - %cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mfma> - %a_ptr_splat = tt.splat %arg0 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked1> - %a_tmp0 = tt.make_range {end = 64: i32, start = 0: i32} : tensor<64xi32, #sliced_blocked1> - %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<64xi32, #sliced_blocked1> -> tensor<1x64xi32, #blocked1> - %a_offs = tt.broadcast %a_tmp1 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> - %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> - %b_ptr_splat = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked2> - %b_tmp0 = tt.make_range {end = 256: i32, start = 0: i32} : tensor<256xi32, #sliced_blocked2> - %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<256xi32, #sliced_blocked2> -> tensor<1x256xi32, #blocked2> - %b_offs = tt.broadcast %b_tmp1 : tensor<1x256xi32, #blocked2> -> tensor<64x256xi32, #blocked2> - %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<64x256x!tt.ptr, #blocked2>, tensor<64x256xi32, #blocked2> - %56:3 = scf.for %arg10 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg11 = %cst_1, %arg12 = %a_ptr_init, %arg13 = %b_ptr_init) -> (tensor<256x256xf32, #mfma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked2>) : i32 { - %74 = tt.load %arg12 : tensor<256x64x!tt.ptr, #blocked1> - %75 = tt.load %arg13 : tensor<64x256x!tt.ptr, #blocked2> - %76 = ttg.convert_layout %74 : tensor<256x64xf16, #blocked1> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> - %77 = ttg.convert_layout %75 : tensor<64x256xf16, #blocked2> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> - %78 = tt.dot %76, %77, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> -> tensor<256x256xf32, #mfma> - %79 = tt.addptr %arg12, %cst : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> - %80 = tt.addptr %arg13, %cst_0 : tensor<64x256x!tt.ptr, #blocked2>, tensor<64x256xi32, #blocked2> - scf.yield %78, %79, %80 : tensor<256x256xf32, #mfma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked2> - } - %addr_res = tt.splat %arg2 : !tt.ptr -> tensor<256x256x!tt.ptr, #blocked3> - %57 = arith.truncf %56#0 : tensor<256x256xf32, #mfma> to tensor<256x256xf16, #mfma> - %72 = ttg.convert_layout %addr_res : tensor<256x256x!tt.ptr, #blocked3> -> tensor<256x256x!tt.ptr, #mfma> - tt.store %72, %57 : tensor<256x256x!tt.ptr, #mfma> - tt.return - } -} - -// ----- - -// Check that bypass LDS optimization is not done because kWidth condition is not satisfied. - -// CHECK-LABEL: no_bypass_lds_kWidth -// CHECK-NOT: tt.load %{{.*}} : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#sliced_blocked1 = #ttg.slice<{parent=#blocked1, dim=0}> -#sliced_blocked2 = #ttg.slice<{parent=#blocked2, dim=0}> -#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 16], isTransposed = true}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { - tt.func public @no_bypass_lds_kWidth(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %cst = arith.constant dense<64> : tensor<256x64xi32, #blocked1> - %cst_0 = arith.constant dense<64> : tensor<64x256xi32, #blocked2> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c64_i32 = arith.constant 64 : i32 - %c63_i32 = arith.constant 63 : i32 - %c8_i32 = arith.constant 8 : i32 - %cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mfma> - %a_ptr_splat = tt.splat %arg0 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked1> - %a_tmp0 = tt.make_range {end = 64: i32, start = 0: i32} : tensor<64xi32, #sliced_blocked1> - %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<64xi32, #sliced_blocked1> -> tensor<1x64xi32, #blocked1> - %a_offs = tt.broadcast %a_tmp1 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> - %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> - %b_ptr_splat = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked2> - %b_tmp0 = tt.make_range {end = 256: i32, start = 0: i32} : tensor<256xi32, #sliced_blocked2> - %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<256xi32, #sliced_blocked2> -> tensor<1x256xi32, #blocked2> - %b_offs = tt.broadcast %b_tmp1 : tensor<1x256xi32, #blocked2> -> tensor<64x256xi32, #blocked2> - %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<64x256x!tt.ptr, #blocked2>, tensor<64x256xi32, #blocked2> - %56:3 = scf.for %arg10 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg11 = %cst_1, %arg12 = %a_ptr_init, %arg13 = %b_ptr_init) -> (tensor<256x256xf32, #mfma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked2>) : i32 { - %74 = tt.load %arg12 : tensor<256x64x!tt.ptr, #blocked1> - %75 = tt.load %arg13 : tensor<64x256x!tt.ptr, #blocked2> - %76 = ttg.convert_layout %74 : tensor<256x64xf16, #blocked1> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> - %77 = ttg.convert_layout %75 : tensor<64x256xf16, #blocked2> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %78 = tt.dot %76, %77, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x256xf32, #mfma> - %79 = tt.addptr %arg12, %cst : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> - %80 = tt.addptr %arg13, %cst_0 : tensor<64x256x!tt.ptr, #blocked2>, tensor<64x256xi32, #blocked2> - scf.yield %78, %79, %80 : tensor<256x256xf32, #mfma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked2> - } - %addr_res = tt.splat %arg2 : !tt.ptr -> tensor<256x256x!tt.ptr, #blocked3> - %57 = arith.truncf %56#0 : tensor<256x256xf32, #mfma> to tensor<256x256xf16, #mfma> - %72 = ttg.convert_layout %addr_res : tensor<256x256x!tt.ptr, #blocked3> -> tensor<256x256x!tt.ptr, #mfma> - tt.store %72, %57 : tensor<256x256x!tt.ptr, #mfma> - tt.return - } -} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index f458b681f549..54da2b869b7d 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -288,23 +288,6 @@ tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 } } -// CHECK-LABEL: @check_dot_op_idx1_propagation -module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { -tt.func @check_dot_op_idx1_propagation(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { - // CHECK-NOT: ttg.convert_layout {{.*}} : {{.*}} -> tensor<{{.*}}, #ttg.dot_op<{opIdx = 1, parent = {{.*}}}>> - // CHECK: tt.load {{.+}} : tensor<{{.*}}, #ttg.dot_op<{opIdx = 1, parent = {{.*}}}>> - %cst_1 = arith.constant dense<1> : tensor<64x64xi32, #blocked1> - %splat_load_addr = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked1> - %splat_store_addr = tt.splat %arg1 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked1> - %1 = tt.addptr %splat_load_addr, %cst_1 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> - %2 = ttg.convert_layout %1 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> - %3 = tt.load %2 : tensor<64x64x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> - %4 = ttg.convert_layout %3 : tensor<64x64xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x64xf32, #blocked1> - tt.store %splat_store_addr, %4 : tensor<64x64x!tt.ptr, #blocked1> - tt.return -} -} - // CHECK-LABEL: loop module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 0f94cf5846d9..0dae76c0d634 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -225,10 +225,6 @@ def make_ttgir(mod, metadata, options): stream_prefetch = os.getenv("TRITON_HIP_STREAM_PREFETCH", "0") == "1" use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1" - bypass_lds = os.environ.get("TRITON_HIP_BYPASS_LDS_FOR_DOT", "0") == "1" - - if bypass_lds: - amd.passes.ttgpuir.add_bypass_lds_for_dot_operand(pm) # The `local-prefetch` scheduling variant requires turning on buffer ops. if options.instruction_sched_variant == "local-prefetch": diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index 44751f4fc0b2..c375d2a386b2 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -23,7 +23,6 @@ std::unique_ptr createTritonAMDGPUReorderInstructionsPass(); std::unique_ptr createTritonAMDGPUVerifier(); std::unique_ptr createTritonAMDGPUOptimizeEpiloguePass(); -std::unique_ptr createTritonAMDGPUBypassLDSForDotOperand(); std::unique_ptr createTritonAMDGPUCanonicalizePointersPass(); diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 3cd14bec0b6c..f026d1d59521 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -99,34 +99,7 @@ def TritonAMDGPUCanonicalizePointers : Pass<"tritonamdgpu-canonicalize-pointers" let constructor = "mlir::createTritonAMDGPUCanonicalizePointersPass()"; let dependentDialects = []; -} - -def TritonAMDGPUBypassLDSForDotOperand: Pass<"tritonamdgpu-bypass-lds-for-dot-operand", "mlir::ModuleOp"> { - let summary = "Bypass moving data trough LDS for dot operand when possible."; - let description = [{ - Under certain conditions, the dot layout of one of the operands allows direct - loading from HBM to VGPRs in the MFMA dot layout, without losing of vectorization of global loads - or increasing the number of global loads due to shared data between threads. - The required conditions are: - 1. K-Major Tensor Layout: - The operand we want to bypass LDS for must be K-major (i.e., row-major for - operand 0 or column-major for operand 1). This supports vectorized global - load instructions, as MFMA instructions require each thread to hold B - operand elements along the K dimension. - 2. kWidth * sizeof(dataType) == 128: - Using the maximum kWidth for a specific data type ensures optimal global - load vectorization (e.g., using global_load_dwordx4 instructions). - 3. Single Warp per CTA Dimension: - Either warpsPerCTA[ndim] == 1 for operand A bypass or warpsPerCTA[mDim] == - 1 for operand B bypass. This guarantees that each tensor element is - handled by exactly one thread, maintaining the same number of global loads - as in the blocked layout (i.e., each element is loaded only once). - }]; - - let constructor = "mlir::createTritonAMDGPUBypassLDSForDotOperand()"; - - let dependentDialects = []; } def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", "mlir::ModuleOp"> { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AMDBypassLDSForDotOperand.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AMDBypassLDSForDotOperand.cpp deleted file mode 100644 index c586b10ff975..000000000000 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AMDBypassLDSForDotOperand.cpp +++ /dev/null @@ -1,172 +0,0 @@ -#include "triton/Analysis/Utility.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#define GEN_PASS_CLASSES -#include "TritonAMDGPUTransforms/Passes.h" - -//===----------------------------------------------------------------------===// -// AMDBypassLDSForDotOperandPass Overview -// -// The AMDBypassLDSForDotOperandPass implements a strategy to bypass using the -// Local Data Share (LDS) for one of the operands in an MFMA dot operation. -// -//===----------------------------------------------------------------------===// -// -// Standard Data Flow for MFMA Dot Operations: -// -// Typically, the data flow for operands in a dot operation involves three main -// steps: -// -// 1. Load Tensor from HBM to VGPRs: -// The tensor is initially loaded into the VGPRs using a blocked (coalesced) -// layout. -// -// 2. Write Tensor to Shared Memory (LDS): -// This step is used for data rearrangement across threads. -// -// 3. Read Tensor from Shared Memory: -// The tensor is read from shared memory using the dot layout, which is -// optimized for MFMA instructions. -// -//===----------------------------------------------------------------------===// -// -// Coalescing in Triton: -// -// Coalescing in Triton is managed by configuring parameters for the blocked -// layout during the Coalesce pass. There are two primary levels of -// coalescing to consider: -// -// 1. Maximizing Load Width: -// Achieving the widest possible loads ensures that elements are grouped into -// larger memory transactions by the VMEM unit. This reduces the number of -// instructions needed to load data, minimizing instruction queue size and -// reducing wait times. -// -// 2. Ensuring Consecutive Thread Access: -// When consecutive threads access sequential memory addresses we expect -// fewer requests per chache line. This approach can significantly affect -// memory bandwidth utilization. -// -//===----------------------------------------------------------------------===// -// -// Optimizing the Data Flow: -// -// Under certain conditions, the dot layout of one of the operands allows direct -// loading from HBM to VGPRs in the MFMA dot layout, without losing level 1 -// coalescing efficiency or increasing the number of global loads due to shared -// data between threads. -// -// The required conditions are: -// -// 1. K-Contig (K dimension is continuous) Tensor Layout : -// The operand we want to bypass LDS for must be K-contig (i.e., row-major -// for operand 0 or column-major for operand 1). This supports vectorized -// global load instructions, as MFMA instructions require each thread to hold -// B operand elements along the K dimension. -// -// 2. kWidth * sizeof(dataType) == 128: -// Using the maximum kWidth for a specific data type ensures optimal global -// load vectorization (e.g., using global_load_dwordx4 instructions). -// -// 3. Single Warp per CTA Dimension: -// Either warpsPerCTA[ndim] == 1 for operand A bypass or warpsPerCTA[mDim] == -// 1 for operand B bypass. This guarantees that each tensor element is -// handled by exactly one thread, maintaining the same number of global loads -// as in the blocked layout (i.e., each element is loaded only once). -// -//===----------------------------------------------------------------------===// -// Current Limitations: -// These limitations are temporary and will be addressed in future updates: -// -// 1. Support is limited to bypassing LDS for operand 1 (e.g., weights in -// MoE-like kernels). Bypassing for operand 0 is not yet implemented. -//===----------------------------------------------------------------------===// - -using namespace mlir; -namespace ttg = triton::gpu; - -// Find all tt.load instructions that are involved in computation of a tensor -// for operand that is getting converted to dot layout. -SmallVector getAllLoadOpsReachingOp(Operation *op) { - SmallVector loadOpsVec; - SetVector backwardSlice; - BackwardSliceOptions opt; - opt.omitBlockArguments = true; - getBackwardSlice(op, &backwardSlice, opt); - - for (auto op : backwardSlice) { - if (auto loadOp = dyn_cast(op)) { - loadOpsVec.push_back(loadOp); - } - } - - return loadOpsVec; -} - -struct TritonAMDGPUBypassLDSForDotOperandPass - : public TritonAMDGPUBypassLDSForDotOperandBase< - TritonAMDGPUBypassLDSForDotOperandPass> { - - TritonAMDGPUBypassLDSForDotOperandPass() = default; - - void runOnOperation() override { - ModuleOp module = getOperation(); - auto convertOps = collectConvertOps(module); - - for (ttg::ConvertLayoutOp &convertOp : convertOps) { - auto loadInsts = getAllLoadOpsReachingOp(convertOp); - assert(!loadInsts.empty()); - - // Convert load instructions to dot layout. - for (auto loadInst : loadInsts) { - auto loadType = - dyn_cast(loadInst.getResult().getType()); - if (!loadType) - continue; - - auto dstType = llvm::cast(convertOp.getType()); - auto dstDotOp = - llvm::cast(dstType.getEncoding()); - convertOpEncoding(dstDotOp, loadInst); - } - } - } - - SmallVector collectConvertOps(ModuleOp &module) { - SmallVector convertOps; - - module.walk([&](ttg::ConvertLayoutOp cvtOp) { - if (isEligibleConvertOp(cvtOp)) - convertOps.push_back(cvtOp); - }); - - return convertOps; - } - - // Check if the required conditions and current limitations from the above doc - // are met. - bool isEligibleConvertOp(ttg::ConvertLayoutOp convertOp) { - auto srcType = dyn_cast(convertOp.getOperand().getType()); - auto dstType = dyn_cast(convertOp.getType()); - - if (!srcType || !dstType) - return false; - - auto srcBlocked = dyn_cast(srcType.getEncoding()); - auto dstDotOp = - dyn_cast(dstType.getEncoding()); - if (!srcBlocked || !dstDotOp) - return false; - - // srcBlocked.getOrder[0] == 0 is the requirement for opIdx 1 tensor to be K - // contig (required condition 1) from the above doc). - auto mfmaLayout = dyn_cast(dstDotOp.getParent()); - return mfmaLayout && - (dstDotOp.getKWidth() == 8 || dstDotOp.getKWidth() == 16) && - mfmaLayout.getWarpsPerCTA()[0] == 1 && dstDotOp.getOpIdx() == 1 && - srcBlocked.getOrder()[0] == 0; - } -}; - -std::unique_ptr mlir::createTritonAMDGPUBypassLDSForDotOperand() { - return std::make_unique(); -} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index 360c1781d8c7..d6fed403856c 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -7,7 +7,7 @@ add_triton_library(TritonAMDGPUTransforms ReorderInstructions.cpp StreamPipeline.cpp MfmaGroup.cpp - AMDBypassLDSForDotOperand.cpp + DEPENDS TritonAMDGPUIR TritonAMDGPUTransformsIncGen diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 3b9e5be905ea..9e27e06b97e9 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -71,8 +71,6 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { ADD_PASS_WRAPPER_1("add_convert_to_buffer_ops", mlir::createTritonAMDGPUConvertToBufferOpsPass, const std::string &); - ADD_PASS_WRAPPER_0("add_bypass_lds_for_dot_operand", - mlir::createTritonAMDGPUBypassLDSForDotOperand); ADD_PASS_WRAPPER_0("add_reorder_instructions", mlir::createTritonAMDGPUReorderInstructionsPass); ADD_PASS_WRAPPER_0("add_block_pingpong",