Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) {
return convertType<unsigned, int64_t>(getShapePerCTA(srcTy));
}

if (isMfmaToDotShortcut(srcTy, dstTy))
return {};

// MmaToDotShortcut and MmaToMmaShortcut doesn't use shared mem
if (auto srcMmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(srcLayout)) {
if (mlir::isa<DotOperandEncodingAttr>(dstLayout)) {
Expand Down Expand Up @@ -111,11 +114,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

if (mlir::isa<AMDMfmaEncodingAttr>(srcLayout) &&
mlir::dyn_cast<AMDMfmaEncodingAttr>(srcLayout).getIsTransposed() &&
mlir::isa<DotOperandEncodingAttr>(dstLayout))
if (isMfmaToDotShortcut(srcTy, dstTy))
return {};
assert(!isMfmaToDotShortcut(srcTy, dstTy));

auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
unsigned srcContigPerThread =
Expand Down
6 changes: 4 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,8 +581,10 @@ bool supportMMA(Value value, int version) {
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(srcLayout);
auto dotOperandLayout = cast<DotOperandEncodingAttr>(dstLayout);
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcLayout);
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstLayout);
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
return false;
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
// improved. In addition, we can enable this shortcut for regular MFMA
// layout when opIdx == 1.
Expand Down
25 changes: 25 additions & 0 deletions test/Conversion/amd/math-denorm-handling.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=True" --convert-builtin-func-to-llvm | FileCheck %s --check-prefix=LLVM_FTZ
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" --convert-builtin-func-to-llvm | FileCheck %s --check-prefix=LLVM_NO_FTZ


#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} {
// LLVM_FTZ: llvm.amdgcn.exp2.f32
// LLVM_NO_FTZ: llvm.exp2.f32
%0 = math.exp2 %arg0 : tensor<64xf32, #blocked>
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} {
// LLVM_FTZ: llvm.exp2.f32
// LLVM_NO_FTZ: llvm.exp2.f32
%0 = math.exp %arg0 : tensor<64xf32, #blocked>
tt.return
}
}
27 changes: 27 additions & 0 deletions test/Conversion/amd/mfma-shortcut.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s

#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
#dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: shortcut_mfma16
tt.func public @shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) {
// CHECK-NOT: store
// CHECK-NOT: load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop>
tt.return
}
}

// -----

#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
#dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: no_shortcut_mfma16
tt.func public @no_shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) {
// CHECK: store
// CHECK: load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop>
tt.return
}
}
10 changes: 9 additions & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,15 @@ def make_llir(src, metadata, options):
passes.convert.add_index_to_llvmir(pm)

passes.ttgpuir.add_allocate_shared_memory(pm)
amd.passes.ttgpuir.add_to_llvmir(pm, options.arch)
## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
## of the value of kernel arg `allow_flush_denorm`.
## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output
## depends on the value of kernel arg `allow_flush_denorm`.
## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
## For now it is used as a controller for developers only.
__HIP_FTZ = True
amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)

Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ createDecomposeUnsupportedConversionsPass(StringRef targetArch);
} // namespace AMD

std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch);
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz);
std::unique_ptr<OperationPass<ModuleOp>> createConvertBuiltinFuncToLLVMPass();

#define GEN_PASS_REGISTRATION
Expand Down
4 changes: 3 additions & 1 deletion third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def DecomposeUnsupportedAMDConversions : Pass<"decompose-unsupported-amd-convers

def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert TritonGPU to LLVM";
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\")";
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true)";

let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::math::MathDialect",
Expand All @@ -30,6 +30,8 @@ def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::Mod
let options = [
Option<"arch", "arch", "std::string", /*default*/"\"\"",
"gfx target device architecture, e.g., gfx942">,
Option<"ftz", "ftz", "bool", /*default*/"true",
"flush denorms for math functions">,
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,5 @@ void populateConvertLayoutOpToLLVMPatterns(
ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) {
patterns.add<ConvertLayoutOpConversion>(typeConverter, benefit);
patterns.add<LocalLoadOpConversion>(typeConverter, benefit);
mlir::triton::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo,
patterns, benefit);
}
} // namespace mlir::triton::AMD
63 changes: 56 additions & 7 deletions third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
using namespace mlir;
using namespace mlir::triton;

using mlir::triton::gpu::appendOrGetExternFuncOp;
using mlir::triton::gpu::ElementwiseOpConversionBase;
using mlir::triton::gpu::getElementType;
using mlir::triton::gpu::getFunctionType;
using mlir::triton::gpu::MultipleOperandsRange;

typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
Expand Down Expand Up @@ -1213,23 +1215,66 @@ struct ExpOpConversionApprox
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
// For non-FP32 input, call __nv_expf for higher-precision calculation
// For non-FP32 input, call __ocml_exp_f64 for higher-precision calculation
if (elemTy.getIntOrFloatBitWidth() != 32)
return {};

const double log2e = 1.4426950408889634;
Value prod = fmul(f32_ty, operands[0][0], f32_val(log2e));

return {rewriter.create<math::Exp2Op>(loc, f32_ty, prod,
adaptor.getAttributes().getValue())};
// Here we use llvm.exp2.f32 instead of math::Exp2Op. The latter
// flushes denorms by default, but we want to preserve denorms by default
// for expOp.
StringRef funcName = "llvm.exp2.f32";
Type funcType = getFunctionType(elemTy, operands[0]);
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);

return {rewriter.create<LLVM::CallOp>(loc, funcOp, prod).getResult()};
}
};

struct Exp2OpConversion
: ElementwiseOpConversionBase<mlir::math::Exp2Op, Exp2OpConversion> {
using ElementwiseOpConversionBase<
mlir::math::Exp2Op, Exp2OpConversion>::ElementwiseOpConversionBase;

explicit Exp2OpConversion(LLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisInfoAnalysis, bool ftz,
PatternBenefit benefit)
: ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit),
ftz(ftz) {}

SmallVector<Value> createDestOps(mlir::math::Exp2Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
// For non-FP32 input, call __ocml_exp2_f64 for higher-precision calculation
if (elemTy.getIntOrFloatBitWidth() != 32)
return {};

// On AMD backend, both intrinsics are lowered to v_exp_f32 instruction,
// which flushes input and output denorms. `llvm.amdgcn.exp2.f32` provides
// direct access to v_exp_f32. For `llvm.exp2.f32`, the LLVM backend inserts
// instructions to handle denorms iff `allow_flush_denorm` is False.
StringRef funcName = ftz ? "llvm.amdgcn.exp2.f32" : "llvm.exp2.f32";
Type funcType = getFunctionType(elemTy, operands[0]);
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);

return {
rewriter.create<LLVM::CallOp>(loc, funcOp, operands[0]).getResult()};
}

private:
bool ftz;
};

} // namespace

namespace mlir::triton::AMD {
void populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps,
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz,
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
const TargetInfo &targetInfo, PatternBenefit benefit) {

Expand Down Expand Up @@ -1257,11 +1302,15 @@ void populateElementwiseOpToLLVMPatterns(
patterns.add<FpToFpOpConversion>(typeConverter, axisInfoAnalysis,
targetInfo.getISAFamily(), benefit);

// ExpOpConversionApprox will try using ex2.approx if the input type is
// ExpOpConversionApprox will try using __ocml_exp2_f32 if the input type is
// FP32. For other input types, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
// __nv_expf for higher-precision calculation
// later pass will call __ocml_exp_f64 for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, axisInfoAnalysis, benefit);
// Exp2OpConversion will use llvm.exp2.f32 or llvm.amdgcn.exp2.f32
// based on the ftz flag if the input type is FP32. For FP64 input,
// Exp2OpConversion will return failure and later pass will call
// __ocml_exp2_f64 for higher-precision calculation
patterns.add<Exp2OpConversion>(typeConverter, axisInfoAnalysis, ftz, benefit);
mlir::triton::populateElementwiseOpToLLVMPatterns(
typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit);
mlir::triton::populateMinMaxFOpToLLVMPattern(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
PatternBenefit benefit);
void populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps,
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz,
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
const TargetInfo &targetInfo, PatternBenefit benefit);
void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
Expand Down
59 changes: 36 additions & 23 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ class TritonLLVMConversionTarget : public ConversionTarget {
struct ConvertTritonAMDGPUToLLVM
: public triton::impl::ConvertTritonAMDGPUToLLVMBase<
ConvertTritonAMDGPUToLLVM> {
explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch) {
explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch, bool ftz) {
this->arch = targetArch.str();
this->ftz = ftz;
}

void getDependentDialects(DialectRegistry &registry) const override {
Expand Down Expand Up @@ -145,48 +146,60 @@ struct ConvertTritonAMDGPUToLLVM
OpBuilder::InsertPoint indexInsertPoint;

RewritePatternSet patterns(context);
int benefit = patternBenefitPrioritizeOverLLVMConversions;
auto populatePatterns1 = [&](auto populateFunc) {
int commonBenefit = patternBenefitPrioritizeOverLLVMConversions;
// Make benefit for AMD specific patterns higher so they apply before common
// patterns
int AMDBenefit = commonBenefit + 1;
auto populatePatterns1 = [&](auto populateFunc, int benefit) {
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
allocation, benefit);
};

auto populatePatterns5 = [&](auto populateFunc) {
auto populatePatterns5 = [&](auto populateFunc, int benefit) {
populateFunc(typeConverter, patterns, benefit);
};

auto populatePatterns6 = [&](auto populateFunc) {
auto populatePatterns6 = [&](auto populateFunc, int benefit) {
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
allocation, targetInfo, benefit);
};

auto populatePatterns7 = [&](auto populateFunc) {
auto populatePatterns7 = [&](auto populateFunc, int benefit) {
populateFunc(typeConverter, patterns, targetInfo, benefit);
};

AMD::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo,
patterns, numWarps,
axisInfoAnalysis, benefit);
axisInfoAnalysis, AMDBenefit);
mlir::triton::populateConvertLayoutOpToLLVMPatterns(
typeConverter, targetInfo, patterns, commonBenefit);
AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, benefit);
populatePatterns6(AMD::populateElementwiseOpToLLVMPatterns);
axisInfoAnalysis, AMDBenefit);
AMD::populateElementwiseOpToLLVMPatterns(typeConverter, patterns, ftz,
axisInfoAnalysis, allocation,
targetInfo, AMDBenefit);
AMD::populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, patterns,
numWarps, axisInfoAnalysis, benefit);
populatePatterns7(mlir::triton::populateReduceOpToLLVMPatterns);
populatePatterns7(mlir::triton::populateScanOpToLLVMPatterns);
populatePatterns5(mlir::triton::populateViewOpToLLVMPatterns);
populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns);
numWarps, axisInfoAnalysis,
AMDBenefit);
populatePatterns7(mlir::triton::populateReduceOpToLLVMPatterns,
commonBenefit);
populatePatterns7(mlir::triton::populateScanOpToLLVMPatterns,
commonBenefit);
populatePatterns5(mlir::triton::populateViewOpToLLVMPatterns,
commonBenefit);
populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns,
commonBenefit);
mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, targetInfo,
patterns, benefit);
patterns, commonBenefit);
mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo,
patterns, benefit);
patterns, commonBenefit);
mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns,
targetInfo, benefit);
targetInfo, commonBenefit);
mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns,
benefit);
commonBenefit);
mlir::triton::populateSPMDOpToLLVMPattern(typeConverter, patterns,
targetInfo, benefit);
AMD::populateSPMDOpToLLVMPattern(typeConverter, patterns, benefit);
targetInfo, commonBenefit);
AMD::populateSPMDOpToLLVMPattern(typeConverter, patterns, AMDBenefit);
// TODO(thomas): this should probably be done in a separate step to not
// interfere with our own lowering of arith ops. Add arith/math's patterns
// to help convert scalar expression to LLVM.
Expand All @@ -200,7 +213,7 @@ struct ConvertTritonAMDGPUToLLVM
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
patterns);
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns,
targetInfo, benefit);
targetInfo, commonBenefit);
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) {
return signalPassFailure();
}
Expand Down Expand Up @@ -233,8 +246,8 @@ namespace mlir {
namespace triton {

std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch) {
return std::make_unique<ConvertTritonAMDGPUToLLVM>(targetArch);
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz) {
return std::make_unique<ConvertTritonAMDGPUToLLVM>(targetArch, ftz);
}

} // namespace triton
Expand Down
7 changes: 4 additions & 3 deletions third_party/amd/python/triton_amd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ namespace py = pybind11;
namespace {
void init_triton_amd_passes_ttgpuir(py::module &&m) {
using namespace mlir::triton;
m.def("add_to_llvmir", [](mlir::PassManager &pm, const std::string &arch) {
pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch));
});
m.def("add_to_llvmir",
[](mlir::PassManager &pm, const std::string &arch, bool ftz) {
pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz));
});
m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) {
pm.addPass(createConvertBuiltinFuncToLLVMPass());
});
Expand Down