-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[AMD] Handle denorms properly for exp2 and exp #3816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
817c7ef
33d9971
32dbc8a
d46b371
54cb9af
ad05f7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| // RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=True" --convert-builtin-func-to-llvm | FileCheck %s --check-prefix=LLVM_FTZ | ||
| // RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" --convert-builtin-func-to-llvm | FileCheck %s --check-prefix=LLVM_NO_FTZ | ||
|
|
||
|
|
||
| #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> | ||
| module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { | ||
| tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { | ||
| // LLVM_FTZ: llvm.amdgcn.exp2.f32 | ||
| // LLVM_NO_FTZ: llvm.exp2.f32 | ||
| %0 = math.exp2 %arg0 : tensor<64xf32, #blocked> | ||
| tt.return | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> | ||
| module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { | ||
| tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { | ||
| // LLVM_FTZ: llvm.exp2.f32 | ||
| // LLVM_NO_FTZ: llvm.exp2.f32 | ||
| %0 = math.exp %arg0 : tensor<64xf32, #blocked> | ||
| tt.return | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -158,7 +158,15 @@ def make_llir(src, metadata, options): | |
| passes.convert.add_index_to_llvmir(pm) | ||
|
|
||
| passes.ttgpuir.add_allocate_shared_memory(pm) | ||
| amd.passes.ttgpuir.add_to_llvmir(pm, options.arch) | ||
| ## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows: | ||
| ## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless | ||
| ## of the value of kernel arg `allow_flush_denorm`. | ||
| ## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output | ||
| ## depends on the value of kernel arg `allow_flush_denorm`. | ||
| ## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument. | ||
| ## For now it is used as a controller for developers only. | ||
| __HIP_FTZ = True | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having this be specifically control exp makes no sense. The name, by lack of exp, and similarity to the module level __CUDA_FTZ, would imply this is changing the global floating point environment denormal mode. There should be no special case modes for a specific operation
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The plan is to use this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I propose splitting it this way:
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In this PR, exp2 and exp are lowered with llvm intrinsics directly for f32 inputs. For f64 inputs, I assume we cannot use llvm intrinsics, right? In follow up PR, we can fix other math functions by using llvm intrinsics directly.
Do you mean we should just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Correct
More precisely, this changes the default floating point mode.
Sounds like an Nvidia bug to me, unless this is a specific "fast" exp function. You shouldn't be touching any global module flag for this
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Triton math op semantics are sort of "defined" by the nvidia instructions there given the history. The overall goal is to figure out the fine details for various ops (a lot there) and document them properly and make sure we are consistent. So I'd expect that's a lenghty procedure and we might not get everything perfect in one go. I think these are good points to follow up on that aren't blocking. |
||
| amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ) | ||
| passes.common.add_canonicalizer(pm) | ||
| passes.common.add_cse(pm) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,8 +10,10 @@ | |
| using namespace mlir; | ||
| using namespace mlir::triton; | ||
|
|
||
| using mlir::triton::gpu::appendOrGetExternFuncOp; | ||
| using mlir::triton::gpu::ElementwiseOpConversionBase; | ||
| using mlir::triton::gpu::getElementType; | ||
| using mlir::triton::gpu::getFunctionType; | ||
| using mlir::triton::gpu::MultipleOperandsRange; | ||
|
|
||
| typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &, | ||
|
|
@@ -1213,23 +1215,66 @@ struct ExpOpConversionApprox | |
| ConversionPatternRewriter &rewriter, | ||
| Type elemTy, MultipleOperandsRange operands, | ||
| Location loc) const { | ||
| // For non-FP32 input, call __nv_expf for higher-precision calculation | ||
| // For non-FP32 input, call __ocml_exp_f64 for higher-precision calculation | ||
|
antiagainst marked this conversation as resolved.
|
||
| if (elemTy.getIntOrFloatBitWidth() != 32) | ||
| return {}; | ||
|
|
||
| const double log2e = 1.4426950408889634; | ||
| Value prod = fmul(f32_ty, operands[0][0], f32_val(log2e)); | ||
|
|
||
| return {rewriter.create<math::Exp2Op>(loc, f32_ty, prod, | ||
| adaptor.getAttributes().getValue())}; | ||
| // Here we use llvm.exp2.f32 instead of math::Exp2Op. The latter | ||
| // flushes denorms by default, but we want to preserve denorms by default | ||
| // for expOp. | ||
| StringRef funcName = "llvm.exp2.f32"; | ||
| Type funcType = getFunctionType(elemTy, operands[0]); | ||
| LLVM::LLVMFuncOp funcOp = | ||
| appendOrGetExternFuncOp(rewriter, op, funcName, funcType); | ||
|
|
||
| return {rewriter.create<LLVM::CallOp>(loc, funcOp, prod).getResult()}; | ||
| } | ||
| }; | ||
|
|
||
| struct Exp2OpConversion | ||
| : ElementwiseOpConversionBase<mlir::math::Exp2Op, Exp2OpConversion> { | ||
| using ElementwiseOpConversionBase< | ||
| mlir::math::Exp2Op, Exp2OpConversion>::ElementwiseOpConversionBase; | ||
|
|
||
| explicit Exp2OpConversion(LLVMTypeConverter &typeConverter, | ||
| ModuleAxisInfoAnalysis &axisInfoAnalysis, bool ftz, | ||
| PatternBenefit benefit) | ||
| : ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit), | ||
| ftz(ftz) {} | ||
|
|
||
| SmallVector<Value> createDestOps(mlir::math::Exp2Op op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter, | ||
| Type elemTy, MultipleOperandsRange operands, | ||
| Location loc) const { | ||
| // For non-FP32 input, call __ocml_exp2_f64 for higher-precision calculation | ||
|
antiagainst marked this conversation as resolved.
|
||
| if (elemTy.getIntOrFloatBitWidth() != 32) | ||
| return {}; | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here worth a comment saying the former flushes denorm values and the later expands to LLVM instructions to handle denorm values + the former. |
||
| // On AMD backend, both intrinsics are lowered to v_exp_f32 instruction, | ||
| // which flushes input and output denorms. `llvm.amdgcn.exp2.f32` provides | ||
| // direct access to v_exp_f32. For `llvm.exp2.f32`, the LLVM backend inserts | ||
| // instructions to handle denorms iff `allow_flush_denorm` is False. | ||
| StringRef funcName = ftz ? "llvm.amdgcn.exp2.f32" : "llvm.exp2.f32"; | ||
| Type funcType = getFunctionType(elemTy, operands[0]); | ||
| LLVM::LLVMFuncOp funcOp = | ||
| appendOrGetExternFuncOp(rewriter, op, funcName, funcType); | ||
|
|
||
| return { | ||
| rewriter.create<LLVM::CallOp>(loc, funcOp, operands[0]).getResult()}; | ||
| } | ||
|
|
||
| private: | ||
| bool ftz; | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| namespace mlir::triton::AMD { | ||
| void populateElementwiseOpToLLVMPatterns( | ||
| LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, | ||
| LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz, | ||
| ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, | ||
| const TargetInfo &targetInfo, PatternBenefit benefit) { | ||
|
|
||
|
|
@@ -1257,11 +1302,15 @@ void populateElementwiseOpToLLVMPatterns( | |
| patterns.add<FpToFpOpConversion>(typeConverter, axisInfoAnalysis, | ||
| targetInfo.getISAFamily(), benefit); | ||
|
|
||
| // ExpOpConversionApprox will try using ex2.approx if the input type is | ||
| // ExpOpConversionApprox will try using __ocml_exp2_f32 if the input type is | ||
| // FP32. For other input types, ExpOpConversionApprox will return failure and | ||
| // ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call | ||
| // __nv_expf for higher-precision calculation | ||
| // later pass will call __ocml_exp_f64 for higher-precision calculation | ||
| patterns.add<ExpOpConversionApprox>(typeConverter, axisInfoAnalysis, benefit); | ||
| // Exp2OpConversion will use llvm.exp2.f32 or llvm.amdgcn.exp2.f32 | ||
| // based on the ftz flag if the input type is FP32. For FP64 input, | ||
| // Exp2OpConversion will return failure and later pass will call | ||
| // __ocml_exp2_f64 for higher-precision calculation | ||
| patterns.add<Exp2OpConversion>(typeConverter, axisInfoAnalysis, ftz, benefit); | ||
|
antiagainst marked this conversation as resolved.
|
||
| mlir::triton::populateElementwiseOpToLLVMPatterns( | ||
| typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); | ||
| mlir::triton::populateMinMaxFOpToLLVMPattern( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.