Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions include/triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H
#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/MLIRTypes.h"

namespace mlir::triton::cpu {
class CPUTargetInfo {
public:
// Note: we may revisit for different CPU ISAs like AVX and Neon.
CPUTargetInfo() {}

Value programId(ConversionPatternRewriter &rewriter, Location loc,
LLVM::LLVMFuncOp funcOp, int axis) const;

void printf(ConversionPatternRewriter &rewriter, Value formatStrStart,
int formatStrByteCount, ValueRange args) const;

~CPUTargetInfo() {}
};
} // namespace mlir::triton::cpu
#endif // TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H
#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H

#include "CPUTargetInfo.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"

using namespace mlir;
Expand All @@ -17,6 +19,11 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
constexpr int patternBenefitClampOptimizedPattern = 20;
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;

void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const cpu::CPUTargetInfo &targetInfo,
PatternBenefit benefit);

void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);
Expand All @@ -27,6 +34,7 @@ void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter,

void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const CPUTargetInfo &targetInfo,
PatternBenefit benefit);

} // namespace cpu
Expand Down
13 changes: 4 additions & 9 deletions include/triton/Conversion/TritonCPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,10 @@
using namespace mlir;
using namespace mlir::triton;

namespace mlir {
namespace LLVM {
// TODO: Do better refactoring.
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

// TODO: Not sure we need this for CPU backends.
inline bool isKernel(FunctionOpInterface funcOp) {
return funcOp.getVisibility() == SymbolTable::Visibility::Public;
}

} // namespace LLVM
} // namespace mlir
#undef DEBUG_TYPE
#define DEBUG_TYPE "ttcpu_to_llvm"

#endif
3 changes: 3 additions & 0 deletions lib/Conversion/TritonCPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
add_triton_library(TritonCPUToLLVM
ControlFlowOpToLLVM.cpp
CPUTargetInfo.cpp
FuncOpToLLVM.cpp
PrintOpToLLVM.cpp
SPMDOpToLLVM.cpp
TypeConverter.cpp
TritonCPUToLLVM.cpp

Expand Down
49 changes: 49 additions & 0 deletions lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include "triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h"
#include "triton/Conversion/TritonCPUToLLVM/Utility.h"

namespace {
LLVM::LLVMFuncOp getPrintfDeclaration(ConversionPatternRewriter &rewriter) {
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
StringRef funcName("printf");
Operation *funcOp = moduleOp.lookupSymbol(funcName);
if (funcOp)
return cast<LLVM::LLVMFuncOp>(*funcOp);

auto *context = rewriter.getContext();

// int printf(char* format, ...)
SmallVector<Type> argsType{ptr_ty(context)};
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType, true);

ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());

return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
funcType);
}
} // namespace

namespace mlir::triton::cpu {

Value CPUTargetInfo::programId(ConversionPatternRewriter &rewriter,
Location loc, LLVM::LLVMFuncOp funcOp,
int axis) const {
assert(axis >= 0 && axis < 3);

// program_id for CPU is provided as function arguments. The last three
// arguments are __grid0 to __grid2 of i32.
assert(funcOp && funcOp.getArguments().size() >= 3);
return funcOp.getArgument(funcOp.getArguments().size() - 3 + axis);
}

void CPUTargetInfo::printf(ConversionPatternRewriter &rewriter,
Value formatStrStart, int /*formatStrByteCount*/,
ValueRange args) const {
auto loc = UnknownLoc::get(rewriter.getContext());
SmallVector<Value> formatStrAndArgs{formatStrStart};
for (auto arg : args) {
formatStrAndArgs.push_back(arg);
}
call(getPrintfDeclaration(rewriter), formatStrAndArgs);
}
} // namespace mlir::triton::cpu
131 changes: 131 additions & 0 deletions lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h"
#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h"
#include "triton/Conversion/TritonCPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

namespace {

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::cpu;

struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
explicit PrintOpConversion(LLVMTypeConverter &typeConverter,
const CPUTargetInfo &targetInfo,
PatternBenefit benefit)
: mlir::ConvertOpToLLVMPattern<triton::PrintOp>(typeConverter, benefit),
targetInfo(targetInfo) {}

LogicalResult
matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();

auto getPid = [&](int axis) {
return targetInfo.programId(
rewriter, loc, op->getParentOfType<LLVM::LLVMFuncOp>(), axis);
};
SmallVector<Value> values = {getPid(0), getPid(1), getPid(2)};

std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << "pid (" << getFormatSubstr(values[0]) << ", "
<< getFormatSubstr(values[1]) << ", " << getFormatSubstr(values[2])
<< ")" << op.getPrefix();

for (size_t i = 0; i < op.getNumOperands(); i++) {
auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter);
if (op.getOperand(i).getType().dyn_cast<RankedTensorType>()) {
llvm_unreachable("Not implemented for tensor types");
}

// Only support scalars for now.
assert(elems.size() == 1);
if (i != 0) {
os << ", ";
}
os << getFormatSubstr(elems[0]);
values.push_back(elems[0]);
}

llPrintf(formatStr, values, rewriter);
rewriter.eraseOp(op);
return success();
}

// TODO: This code is the same as the GPU-backend code. Consider refactoring.
std::string getFormatSubstr(Value value, bool hex = false,
std::optional<int> width = std::nullopt) const {
Type type = value.getType();
if (type.isa<LLVM::LLVMPointerType>()) {
return "%p";
}
// Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the
// type (so 4 for fp16, 8 for int32, 16 for int64).
if (hex) {
// Ignore `width` for `hex` values, pad to typeWidth.
std::string ret =
"0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4);
if (type.getIntOrFloatBitWidth() > 32) {
ret += "ll";
}
ret += "x";
return ret;
}

std::string prefix = "%";
if (width.has_value()) {
prefix += std::to_string(*width);
} else if (hex) {
prefix += "0";
prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4);
}

if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
return prefix + "f";
} else if (type.isSignedInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return prefix + "lli";
else
return prefix + "i";
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return prefix + "llu";
else
return prefix + "u";
}
assert(false && "not supported type");
return "";
}

Value llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter,
int *formatStrByteCount = nullptr) const {
assert(!msg.empty() && "printf with empty string not supported");
llvm::SmallString<64> msgNewline(msg);
msgNewline.push_back('\n');
msgNewline.push_back('\0');
Value msgValue =
LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()),
rewriter, "printfFormat_", msgNewline);
targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args);
if (formatStrByteCount)
*formatStrByteCount = msgNewline.size_in_bytes();
return msgValue;
}

protected:
const CPUTargetInfo &targetInfo;
};

} // namespace

void mlir::triton::cpu::populatePrintOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const CPUTargetInfo &targetInfo, PatternBenefit benefit) {
patterns.add<PrintOpConversion>(typeConverter, targetInfo, benefit);
}
39 changes: 39 additions & 0 deletions lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h"
#include "triton/Conversion/TritonCPUToLLVM/Utility.h"

namespace {

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::cpu;

struct GetProgramIdOpConversion
: public ConvertOpToLLVMPattern<triton::GetProgramIdOp> {
explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter,
const CPUTargetInfo &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<triton::GetProgramIdOp>(typeConverter, benefit),
targetInfo(targetInfo) {}

LogicalResult
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value programId = targetInfo.programId(
rewriter, op->getLoc(), op->getParentOfType<LLVM::LLVMFuncOp>(),
op.getAxisAsInt());
rewriter.replaceOp(op, programId);
return success();
}

private:
const CPUTargetInfo &targetInfo;
};

} // namespace

void mlir::triton::cpu::populateSPMDOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const CPUTargetInfo &targetInfo, PatternBenefit benefit) {
patterns.add<GetProgramIdOpConversion>(typeConverter, targetInfo, benefit);
}
5 changes: 5 additions & 0 deletions lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,15 @@ struct ConvertTritonCPUToLLVM
}

RewritePatternSet patterns(context);
mlir::triton::cpu::CPUTargetInfo targetInfo;
int benefit =
mlir::triton::cpu::patternBenefitPrioritizeOverLLVMConversions;
mlir::triton::cpu::populateControlFlowOpToLLVMPattern(typeConverter,
patterns, benefit);
mlir::triton::cpu::populatePrintOpToLLVMPattern(typeConverter, patterns,
targetInfo, benefit);
mlir::triton::cpu::populateSPMDOpToLLVMPattern(typeConverter, patterns,
targetInfo, benefit);

if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
Expand Down
7 changes: 7 additions & 0 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,13 @@ def run(self, *args, grid, warmup, **kwargs):
sigvals = sig_and_spec[:len(sigkeys)]
signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)}

# The CPU launcher will provide the grid ids directly to the kernel.
# Note that this design is interim and subject to change.
if target[0] == 'cpu':
signature["__grid0"] = 'i32'
signature["__grid1"] = 'i32'
signature["__grid2"] = 'i32'

configs = (self._get_config(*bound_vals), )
constants = {
p.name: v
Expand Down