Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class TargetInfoBase {

virtual bool warpReduce(ConversionPatternRewriter &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce) const = 0;
unsigned numLaneToReduce,
unsigned interleave) const = 0;

virtual bool processReplicaUsingStMatrix(
ConversionPatternRewriter &rewriter, Location loc, Value smemBase,
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ struct ReduceOpConversion
void warpReduce(ConversionPatternRewriter &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce, unsigned interleave) const {
auto success =
targetInfo.warpReduce(rewriter, loc, acc, op, numLaneToReduce);
auto success = targetInfo.warpReduce(rewriter, loc, acc, op,
numLaneToReduce, interleave);
if (success)
return;
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
Expand Down
3 changes: 2 additions & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ Value TargetInfo::programId(ConversionPatternRewriter &rewriter, Location loc,

bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce) const {
unsigned numLaneToReduce,
unsigned interleave) const {
return false;
}

Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class TargetInfo : public mlir::triton::TargetInfoBase {

bool warpReduce(ConversionPatternRewriter &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce) const override;
unsigned numLaneToReduce, unsigned interleave) const override;

bool processReplicaUsingStMatrix(
ConversionPatternRewriter &rewriter, Location loc, Value smemBase,
Expand Down
2 changes: 2 additions & 0 deletions third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_triton_library(TritonIntelGPUToLLVM
TypeConverter.cpp
Utility.cpp
ViewOpToLLVM.cpp
GenIntrinsicHelper.cpp

DEPENDS
TritonIntelGPUConversionPassIncGen
Expand All @@ -33,4 +34,5 @@ add_triton_library(TritonIntelGPUToLLVM
TritonGENIR
TritonGENToLLVM
TritonIntelGPUIR
MLIRTargetLLVMIRImport
)
145 changes: 145 additions & 0 deletions third_party/intel/lib/TritonIntelGPUToLLVM/GenIntrinsicHelper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#include "Utility.h"

#include "GenIntrinsicHelper.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Target/LLVMIR/ModuleImport.h"
#include "mlir/Target/LLVMIR/TypeFromLLVM.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"

namespace mlir {
namespace triton {
namespace intel {

// The code convert the function attribute from the original here:
// https://github.com/llvm/llvm-project/blob/e575b7cb7a64297583d6382c16ce264d9fe45d08/mlir/lib/Target/LLVMIR/ModuleImport.cpp#L1547
// List of LLVM IR attributes that map to an explicit attribute on the MLIR
// LLVMFuncOp.
static constexpr std::array ExplicitAttributes{
StringLiteral("aarch64_pstate_sm_enabled"),
StringLiteral("aarch64_pstate_sm_body"),
StringLiteral("aarch64_pstate_sm_compatible"),
StringLiteral("aarch64_new_za"),
StringLiteral("aarch64_preserves_za"),
StringLiteral("aarch64_in_za"),
StringLiteral("aarch64_out_za"),
StringLiteral("aarch64_inout_za"),
StringLiteral("vscale_range"),
StringLiteral("frame-pointer"),
StringLiteral("target-features"),
StringLiteral("unsafe-fp-math"),
StringLiteral("no-infs-fp-math"),
StringLiteral("no-nans-fp-math"),
StringLiteral("approx-func-fp-math"),
StringLiteral("no-signed-zeros-fp-math"),
};

static void processPassthroughAttrs(llvm::Function *func,
mlir::LLVM::LLVMFuncOp funcOp) {
MLIRContext *context = funcOp.getContext();
SmallVector<Attribute> passthroughs;
llvm::AttributeSet funcAttrs = func->getAttributes().getAttributes(
llvm::AttributeList::AttrIndex::FunctionIndex);
for (llvm::Attribute attr : funcAttrs) {
// Skip the memory attribute since the LLVMFuncOp has an explicit memory
// attribute.
if (attr.hasAttribute(llvm::Attribute::Memory))
continue;

// Skip invalid type attributes.
if (attr.isTypeAttribute()) {
emitWarning(funcOp.getLoc(),
"type attributes on a function are invalid, skipping it");
continue;
}

StringRef attrName;
if (attr.isStringAttribute())
attrName = attr.getKindAsString();
else
attrName = llvm::Attribute::getNameFromAttrKind(attr.getKindAsEnum());
auto keyAttr = StringAttr::get(context, attrName);

// Skip attributes that map to an explicit attribute on the LLVMFuncOp.
if (llvm::is_contained(ExplicitAttributes, attrName))
continue;

if (attr.isStringAttribute()) {
StringRef val = attr.getValueAsString();
if (val.empty()) {
passthroughs.push_back(keyAttr);
continue;
}
passthroughs.push_back(
ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
continue;
}
if (attr.isIntAttribute()) {
auto val = std::to_string(attr.getValueAsInt());
passthroughs.push_back(
ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
continue;
}
if (attr.isEnumAttribute()) {
passthroughs.push_back(keyAttr);
continue;
}

llvm_unreachable("unexpected attribute kind");
}

if (!passthroughs.empty())
funcOp.setPassthroughAttr(ArrayAttr::get(context, passthroughs));
}

mlir::LLVM::LLVMFuncOp
appendOrGetGenISADeclaration(OpBuilder &builder, llvm::GenISAIntrinsic::ID id,
ArrayRef<mlir::Type *> mlirTys) {
auto mlirContext = builder.getContext();

SmallVector<llvm::Type *, 4> llvmTys;
llvm::LLVMContext llvmContext;
std::unique_ptr<llvm::Module> llvmModule =
std::make_unique<llvm::Module>("temp", llvmContext);
mlir::LLVM::TypeToLLVMIRTranslator llvmToMLIR(llvmContext);
for (mlir::Type *ty : mlirTys) {
llvmTys.push_back(llvmToMLIR.translateType(*ty));
}
auto llvmFunc =
llvm::GenISAIntrinsic::getDeclaration(llvmModule.get(), id, llvmTys);

auto genISAName = llvmFunc->getName();
auto llvmFuncType = llvmFunc->getFunctionType();
LLVM::TypeFromLLVMIRTranslator mlirFromLLVM(*mlirContext);
auto mlirFuncTy = mlirFromLLVM.translateType(llvmFuncType);
mlir::LLVM::LLVMFunctionType funcTy =
cast<mlir::LLVM::LLVMFunctionType>(mlirFuncTy);

auto funcName = StringAttr::get(mlirContext, genISAName);
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(
builder.getBlock()
->getParent()
->getParentOfType<mlir::LLVM::LLVMFuncOp>(),
funcName);

if (funcOp)
return cast<mlir::LLVM::LLVMFuncOp>(*funcOp);

auto parent = builder.getBlock()
->getParent()
->getParentOfType<mlir::LLVM::LLVMFuncOp>();
mlir::OpBuilder b(parent);
auto ret =
b.create<LLVM::LLVMFuncOp>(mlir::UnknownLoc::get(mlirContext), genISAName,
funcTy, LLVM::Linkage::External,
/*dsoLocal*/ false, LLVM::CConv::C,
/*comdat=*/SymbolRefAttr{});

processPassthroughAttrs(llvmFunc, ret);

return ret;
}

} // namespace intel
} // namespace triton
} // namespace mlir
115 changes: 115 additions & 0 deletions third_party/intel/lib/TritonIntelGPUToLLVM/GenIntrinsicHelper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
//===- GenIntrinsicHelper.h - Gen intrinsic helper ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef TRITON_VCINTRINSICHELPER_H
#define TRITON_VCINTRINSICHELPER_H

#include "TritonGENToLLVM/GenIntrinsics.h"
#include "Utility.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <mlir/Dialect/SPIRV/IR/SPIRVDialect.h>
#include <string>

namespace mlir {
namespace triton {
namespace intel {

mlir::LLVM::LLVMFuncOp
appendOrGetGenISADeclaration(OpBuilder &builder, llvm::GenISAIntrinsic::ID id,
ArrayRef<mlir::Type *> mlirTys);

class GenISA_WaveAll {
public:
// Enum def copied from IGC.
enum class WaveOps : unsigned int {
SUM,
PROD,
UMIN,
UMAX,
IMIN,
IMAX,
OR,
XOR,
AND,
FSUM,
FPROD,
FMIN,
FMAX,
UNDEF
};

explicit GenISA_WaveAll(OpBuilder &builder, Type retTy) : builder(builder) {
// get GenISA intrinsic declaration.
intrinsicDecl = appendOrGetGenISADeclaration(
builder, llvm::GenISAIntrinsic::ID::GenISA_WaveAll, {&retTy});
}

template <typename... Args>
Value operator()(OpBuilder &rewriter, Location loc, Args... args) {
auto funName = intrinsicDecl.getName();
auto retType = intrinsicDecl.getResultTypes();
auto funCall = rewriter.create<LLVM::CallOp>(loc, retType, funName,
ValueRange{args...});
return funCall.getResult();
}

private:
OpBuilder &builder;
LLVM::LLVMFuncOp intrinsicDecl;
};

class GenISA_WaveCluster {
public:
// Enum def copied from IGC.
enum class WaveOps : unsigned int {
SUM,
PROD,
UMIN,
UMAX,
IMIN,
IMAX,
OR,
XOR,
AND,
FSUM,
FPROD,
FMIN,
FMAX,
UNDEF
};

explicit GenISA_WaveCluster(OpBuilder &builder, Type retTy)
: builder(builder) {
// get GenISA intrinsic declaration.
intrinsicDecl = appendOrGetGenISADeclaration(
builder, llvm::GenISAIntrinsic::ID::GenISA_WaveClustered, {&retTy});
}

template <typename... Args>
Value operator()(OpBuilder &rewriter, Location loc, Args... args) {
auto funName = intrinsicDecl.getName();
auto retType = intrinsicDecl.getResultTypes();
auto funCall = rewriter.create<LLVM::CallOp>(loc, retType, funName,
ValueRange{args...});
return funCall.getResult();
}

private:
OpBuilder &builder;
LLVM::LLVMFuncOp intrinsicDecl;
};

} // namespace intel
} // namespace triton
} // namespace mlir

#endif // TRITON_VCINTRINSICHELPER_H
8 changes: 4 additions & 4 deletions third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ struct ReduceOpConversion
}
}

// Apply warp reduction across the given number of contiguous lanes using op
// region and the accumulator values as source.
// Apply warp reduction across the given number of the lanes and the
// interleave using op region and the accumulator values as source.
void warpReduce(ConversionPatternRewriter &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce, unsigned interleave) const {
auto success =
targetInfo.warpReduce(rewriter, loc, acc, op, numLaneToReduce);
auto success = targetInfo.warpReduce(rewriter, loc, acc, op,
numLaneToReduce, interleave);
if (success)
return;
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
Expand Down
Loading