diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 7f0a2729a4..0ee1707a4c 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -33,7 +33,8 @@ class TargetInfoBase { virtual bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce) const = 0; + unsigned numLaneToReduce, + unsigned interleave) const = 0; virtual bool processReplicaUsingStMatrix( ConversionPatternRewriter &rewriter, Location loc, Value smemBase, diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 4d036c21a8..66f7b90e0d 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -176,8 +176,8 @@ struct ReduceOpConversion void warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { - auto success = - targetInfo.warpReduce(rewriter, loc, acc, op, numLaneToReduce); + auto success = targetInfo.warpReduce(rewriter, loc, acc, op, + numLaneToReduce, interleave); if (success) return; for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 1e8f33c3f5..92a87794c9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -118,7 +118,8 @@ Value TargetInfo::programId(ConversionPatternRewriter &rewriter, Location loc, bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce) const { + unsigned numLaneToReduce, + unsigned interleave) const { return false; } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 0cf4535063..84733661cf 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -41,7 +41,7 @@ class TargetInfo : public mlir::triton::TargetInfoBase { bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce) const override; + unsigned numLaneToReduce, unsigned interleave) const override; bool processReplicaUsingStMatrix( ConversionPatternRewriter &rewriter, Location loc, Value smemBase, diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt index 7b5cb90868..e8afb63d28 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt @@ -24,6 +24,7 @@ add_triton_library(TritonIntelGPUToLLVM TypeConverter.cpp Utility.cpp ViewOpToLLVM.cpp + GenIntrinsicHelper.cpp DEPENDS TritonIntelGPUConversionPassIncGen @@ -33,4 +34,5 @@ add_triton_library(TritonIntelGPUToLLVM TritonGENIR TritonGENToLLVM TritonIntelGPUIR + MLIRTargetLLVMIRImport ) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/GenIntrinsicHelper.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/GenIntrinsicHelper.cpp new file mode 100644 index 0000000000..ee4f2a97d1 --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/GenIntrinsicHelper.cpp @@ -0,0 +1,145 @@ +#include "Utility.h" + +#include "GenIntrinsicHelper.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Target/LLVMIR/ModuleImport.h" +#include "mlir/Target/LLVMIR/TypeFromLLVM.h" +#include "mlir/Target/LLVMIR/TypeToLLVM.h" + +namespace mlir { +namespace triton { +namespace intel { + +// The code convert the function attribute from the original here: +// https://github.com/llvm/llvm-project/blob/e575b7cb7a64297583d6382c16ce264d9fe45d08/mlir/lib/Target/LLVMIR/ModuleImport.cpp#L1547 +// List of LLVM IR attributes that map to an explicit attribute on the MLIR +// LLVMFuncOp. +static constexpr std::array ExplicitAttributes{ + StringLiteral("aarch64_pstate_sm_enabled"), + StringLiteral("aarch64_pstate_sm_body"), + StringLiteral("aarch64_pstate_sm_compatible"), + StringLiteral("aarch64_new_za"), + StringLiteral("aarch64_preserves_za"), + StringLiteral("aarch64_in_za"), + StringLiteral("aarch64_out_za"), + StringLiteral("aarch64_inout_za"), + StringLiteral("vscale_range"), + StringLiteral("frame-pointer"), + StringLiteral("target-features"), + StringLiteral("unsafe-fp-math"), + StringLiteral("no-infs-fp-math"), + StringLiteral("no-nans-fp-math"), + StringLiteral("approx-func-fp-math"), + StringLiteral("no-signed-zeros-fp-math"), +}; + +static void processPassthroughAttrs(llvm::Function *func, + mlir::LLVM::LLVMFuncOp funcOp) { + MLIRContext *context = funcOp.getContext(); + SmallVector passthroughs; + llvm::AttributeSet funcAttrs = func->getAttributes().getAttributes( + llvm::AttributeList::AttrIndex::FunctionIndex); + for (llvm::Attribute attr : funcAttrs) { + // Skip the memory attribute since the LLVMFuncOp has an explicit memory + // attribute. + if (attr.hasAttribute(llvm::Attribute::Memory)) + continue; + + // Skip invalid type attributes. + if (attr.isTypeAttribute()) { + emitWarning(funcOp.getLoc(), + "type attributes on a function are invalid, skipping it"); + continue; + } + + StringRef attrName; + if (attr.isStringAttribute()) + attrName = attr.getKindAsString(); + else + attrName = llvm::Attribute::getNameFromAttrKind(attr.getKindAsEnum()); + auto keyAttr = StringAttr::get(context, attrName); + + // Skip attributes that map to an explicit attribute on the LLVMFuncOp. + if (llvm::is_contained(ExplicitAttributes, attrName)) + continue; + + if (attr.isStringAttribute()) { + StringRef val = attr.getValueAsString(); + if (val.empty()) { + passthroughs.push_back(keyAttr); + continue; + } + passthroughs.push_back( + ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)})); + continue; + } + if (attr.isIntAttribute()) { + auto val = std::to_string(attr.getValueAsInt()); + passthroughs.push_back( + ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)})); + continue; + } + if (attr.isEnumAttribute()) { + passthroughs.push_back(keyAttr); + continue; + } + + llvm_unreachable("unexpected attribute kind"); + } + + if (!passthroughs.empty()) + funcOp.setPassthroughAttr(ArrayAttr::get(context, passthroughs)); +} + +mlir::LLVM::LLVMFuncOp +appendOrGetGenISADeclaration(OpBuilder &builder, llvm::GenISAIntrinsic::ID id, + ArrayRef mlirTys) { + auto mlirContext = builder.getContext(); + + SmallVector llvmTys; + llvm::LLVMContext llvmContext; + std::unique_ptr llvmModule = + std::make_unique("temp", llvmContext); + mlir::LLVM::TypeToLLVMIRTranslator llvmToMLIR(llvmContext); + for (mlir::Type *ty : mlirTys) { + llvmTys.push_back(llvmToMLIR.translateType(*ty)); + } + auto llvmFunc = + llvm::GenISAIntrinsic::getDeclaration(llvmModule.get(), id, llvmTys); + + auto genISAName = llvmFunc->getName(); + auto llvmFuncType = llvmFunc->getFunctionType(); + LLVM::TypeFromLLVMIRTranslator mlirFromLLVM(*mlirContext); + auto mlirFuncTy = mlirFromLLVM.translateType(llvmFuncType); + mlir::LLVM::LLVMFunctionType funcTy = + cast(mlirFuncTy); + + auto funcName = StringAttr::get(mlirContext, genISAName); + Operation *funcOp = SymbolTable::lookupNearestSymbolFrom( + builder.getBlock() + ->getParent() + ->getParentOfType(), + funcName); + + if (funcOp) + return cast(*funcOp); + + auto parent = builder.getBlock() + ->getParent() + ->getParentOfType(); + mlir::OpBuilder b(parent); + auto ret = + b.create(mlir::UnknownLoc::get(mlirContext), genISAName, + funcTy, LLVM::Linkage::External, + /*dsoLocal*/ false, LLVM::CConv::C, + /*comdat=*/SymbolRefAttr{}); + + processPassthroughAttrs(llvmFunc, ret); + + return ret; +} + +} // namespace intel +} // namespace triton +} // namespace mlir diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/GenIntrinsicHelper.h b/third_party/intel/lib/TritonIntelGPUToLLVM/GenIntrinsicHelper.h new file mode 100644 index 0000000000..9fc70e673e --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/GenIntrinsicHelper.h @@ -0,0 +1,115 @@ +//===- GenIntrinsicHelper.h - Gen intrinsic helper ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_VCINTRINSICHELPER_H +#define TRITON_VCINTRINSICHELPER_H + +#include "TritonGENToLLVM/GenIntrinsics.h" +#include "Utility.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include +#include +#include + +namespace mlir { +namespace triton { +namespace intel { + +mlir::LLVM::LLVMFuncOp +appendOrGetGenISADeclaration(OpBuilder &builder, llvm::GenISAIntrinsic::ID id, + ArrayRef mlirTys); + +class GenISA_WaveAll { +public: + // Enum def copied from IGC. + enum class WaveOps : unsigned int { + SUM, + PROD, + UMIN, + UMAX, + IMIN, + IMAX, + OR, + XOR, + AND, + FSUM, + FPROD, + FMIN, + FMAX, + UNDEF + }; + + explicit GenISA_WaveAll(OpBuilder &builder, Type retTy) : builder(builder) { + // get GenISA intrinsic declaration. + intrinsicDecl = appendOrGetGenISADeclaration( + builder, llvm::GenISAIntrinsic::ID::GenISA_WaveAll, {&retTy}); + } + + template + Value operator()(OpBuilder &rewriter, Location loc, Args... args) { + auto funName = intrinsicDecl.getName(); + auto retType = intrinsicDecl.getResultTypes(); + auto funCall = rewriter.create(loc, retType, funName, + ValueRange{args...}); + return funCall.getResult(); + } + +private: + OpBuilder &builder; + LLVM::LLVMFuncOp intrinsicDecl; +}; + +class GenISA_WaveCluster { +public: + // Enum def copied from IGC. + enum class WaveOps : unsigned int { + SUM, + PROD, + UMIN, + UMAX, + IMIN, + IMAX, + OR, + XOR, + AND, + FSUM, + FPROD, + FMIN, + FMAX, + UNDEF + }; + + explicit GenISA_WaveCluster(OpBuilder &builder, Type retTy) + : builder(builder) { + // get GenISA intrinsic declaration. + intrinsicDecl = appendOrGetGenISADeclaration( + builder, llvm::GenISAIntrinsic::ID::GenISA_WaveClustered, {&retTy}); + } + + template + Value operator()(OpBuilder &rewriter, Location loc, Args... args) { + auto funName = intrinsicDecl.getName(); + auto retType = intrinsicDecl.getResultTypes(); + auto funCall = rewriter.create(loc, retType, funName, + ValueRange{args...}); + return funCall.getResult(); + } + +private: + OpBuilder &builder; + LLVM::LLVMFuncOp intrinsicDecl; +}; + +} // namespace intel +} // namespace triton +} // namespace mlir + +#endif // TRITON_VCINTRINSICHELPER_H diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp index 9cec7c2116..cb6dbd91dc 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp @@ -170,13 +170,13 @@ struct ReduceOpConversion } } - // Apply warp reduction across the given number of contiguous lanes using op - // region and the accumulator values as source. + // Apply warp reduction across the given number of the lanes and the + // interleave using op region and the accumulator values as source. void warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { - auto success = - targetInfo.warpReduce(rewriter, loc, acc, op, numLaneToReduce); + auto success = targetInfo.warpReduce(rewriter, loc, acc, op, + numLaneToReduce, interleave); if (success) return; for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp index 867f857534..9a22fc8c36 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TargetInfo.h" +#include "GenIntrinsicHelper.h" #include "Utility.h" using namespace mlir; @@ -84,9 +85,61 @@ Value TargetInfo::programId(ConversionPatternRewriter &rewriter, Location loc, bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce) const { - assert("TODO: implement warpReduce on XPU"); - return false; + unsigned numLaneToReduce, + unsigned interleave) const { + // No horizontal reduce required. + if (numLaneToReduce == 1) + return false; + // horizontal reduce with interleave stride not support. + if (interleave > 1) + return false; + // Check if it is a simple reduce operation supported by Wave Op. + if (op.getNumOperands() != 1 || op.getNumResults() != 1) + return false; + auto &combineOp = op.getCombineOp(); + if (combineOp.getBlocks().size() > 1) + return false; + Block &block = *combineOp.begin(); + Operation *yield = block.getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + if (!reduceOp || reduceOp->getNumOperands() != 2 || + reduceOp->getNumResults() != 1) + return false; + if (reduceOp->getOperand(0) != block.getArgument(0) || + reduceOp->getOperand(1) != block.getArgument(1)) + return false; + + GenISA_WaveAll::WaveOps waveOp = GenISA_WaveAll::WaveOps::UNDEF; + if (isa(reduceOp)) + waveOp = GenISA_WaveAll::WaveOps::SUM; + if (isa(reduceOp)) + waveOp = GenISA_WaveAll::WaveOps::FMAX; + if (waveOp == GenISA_WaveAll::WaveOps::UNDEF) + return false; + + auto mod = op.getOperation()->getParentOfType(); + unsigned threadsPerWarp = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + + if (threadsPerWarp > numLaneToReduce) { + GenISA_WaveCluster waveClusterOp(rewriter, + reduceOp->getResult(0).getType()); + for (unsigned i = 0; i < acc.size(); ++i) { + acc[i] = waveClusterOp(rewriter, op->getLoc(), acc[i], + int_val(8, (unsigned)waveOp), + i32_val(numLaneToReduce), i32_val(0)); + } + } else if (threadsPerWarp == numLaneToReduce) { + GenISA_WaveAll waveAllOp(rewriter, reduceOp->getResult(0).getType()); + for (unsigned i = 0; i < acc.size(); ++i) { + acc[i] = waveAllOp(rewriter, op->getLoc(), acc[i], + int_val(8, (unsigned)waveOp), i32_val(0)); + } + } else { + llvm_unreachable("it is ilegal to reduce the lane number > warp size"); + } + + return true; } bool TargetInfo::processReplicaUsingStMatrix( diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h index 62e266a56c..675988b714 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h @@ -43,7 +43,7 @@ class TargetInfo : public mlir::triton::TargetInfoBase { bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce) const override; + unsigned numLaneToReduce, unsigned interleave) const override; bool processReplicaUsingStMatrix( ConversionPatternRewriter &rewriter, Location loc, Value smemBase, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 0afdf6fbaf..7f811350aa 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -323,7 +323,8 @@ Value TargetInfo::programId(ConversionPatternRewriter &rewriter, Location loc, } bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce) const { + unsigned numLaneToReduce, + unsigned interleave) const { if (auto kind = matchReduxKind(op, computeCapability)) { // Based on benchmarking on A100 redux op gives a speed up only when doing // a single reduction (not partitioned) and when the mask is static. diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index 5a5a456533..9b59993e6a 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -36,7 +36,7 @@ class TargetInfo : public mlir::triton::TargetInfoBase { bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce) const override; + unsigned numLaneToReduce, unsigned interleave) const override; bool processReplicaUsingStMatrix( ConversionPatternRewriter &rewriter, Location loc, Value smemBase,