diff --git a/test/Conversion/amd/math-denorm-handling.mlir b/test/Conversion/amd/math-denorm-handling.mlir new file mode 100644 index 000000000000..520f44db933d --- /dev/null +++ b/test/Conversion/amd/math-denorm-handling.mlir @@ -0,0 +1,25 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=True" --convert-builtin-func-to-llvm | FileCheck %s --check-prefix=LLVM_FTZ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" --convert-builtin-func-to-llvm | FileCheck %s --check-prefix=LLVM_NO_FTZ + + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { + // LLVM_FTZ: llvm.amdgcn.exp2.f32 + // LLVM_NO_FTZ: llvm.exp2.f32 + %0 = math.exp2 %arg0 : tensor<64xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { + // LLVM_FTZ: llvm.exp2.f32 + // LLVM_NO_FTZ: llvm.exp2.f32 + %0 = math.exp %arg0 : tensor<64xf32, #blocked> + tt.return + } +} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 899481a754e7..b092e32e6ce1 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -158,7 +158,15 @@ def make_llir(src, metadata, options): passes.convert.add_index_to_llvmir(pm) passes.ttgpuir.add_allocate_shared_memory(pm) - amd.passes.ttgpuir.add_to_llvmir(pm, options.arch) + ## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows: + ## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless + ## of the value of kernel arg `allow_flush_denorm`. + ## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output + ## depends on the value of kernel arg `allow_flush_denorm`. + ## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument. + ## For now it is used as a controller for developers only. + __HIP_FTZ = True + amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h index d853f9c2101c..df5ad78494ab 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -25,7 +25,7 @@ createDecomposeUnsupportedConversionsPass(StringRef targetArch); } // namespace AMD std::unique_ptr> -createConvertTritonAMDGPUToLLVMPass(StringRef targetArch); +createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz); std::unique_ptr> createConvertBuiltinFuncToLLVMPass(); #define GEN_PASS_REGISTRATION diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index b14e3eb4d056..986c6763bbb3 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -15,7 +15,7 @@ def DecomposeUnsupportedAMDConversions : Pass<"decompose-unsupported-amd-convers def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::ModuleOp"> { let summary = "Convert TritonGPU to LLVM"; - let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\")"; + let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true)"; let dependentDialects = ["mlir::arith::ArithDialect", "mlir::math::MathDialect", @@ -30,6 +30,8 @@ def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::Mod let options = [ Option<"arch", "arch", "std::string", /*default*/"\"\"", "gfx target device architecture, e.g., gfx942">, + Option<"ftz", "ftz", "bool", /*default*/"true", + "flush denorms for math functions">, ]; } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index ac1d8c68d43c..af244ad1ed21 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -10,8 +10,10 @@ using namespace mlir; using namespace mlir::triton; +using mlir::triton::gpu::appendOrGetExternFuncOp; using mlir::triton::gpu::ElementwiseOpConversionBase; using mlir::triton::gpu::getElementType; +using mlir::triton::gpu::getFunctionType; using mlir::triton::gpu::MultipleOperandsRange; typedef std::function(Location, ConversionPatternRewriter &, @@ -1213,23 +1215,66 @@ struct ExpOpConversionApprox ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { - // For non-FP32 input, call __nv_expf for higher-precision calculation + // For non-FP32 input, call __ocml_exp_f64 for higher-precision calculation if (elemTy.getIntOrFloatBitWidth() != 32) return {}; const double log2e = 1.4426950408889634; Value prod = fmul(f32_ty, operands[0][0], f32_val(log2e)); - return {rewriter.create(loc, f32_ty, prod, - adaptor.getAttributes().getValue())}; + // Here we use llvm.exp2.f32 instead of math::Exp2Op. The latter + // flushes denorms by default, but we want to preserve denorms by default + // for expOp. + StringRef funcName = "llvm.exp2.f32"; + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + + return {rewriter.create(loc, funcOp, prod).getResult()}; } }; +struct Exp2OpConversion + : ElementwiseOpConversionBase { + using ElementwiseOpConversionBase< + mlir::math::Exp2Op, Exp2OpConversion>::ElementwiseOpConversionBase; + + explicit Exp2OpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisInfoAnalysis, bool ftz, + PatternBenefit benefit) + : ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit), + ftz(ftz) {} + + SmallVector createDestOps(mlir::math::Exp2Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + // For non-FP32 input, call __ocml_exp2_f64 for higher-precision calculation + if (elemTy.getIntOrFloatBitWidth() != 32) + return {}; + + // On AMD backend, both intrinsics are lowered to v_exp_f32 instruction, + // which flushes input and output denorms. `llvm.amdgcn.exp2.f32` provides + // direct access to v_exp_f32. For `llvm.exp2.f32`, the LLVM backend inserts + // instructions to handle denorms iff `allow_flush_denorm` is False. + StringRef funcName = ftz ? "llvm.amdgcn.exp2.f32" : "llvm.exp2.f32"; + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + + return { + rewriter.create(loc, funcOp, operands[0]).getResult()}; + } + +private: + bool ftz; +}; + } // namespace namespace mlir::triton::AMD { void populateElementwiseOpToLLVMPatterns( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, const TargetInfo &targetInfo, PatternBenefit benefit) { @@ -1257,11 +1302,15 @@ void populateElementwiseOpToLLVMPatterns( patterns.add(typeConverter, axisInfoAnalysis, targetInfo.getISAFamily(), benefit); - // ExpOpConversionApprox will try using ex2.approx if the input type is + // ExpOpConversionApprox will try using __ocml_exp2_f32 if the input type is // FP32. For other input types, ExpOpConversionApprox will return failure and - // ElementwiseOpConversion defined below will call - // __nv_expf for higher-precision calculation + // later pass will call __ocml_exp_f64 for higher-precision calculation patterns.add(typeConverter, axisInfoAnalysis, benefit); + // Exp2OpConversion will use llvm.exp2.f32 or llvm.amdgcn.exp2.f32 + // based on the ftz flag if the input type is FP32. For FP64 input, + // Exp2OpConversion will return failure and later pass will call + // __ocml_exp2_f64 for higher-precision calculation + patterns.add(typeConverter, axisInfoAnalysis, ftz, benefit); mlir::triton::populateElementwiseOpToLLVMPatterns( typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); mlir::triton::populateMinMaxFOpToLLVMPattern( diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h index e904f6794f02..67e5369b8650 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -17,7 +17,7 @@ void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit); void populateElementwiseOpToLLVMPatterns( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, const TargetInfo &targetInfo, PatternBenefit benefit); void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 6625b8a12027..8649911a7c2d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -63,8 +63,9 @@ class TritonLLVMConversionTarget : public ConversionTarget { struct ConvertTritonAMDGPUToLLVM : public triton::impl::ConvertTritonAMDGPUToLLVMBase< ConvertTritonAMDGPUToLLVM> { - explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch) { + explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch, bool ftz) { this->arch = targetArch.str(); + this->ftz = ftz; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -174,7 +175,9 @@ struct ConvertTritonAMDGPUToLLVM typeConverter, targetInfo, patterns, commonBenefit); AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, axisInfoAnalysis, AMDBenefit); - populatePatterns6(AMD::populateElementwiseOpToLLVMPatterns, AMDBenefit); + AMD::populateElementwiseOpToLLVMPatterns(typeConverter, patterns, ftz, + axisInfoAnalysis, allocation, + targetInfo, AMDBenefit); AMD::populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, patterns, numWarps, axisInfoAnalysis, AMDBenefit); @@ -243,8 +246,8 @@ namespace mlir { namespace triton { std::unique_ptr> -createConvertTritonAMDGPUToLLVMPass(StringRef targetArch) { - return std::make_unique(targetArch); +createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz) { + return std::make_unique(targetArch, ftz); } } // namespace triton diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 7c1f12c51760..ddc1feb2aa94 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -34,9 +34,10 @@ namespace py = pybind11; namespace { void init_triton_amd_passes_ttgpuir(py::module &&m) { using namespace mlir::triton; - m.def("add_to_llvmir", [](mlir::PassManager &pm, const std::string &arch) { - pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch)); - }); + m.def("add_to_llvmir", + [](mlir::PassManager &pm, const std::string &arch, bool ftz) { + pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz)); + }); m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(createConvertBuiltinFuncToLLVMPass()); });