Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions test/Conversion/amd/math-denorm-handling.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=True" --convert-builtin-func-to-llvm | FileCheck %s --check-prefix=LLVM_FTZ
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" --convert-builtin-func-to-llvm | FileCheck %s --check-prefix=LLVM_NO_FTZ


#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} {
// LLVM_FTZ: llvm.amdgcn.exp2.f32
// LLVM_NO_FTZ: llvm.exp2.f32
%0 = math.exp2 %arg0 : tensor<64xf32, #blocked>
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} {
// LLVM_FTZ: llvm.exp2.f32
// LLVM_NO_FTZ: llvm.exp2.f32
%0 = math.exp %arg0 : tensor<64xf32, #blocked>
tt.return
}
}
10 changes: 9 additions & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,15 @@ def make_llir(src, metadata, options):
passes.convert.add_index_to_llvmir(pm)

passes.ttgpuir.add_allocate_shared_memory(pm)
amd.passes.ttgpuir.add_to_llvmir(pm, options.arch)
## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
## of the value of kernel arg `allow_flush_denorm`.
## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output
## depends on the value of kernel arg `allow_flush_denorm`.
## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
## For now it is used as a controller for developers only.
__HIP_FTZ = True
Comment thread
antiagainst marked this conversation as resolved.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having this be specifically control exp makes no sense. The name, by lack of exp, and similarity to the module level __CUDA_FTZ, would imply this is changing the global floating point environment denormal mode. There should be no special case modes for a specific operation

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plan is to use this __HIP_FTZ to control denorm handling behavior for all related math functions. This PR serves as a first step and selects exp2 because it affects performance of FA kernels.
__CUDA_FTZ is used only in the libdevice.bc. We don't have the DAZ_OPT flag in AMD ocml.bc anymore, so I exposed this flag here.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose splitting it this way:

  1. Fix this to start directly emitting the llvm intrinsic directly
  2. Holistically change the mode to use FTZ, then you don't need to special case anything

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Fix this to start directly emitting the llvm intrinsic directly

In this PR, exp2 and exp are lowered with llvm intrinsics directly for f32 inputs. For f64 inputs, I assume we cannot use llvm intrinsics, right? In follow up PR, we can fix other math functions by using llvm intrinsics directly.

  1. Holistically change the mode to use FTZ, then you don't need to special case anything

Do you mean we should just use denorm-fp-math-f32 to control whether denorms are flushed for exp2, as with all other valu operations? Unfortunately, what we are trying to achieve here is to only flush denorms for exp2, which is the case on the nvidia side.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, exp2 and exp are lowered with llvm intrinsics directly for f32 inputs. For f64 inputs, I assume we cannot use llvm intrinsics, right? In follow up PR, we can fix other math functions by using llvm intrinsics directly.

Correct

Do you mean we should just use denorm-fp-math-f32 to control whether denorms are flushed for exp2, as with all other valu operations?

More precisely, this changes the default floating point mode.

Unfortunately, what we are trying to achieve here is to only flush denorms for exp2, which is the case on the nvidia side.

Sounds like an Nvidia bug to me, unless this is a specific "fast" exp function. You shouldn't be touching any global module flag for this

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Triton math op semantics are sort of "defined" by the nvidia instructions there given the history. The overall goal is to figure out the fine details for various ops (a lot there) and document them properly and make sure we are consistent. So I'd expect that's a lenghty procedure and we might not get everything perfect in one go. I think these are good points to follow up on that aren't blocking.

amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)

Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ createDecomposeUnsupportedConversionsPass(StringRef targetArch);
} // namespace AMD

std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch);
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz);
std::unique_ptr<OperationPass<ModuleOp>> createConvertBuiltinFuncToLLVMPass();

#define GEN_PASS_REGISTRATION
Expand Down
4 changes: 3 additions & 1 deletion third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def DecomposeUnsupportedAMDConversions : Pass<"decompose-unsupported-amd-convers

def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert TritonGPU to LLVM";
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\")";
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true)";

let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::math::MathDialect",
Expand All @@ -30,6 +30,8 @@ def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::Mod
let options = [
Option<"arch", "arch", "std::string", /*default*/"\"\"",
"gfx target device architecture, e.g., gfx942">,
Option<"ftz", "ftz", "bool", /*default*/"true",
"flush denorms for math functions">,
];
}

Expand Down
63 changes: 56 additions & 7 deletions third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
using namespace mlir;
using namespace mlir::triton;

using mlir::triton::gpu::appendOrGetExternFuncOp;
using mlir::triton::gpu::ElementwiseOpConversionBase;
using mlir::triton::gpu::getElementType;
using mlir::triton::gpu::getFunctionType;
using mlir::triton::gpu::MultipleOperandsRange;

typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
Expand Down Expand Up @@ -1213,23 +1215,66 @@ struct ExpOpConversionApprox
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
// For non-FP32 input, call __nv_expf for higher-precision calculation
// For non-FP32 input, call __ocml_exp_f64 for higher-precision calculation
Comment thread
antiagainst marked this conversation as resolved.
if (elemTy.getIntOrFloatBitWidth() != 32)
return {};

const double log2e = 1.4426950408889634;
Value prod = fmul(f32_ty, operands[0][0], f32_val(log2e));

return {rewriter.create<math::Exp2Op>(loc, f32_ty, prod,
adaptor.getAttributes().getValue())};
// Here we use llvm.exp2.f32 instead of math::Exp2Op. The latter
// flushes denorms by default, but we want to preserve denorms by default
// for expOp.
StringRef funcName = "llvm.exp2.f32";
Type funcType = getFunctionType(elemTy, operands[0]);
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);

return {rewriter.create<LLVM::CallOp>(loc, funcOp, prod).getResult()};
}
};

struct Exp2OpConversion
: ElementwiseOpConversionBase<mlir::math::Exp2Op, Exp2OpConversion> {
using ElementwiseOpConversionBase<
mlir::math::Exp2Op, Exp2OpConversion>::ElementwiseOpConversionBase;

explicit Exp2OpConversion(LLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisInfoAnalysis, bool ftz,
PatternBenefit benefit)
: ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit),
ftz(ftz) {}

SmallVector<Value> createDestOps(mlir::math::Exp2Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
// For non-FP32 input, call __ocml_exp2_f64 for higher-precision calculation
Comment thread
antiagainst marked this conversation as resolved.
if (elemTy.getIntOrFloatBitWidth() != 32)
return {};

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here worth a comment saying the former flushes denorm values and the later expands to LLVM instructions to handle denorm values + the former.

// On AMD backend, both intrinsics are lowered to v_exp_f32 instruction,
// which flushes input and output denorms. `llvm.amdgcn.exp2.f32` provides
// direct access to v_exp_f32. For `llvm.exp2.f32`, the LLVM backend inserts
// instructions to handle denorms iff `allow_flush_denorm` is False.
StringRef funcName = ftz ? "llvm.amdgcn.exp2.f32" : "llvm.exp2.f32";
Type funcType = getFunctionType(elemTy, operands[0]);
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);

return {
rewriter.create<LLVM::CallOp>(loc, funcOp, operands[0]).getResult()};
}

private:
bool ftz;
};

} // namespace

namespace mlir::triton::AMD {
void populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps,
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz,
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
const TargetInfo &targetInfo, PatternBenefit benefit) {

Expand Down Expand Up @@ -1257,11 +1302,15 @@ void populateElementwiseOpToLLVMPatterns(
patterns.add<FpToFpOpConversion>(typeConverter, axisInfoAnalysis,
targetInfo.getISAFamily(), benefit);

// ExpOpConversionApprox will try using ex2.approx if the input type is
// ExpOpConversionApprox will try using __ocml_exp2_f32 if the input type is
// FP32. For other input types, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
// __nv_expf for higher-precision calculation
// later pass will call __ocml_exp_f64 for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, axisInfoAnalysis, benefit);
// Exp2OpConversion will use llvm.exp2.f32 or llvm.amdgcn.exp2.f32
// based on the ftz flag if the input type is FP32. For FP64 input,
// Exp2OpConversion will return failure and later pass will call
// __ocml_exp2_f64 for higher-precision calculation
patterns.add<Exp2OpConversion>(typeConverter, axisInfoAnalysis, ftz, benefit);
Comment thread
antiagainst marked this conversation as resolved.
mlir::triton::populateElementwiseOpToLLVMPatterns(
typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit);
mlir::triton::populateMinMaxFOpToLLVMPattern(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
PatternBenefit benefit);
void populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps,
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz,
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
const TargetInfo &targetInfo, PatternBenefit benefit);
void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
Expand Down
11 changes: 7 additions & 4 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ class TritonLLVMConversionTarget : public ConversionTarget {
struct ConvertTritonAMDGPUToLLVM
: public triton::impl::ConvertTritonAMDGPUToLLVMBase<
ConvertTritonAMDGPUToLLVM> {
explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch) {
explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch, bool ftz) {
this->arch = targetArch.str();
this->ftz = ftz;
}

void getDependentDialects(DialectRegistry &registry) const override {
Expand Down Expand Up @@ -174,7 +175,9 @@ struct ConvertTritonAMDGPUToLLVM
typeConverter, targetInfo, patterns, commonBenefit);
AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, AMDBenefit);
populatePatterns6(AMD::populateElementwiseOpToLLVMPatterns, AMDBenefit);
AMD::populateElementwiseOpToLLVMPatterns(typeConverter, patterns, ftz,
axisInfoAnalysis, allocation,
targetInfo, AMDBenefit);
AMD::populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, patterns,
numWarps, axisInfoAnalysis,
AMDBenefit);
Expand Down Expand Up @@ -243,8 +246,8 @@ namespace mlir {
namespace triton {

std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch) {
return std::make_unique<ConvertTritonAMDGPUToLLVM>(targetArch);
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz) {
return std::make_unique<ConvertTritonAMDGPUToLLVM>(targetArch, ftz);
}

} // namespace triton
Expand Down
7 changes: 4 additions & 3 deletions third_party/amd/python/triton_amd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ namespace py = pybind11;
namespace {
void init_triton_amd_passes_ttgpuir(py::module &&m) {
using namespace mlir::triton;
m.def("add_to_llvmir", [](mlir::PassManager &pm, const std::string &arch) {
pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch));
});
m.def("add_to_llvmir",
[](mlir::PassManager &pm, const std::string &arch, bool ftz) {
pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz));
});
m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) {
pm.addPass(createConvertBuiltinFuncToLLVMPass());
});
Expand Down