Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ struct ConvertLayoutOpConversion
}
return res;
}
}; // namespace triton::gpu::ConvertLayoutOp>
};

void populateConvertLayoutOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
Expand Down
237 changes: 194 additions & 43 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include "Utility.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"

namespace {

using namespace mlir;
using namespace mlir::triton;

Expand All @@ -10,6 +12,23 @@ using ::mlir::LLVM::getSRegValue;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;

Value llGetPid(int axis, Location loc, ModuleOp moduleOp,
ConversionPatternRewriter &rewriter) {
assert(axis >= 0);
assert(axis < 3);
assert(moduleOp);

// It is not easy to get the compute capability here, so we use numCTAs to
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
// "%clusterid".
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);

std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid.";
sreg.append(1, 'x' + axis); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
return getSRegValue(rewriter, loc, sreg);
}

struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;

Expand Down Expand Up @@ -91,6 +110,12 @@ struct BroadcastOpConversion
}
};

// The input print op contains:
// - a "prefix" (string) specified by the user, and
// - one or more "operands" (tensors).
//
// For each operand, we print all of the values contained in this GPU thread,
// one per line, along with the index of the value in its tensor.
struct PrintOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintOp> {
using ConvertTritonGPUOpToLLVMPattern<
Expand All @@ -100,45 +125,169 @@ struct PrintOpConversion
matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
SmallVector<Value, 16> operands;
Value prefixStr =
LLVM::addStringToModule(loc, rewriter, "printfPrefix_", op.getPrefix());

auto getPid = [&](int axis) {
return llGetPid(axis, loc, op->getParentOfType<ModuleOp>(), rewriter);
};
std::array<Value, 3> pid = {getPid(0), getPid(1), getPid(2)};

// Simple printf of a string without any tensors.
if (op.getNumOperands() == 0) {
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << "pid (" << getFormatSubstr(pid[0]) << ", "
<< getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")%s";
llPrintf(formatStr, {pid[0], pid[1], pid[2], prefixStr}, rewriter);
return success();
}

for (size_t i = 0; i < op.getNumOperands(); i++) {
auto sub_operands = getTypeConverter()->unpackLLElements(
// Elements of the tensor that are resident in this GPU thread.
auto elems = getTypeConverter()->unpackLLElements(
loc, adaptor.getOperands()[i], rewriter, op.getOperand(i).getType());
for (auto elem : sub_operands) {
operands.push_back(elem);

// Get the indices of `elems` within the tensor. Note that if `elems` has
// an "interesting" layout, then these will not be in any particularly
// nice order.

// Extract the shape of the tensor being printed and use it to figure out
// how many digits we need for each of the dimensions.
SmallVector<int, 8> dimWidths;
SmallVector<SmallVector<Value>> indices;
if (auto rankedTy =
op.getOperand(i).getType().dyn_cast<RankedTensorType>()) {
indices = emitIndices(loc, rewriter, rankedTy.getEncoding(), rankedTy);
for (int64_t dim : rankedTy.getShape()) {
if (dim > 0) {
dimWidths.push_back(static_cast<int>(std::ceil(std::log10(dim))));
} else {
dimWidths.push_back(0);
}
}
} else {
// We're printing a scalar.
assert(elems.size() == 1);
indices.push_back({});
}
}
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << op.getPrefix();
if (!operands.empty()) {
os << getFormatSubstr(operands[0]);
}

for (size_t i = 1; i < operands.size(); ++i) {
os << ", " << getFormatSubstr(operands[i]);
if (!elems.empty()) {
printTensor(prefixStr, /*operand=*/i,
/*numOperands=*/op.getNumOperands(), elems, pid, indices,
dimWidths, rewriter);
}
}
llPrintf(formatStr, operands, rewriter);
rewriter.eraseOp(op);
return success();
}

std::string getFormatSubstr(Value value) const {
void printTensor(Value prefixStr, size_t operand, size_t numOperands,
ArrayRef<Value> elems, std::array<Value, 3> pid,
ArrayRef<SmallVector<Value>> indices,
ArrayRef<int> dimWidths,
ConversionPatternRewriter &rewriter) const {
assert(!elems.empty());
assert(elems.size() == indices.size());
assert(dimWidths.size() == indices.front().size());

size_t rank = dimWidths.size();

// Format is:
// pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...)<prefix> (operand <n>) <elem>
// where we leave off "(operand <n>)" if there's only one operand.
//
// The Python wrapper munges `prefix` so that it prints nicely (e.g. starts
// with " " and ends with ": ").

Value formatStrValue;
for (int i = 0; i < elems.size(); i++) {
std::string formatStr;
llvm::raw_string_ostream os(formatStr);

// nvptx printf can only accept 32 args; if we pass more than that, it
// will print garbage for the trailing args.
constexpr int kMaxPrintfOperands = 32;
SmallVector<Value, kMaxPrintfOperands> printfOperands;

// TODO(jlebar): We really should pad the pid, but because the max pid is
// not known at compile-time, this would require nontrivial device-side
// work.
os << "pid (";
for (int j = 0; j < pid.size(); j++) {
if (j != 0) {
os << ", ";
}
os << getFormatSubstr(pid[j]);
printfOperands.push_back(pid[j]);
}
os << ") ";

// If `rank` is large enough, we could end up exceeding
// kMaxPrintfOperands. In that case, just truncate the index.
// (Subtract 2 because we're going to add two operands after the index.)
int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2;

os << "idx (";
const auto &index = indices[i];
for (size_t dim = 0; dim < index.size(); dim++) {
if (dim != 0) {
os << ", ";
}
if (dim == maxAllowedRank) {
os << "... (truncated)";
break;
}
os << getFormatSubstr(index[dim], /*width=*/dimWidths[dim]);
printfOperands.push_back(index[dim]);
}
os << ")";

os << "%s";
printfOperands.push_back(prefixStr);

if (numOperands > 1) {
os << "(operand " << operand << ") ";
}

auto elem = elems[i];
os << getFormatSubstr(elem);
printfOperands.push_back(elem);

// It's the same format string each iteration, but it's a lot easier if we
// construct the format string at the same time as we populate
// printfOperands. But we don't want to create BLOCK_SIZE duplicate
// strings, so we cache the Value.
if (i == 0) {
formatStrValue = llPrintf(formatStr, printfOperands, rewriter);
} else {
llPrintf(formatStrValue, printfOperands, rewriter);
}
}
}

std::string getFormatSubstr(Value value,
std::optional<int> width = std::nullopt) const {
std::string prefix = "%";
if (width.has_value()) {
prefix += std::to_string(*width);
}

Type type = value.getType();
if (type.isa<LLVM::LLVMPointerType>()) {
return "%p";
return prefix + "p";
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
return "%f";
return prefix + "f";
} else if (type.isSignedInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return "%lli";
return prefix + "lli";
else
return "%i";
return prefix + "i";
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return "%llu";
return prefix + "llu";
else
return "%u";
return prefix + "u";
}
assert(false && "not supported type");
return "";
Expand Down Expand Up @@ -194,9 +343,22 @@ struct PrintOpConversion
return {newType, newOp};
}

static void llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
// Returns a Value for the format string, which you can reuse.
static Value llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
assert(!msg.empty() && "printf with empty string not supported");
llvm::SmallString<64> msgNewline(msg);
msgNewline.push_back('\n');
msgNewline.push_back('\0');
Value msgValue =
LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()),
rewriter, "printfFormat_", msgNewline);
llPrintf(msgValue, args, rewriter);
return msgValue;
}

static void llPrintf(Value msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
Type int8Ptr = ptr_ty(i8_ty);

auto *ctx = rewriter.getContext();
Expand All @@ -208,11 +370,6 @@ struct PrintOpConversion
Value one = i32_val(1);
Value zero = i32_val(0);

llvm::SmallString<64> msgNewline(msg);
msgNewline.push_back('\n');
msgNewline.push_back('\0');
Value prefixString =
LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline);
Value bufferPtr = null(int8Ptr);

SmallVector<Value, 16> newArgs;
Expand Down Expand Up @@ -240,7 +397,7 @@ struct PrintOpConversion
bufferPtr = bitcast(allocated, int8Ptr);
}

SmallVector<Value> operands{prefixString, bufferPtr};
SmallVector<Value> operands{msg, bufferPtr};
call(funcOp, operands);
}
};
Expand Down Expand Up @@ -390,20 +547,8 @@ struct GetProgramIdOpConversion
LogicalResult
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// It is not easy to get the compute capability here, so we use numCTAs to
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
// "%clusterid".
auto moduleOp = op->getParentOfType<ModuleOp>();
assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp");
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);

Location loc = op->getLoc();
assert(op.getAxisAsInt() < 3);
std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid.";
sreg.append(1, 'x' + op.getAxisAsInt()); // 0 -> 'x', 1 -> 'y', 2 -> 'z'

Value programId = getSRegValue(rewriter, loc, sreg);
Value programId = llGetPid(op.getAxisAsInt(), op->getLoc(),
op->getParentOfType<ModuleOp>(), rewriter);
rewriter.replaceOp(op, programId);
return success();
}
Expand Down Expand Up @@ -685,6 +830,10 @@ struct AsyncBulkCommitGroupOpConversion
}
};

} // namespace

namespace mlir::triton {

void populateTritonGPUToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
Expand All @@ -710,3 +859,5 @@ void populateTritonGPUToLLVMPatterns(
patterns.add<PrintOpConversion>(typeConverter, benefit);
patterns.add<AssertOpConversion>(typeConverter, benefit);
}

} // namespace mlir::triton
4 changes: 4 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
using namespace mlir;
using namespace mlir::triton;

namespace mlir::triton {

void populateTritonGPUToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);

} // namespace mlir::triton

#endif
8 changes: 4 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ class ConvertTritonGPUOpToLLVMPatternBase {
// Key: {layout, shape, withCTAOffset}
struct IndexCacheInfo {
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
*baseIndexCache;
*baseIndexCache = nullptr;
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
CacheKeyDenseMapInfo> *indexCache;
OpBuilder::InsertPoint *indexInsertPoint;
CacheKeyDenseMapInfo> *indexCache = nullptr;
OpBuilder::InsertPoint *indexInsertPoint = nullptr;
};

explicit ConvertTritonGPUOpToLLVMPatternBase(
Expand Down Expand Up @@ -778,7 +778,7 @@ class ConvertTritonGPUOpToLLVMPatternBase {
emitIndicesForDistributedLayout(loc, b, slice, type, withCTAOffset);
} else {
llvm_unreachable(
"emitIndices for layouts other than blocked & slice not "
"emitIndices for layouts other than blocked, mma, and slice not "
"implemented yet");
}
if (cache) {
Expand Down
Loading