diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 1e6e38749f4c..a129cb1947c6 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -68,6 +68,9 @@ SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) { return convertType(getShapePerCTA(srcTy)); } + if (isMfmaToDotShortcut(srcTy, dstTy)) + return {}; + // MmaToDotShortcut and MmaToMmaShortcut doesn't use shared mem if (auto srcMmaLayout = mlir::dyn_cast(srcLayout)) { if (mlir::isa(dstLayout)) { @@ -111,11 +114,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - if (mlir::isa(srcLayout) && - mlir::dyn_cast(srcLayout).getIsTransposed() && - mlir::isa(dstLayout)) - if (isMfmaToDotShortcut(srcTy, dstTy)) - return {}; + assert(!isMfmaToDotShortcut(srcTy, dstTy)); auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); unsigned srcContigPerThread = diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 1851d9b6feb6..689e83b5acd9 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -581,8 +581,10 @@ bool supportMMA(Value value, int version) { bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { auto srcLayout = srcTy.getEncoding(); auto dstLayout = dstTy.getEncoding(); - auto mfmaLayout = cast(srcLayout); - auto dotOperandLayout = cast(dstLayout); + auto mfmaLayout = dyn_cast(srcLayout); + auto dotOperandLayout = dyn_cast(dstLayout); + if (mfmaLayout == nullptr || dotOperandLayout == nullptr) + return false; // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is // improved. In addition, we can enable this shortcut for regular MFMA // layout when opIdx == 1. diff --git a/test/Conversion/amd/math-denorm-handling.mlir b/test/Conversion/amd/math-denorm-handling.mlir new file mode 100644 index 000000000000..520f44db933d --- /dev/null +++ b/test/Conversion/amd/math-denorm-handling.mlir @@ -0,0 +1,25 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=True" --convert-builtin-func-to-llvm | FileCheck %s --check-prefix=LLVM_FTZ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" --convert-builtin-func-to-llvm | FileCheck %s --check-prefix=LLVM_NO_FTZ + + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { + // LLVM_FTZ: llvm.amdgcn.exp2.f32 + // LLVM_NO_FTZ: llvm.exp2.f32 + %0 = math.exp2 %arg0 : tensor<64xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { + // LLVM_FTZ: llvm.exp2.f32 + // LLVM_NO_FTZ: llvm.exp2.f32 + %0 = math.exp %arg0 : tensor<64xf32, #blocked> + tt.return + } +} diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir new file mode 100644 index 000000000000..83c9e535d8c0 --- /dev/null +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -0,0 +1,27 @@ +// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s + +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: shortcut_mfma16 + tt.func public @shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> + tt.return + } +} + +// ----- + +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: no_shortcut_mfma16 + tt.func public @no_shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { + // CHECK: store + // CHECK: load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> + tt.return + } +} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index d690d11a2e38..bdf2f863b3d6 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -158,7 +158,15 @@ def make_llir(src, metadata, options): passes.convert.add_index_to_llvmir(pm) passes.ttgpuir.add_allocate_shared_memory(pm) - amd.passes.ttgpuir.add_to_llvmir(pm, options.arch) + ## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows: + ## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless + ## of the value of kernel arg `allow_flush_denorm`. + ## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output + ## depends on the value of kernel arg `allow_flush_denorm`. + ## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument. + ## For now it is used as a controller for developers only. + __HIP_FTZ = True + amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h index d853f9c2101c..df5ad78494ab 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -25,7 +25,7 @@ createDecomposeUnsupportedConversionsPass(StringRef targetArch); } // namespace AMD std::unique_ptr> -createConvertTritonAMDGPUToLLVMPass(StringRef targetArch); +createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz); std::unique_ptr> createConvertBuiltinFuncToLLVMPass(); #define GEN_PASS_REGISTRATION diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index b14e3eb4d056..986c6763bbb3 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -15,7 +15,7 @@ def DecomposeUnsupportedAMDConversions : Pass<"decompose-unsupported-amd-convers def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::ModuleOp"> { let summary = "Convert TritonGPU to LLVM"; - let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\")"; + let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true)"; let dependentDialects = ["mlir::arith::ArithDialect", "mlir::math::MathDialect", @@ -30,6 +30,8 @@ def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::Mod let options = [ Option<"arch", "arch", "std::string", /*default*/"\"\"", "gfx target device architecture, e.g., gfx942">, + Option<"ftz", "ftz", "bool", /*default*/"true", + "flush denorms for math functions">, ]; } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index edebbbe12c07..953b01dab08a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -168,7 +168,5 @@ void populateConvertLayoutOpToLLVMPatterns( ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); - mlir::triton::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo, - patterns, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index dff317729815..dd082d25d784 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -10,8 +10,10 @@ using namespace mlir; using namespace mlir::triton; +using mlir::triton::gpu::appendOrGetExternFuncOp; using mlir::triton::gpu::ElementwiseOpConversionBase; using mlir::triton::gpu::getElementType; +using mlir::triton::gpu::getFunctionType; using mlir::triton::gpu::MultipleOperandsRange; typedef std::function(Location, ConversionPatternRewriter &, @@ -1213,23 +1215,66 @@ struct ExpOpConversionApprox ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { - // For non-FP32 input, call __nv_expf for higher-precision calculation + // For non-FP32 input, call __ocml_exp_f64 for higher-precision calculation if (elemTy.getIntOrFloatBitWidth() != 32) return {}; const double log2e = 1.4426950408889634; Value prod = fmul(f32_ty, operands[0][0], f32_val(log2e)); - return {rewriter.create(loc, f32_ty, prod, - adaptor.getAttributes().getValue())}; + // Here we use llvm.exp2.f32 instead of math::Exp2Op. The latter + // flushes denorms by default, but we want to preserve denorms by default + // for expOp. + StringRef funcName = "llvm.exp2.f32"; + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + + return {rewriter.create(loc, funcOp, prod).getResult()}; } }; +struct Exp2OpConversion + : ElementwiseOpConversionBase { + using ElementwiseOpConversionBase< + mlir::math::Exp2Op, Exp2OpConversion>::ElementwiseOpConversionBase; + + explicit Exp2OpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisInfoAnalysis, bool ftz, + PatternBenefit benefit) + : ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit), + ftz(ftz) {} + + SmallVector createDestOps(mlir::math::Exp2Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + // For non-FP32 input, call __ocml_exp2_f64 for higher-precision calculation + if (elemTy.getIntOrFloatBitWidth() != 32) + return {}; + + // On AMD backend, both intrinsics are lowered to v_exp_f32 instruction, + // which flushes input and output denorms. `llvm.amdgcn.exp2.f32` provides + // direct access to v_exp_f32. For `llvm.exp2.f32`, the LLVM backend inserts + // instructions to handle denorms iff `allow_flush_denorm` is False. + StringRef funcName = ftz ? "llvm.amdgcn.exp2.f32" : "llvm.exp2.f32"; + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + + return { + rewriter.create(loc, funcOp, operands[0]).getResult()}; + } + +private: + bool ftz; +}; + } // namespace namespace mlir::triton::AMD { void populateElementwiseOpToLLVMPatterns( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, const TargetInfo &targetInfo, PatternBenefit benefit) { @@ -1257,11 +1302,15 @@ void populateElementwiseOpToLLVMPatterns( patterns.add(typeConverter, axisInfoAnalysis, targetInfo.getISAFamily(), benefit); - // ExpOpConversionApprox will try using ex2.approx if the input type is + // ExpOpConversionApprox will try using __ocml_exp2_f32 if the input type is // FP32. For other input types, ExpOpConversionApprox will return failure and - // ElementwiseOpConversion defined below will call - // __nv_expf for higher-precision calculation + // later pass will call __ocml_exp_f64 for higher-precision calculation patterns.add(typeConverter, axisInfoAnalysis, benefit); + // Exp2OpConversion will use llvm.exp2.f32 or llvm.amdgcn.exp2.f32 + // based on the ftz flag if the input type is FP32. For FP64 input, + // Exp2OpConversion will return failure and later pass will call + // __ocml_exp2_f64 for higher-precision calculation + patterns.add(typeConverter, axisInfoAnalysis, ftz, benefit); mlir::triton::populateElementwiseOpToLLVMPatterns( typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); mlir::triton::populateMinMaxFOpToLLVMPattern( diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h index e904f6794f02..67e5369b8650 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -17,7 +17,7 @@ void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit); void populateElementwiseOpToLLVMPatterns( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, const TargetInfo &targetInfo, PatternBenefit benefit); void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index c61dd5b815d2..8649911a7c2d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -63,8 +63,9 @@ class TritonLLVMConversionTarget : public ConversionTarget { struct ConvertTritonAMDGPUToLLVM : public triton::impl::ConvertTritonAMDGPUToLLVMBase< ConvertTritonAMDGPUToLLVM> { - explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch) { + explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch, bool ftz) { this->arch = targetArch.str(); + this->ftz = ftz; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -145,48 +146,60 @@ struct ConvertTritonAMDGPUToLLVM OpBuilder::InsertPoint indexInsertPoint; RewritePatternSet patterns(context); - int benefit = patternBenefitPrioritizeOverLLVMConversions; - auto populatePatterns1 = [&](auto populateFunc) { + int commonBenefit = patternBenefitPrioritizeOverLLVMConversions; + // Make benefit for AMD specific patterns higher so they apply before common + // patterns + int AMDBenefit = commonBenefit + 1; + auto populatePatterns1 = [&](auto populateFunc, int benefit) { populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, allocation, benefit); }; - auto populatePatterns5 = [&](auto populateFunc) { + auto populatePatterns5 = [&](auto populateFunc, int benefit) { populateFunc(typeConverter, patterns, benefit); }; - auto populatePatterns6 = [&](auto populateFunc) { + auto populatePatterns6 = [&](auto populateFunc, int benefit) { populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, allocation, targetInfo, benefit); }; - auto populatePatterns7 = [&](auto populateFunc) { + auto populatePatterns7 = [&](auto populateFunc, int benefit) { populateFunc(typeConverter, patterns, targetInfo, benefit); }; AMD::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo, patterns, numWarps, - axisInfoAnalysis, benefit); + axisInfoAnalysis, AMDBenefit); + mlir::triton::populateConvertLayoutOpToLLVMPatterns( + typeConverter, targetInfo, patterns, commonBenefit); AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, - axisInfoAnalysis, benefit); - populatePatterns6(AMD::populateElementwiseOpToLLVMPatterns); + axisInfoAnalysis, AMDBenefit); + AMD::populateElementwiseOpToLLVMPatterns(typeConverter, patterns, ftz, + axisInfoAnalysis, allocation, + targetInfo, AMDBenefit); AMD::populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, patterns, - numWarps, axisInfoAnalysis, benefit); - populatePatterns7(mlir::triton::populateReduceOpToLLVMPatterns); - populatePatterns7(mlir::triton::populateScanOpToLLVMPatterns); - populatePatterns5(mlir::triton::populateViewOpToLLVMPatterns); - populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns); + numWarps, axisInfoAnalysis, + AMDBenefit); + populatePatterns7(mlir::triton::populateReduceOpToLLVMPatterns, + commonBenefit); + populatePatterns7(mlir::triton::populateScanOpToLLVMPatterns, + commonBenefit); + populatePatterns5(mlir::triton::populateViewOpToLLVMPatterns, + commonBenefit); + populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns, + commonBenefit); mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, targetInfo, - patterns, benefit); + patterns, commonBenefit); mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, - patterns, benefit); + patterns, commonBenefit); mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, - targetInfo, benefit); + targetInfo, commonBenefit); mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, - benefit); + commonBenefit); mlir::triton::populateSPMDOpToLLVMPattern(typeConverter, patterns, - targetInfo, benefit); - AMD::populateSPMDOpToLLVMPattern(typeConverter, patterns, benefit); + targetInfo, commonBenefit); + AMD::populateSPMDOpToLLVMPattern(typeConverter, patterns, AMDBenefit); // TODO(thomas): this should probably be done in a separate step to not // interfere with our own lowering of arith ops. Add arith/math's patterns // to help convert scalar expression to LLVM. @@ -200,7 +213,7 @@ struct ConvertTritonAMDGPUToLLVM mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, - targetInfo, benefit); + targetInfo, commonBenefit); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { return signalPassFailure(); } @@ -233,8 +246,8 @@ namespace mlir { namespace triton { std::unique_ptr> -createConvertTritonAMDGPUToLLVMPass(StringRef targetArch) { - return std::make_unique(targetArch); +createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz) { + return std::make_unique(targetArch, ftz); } } // namespace triton diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 7c1f12c51760..ddc1feb2aa94 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -34,9 +34,10 @@ namespace py = pybind11; namespace { void init_triton_amd_passes_ttgpuir(py::module &&m) { using namespace mlir::triton; - m.def("add_to_llvmir", [](mlir::PassManager &pm, const std::string &arch) { - pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch)); - }); + m.def("add_to_llvmir", + [](mlir::PassManager &pm, const std::string &arch, bool ftz) { + pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz)); + }); m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(createConvertBuiltinFuncToLLVMPass()); });