Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
974 changes: 974 additions & 0 deletions BUILD

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
4017f04e310454ccced4c404a23f7698eec735ca
6f44bb7717897191be25aa01161831c67cdf5b84
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
axisAnalysisPass(axisAnalysisPass) {}

// True if elements allocated to a thread are contiguous within the axis. This
// is not the case in MMA-like encodings wherea thread might have elements
// (0,0),(0,1) and (8,0),(8,1) for example. The problem with this is that the
// deduplication mechanism assumes that for example constancy=4 and
// elements/thread=4 that if a thread has all elements constant.
bool contiguouslyMapped(Attribute encoding) const {
if (auto slice = encoding.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
return contiguouslyMapped(slice.getParent());
}
return encoding.isa<triton::gpu::BlockedEncodingAttr>();
}

// Try to deduplicate the resultVals based on the
// constancy properties of the result discovered by
// the axis analysis pass. If possible, redundant
Expand All @@ -93,8 +105,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
if (!encoding)
// encoding not available
return resultVals;
if (!encoding.dyn_cast<BlockedEncodingAttr>() &&
!encoding.dyn_cast<SliceEncodingAttr>()) {
if (!contiguouslyMapped(encoding)) {
// TODO: constraining the ecndoing type here is necessary for avoiding
// crashes in the getElemsPerThread call below happening in the
// test_core::test_fp8_dot_acc
Expand Down
6 changes: 5 additions & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,8 @@ loadSharedToDistributed(Value dst, ArrayRef<SmallVector<Value>> dstIndices,
srcTy.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
auto srcElemTy = srcTy.getElementType();
auto dstElemTy = dstTy.getElementType();
LDBG("loadSharedToDistributed elemTy " << elemTy << " srcElemTy " << srcElemTy
<< " dstElemTy " << dstElemTy);
auto inOrd = triton::gpu::getOrder(srcSharedLayout);
auto outOrd = triton::gpu::getOrder(dstDistributedLayout);
unsigned outVec = inOrd == outOrd
Expand Down Expand Up @@ -1281,7 +1283,7 @@ loadSharedToDistributed(Value dst, ArrayRef<SmallVector<Value>> dstIndices,
auto valVec = load(wordTy, smemAddr);
valVec.setAlignment(minVec * elemTy.getIntOrFloatBitWidth() / 8);
for (unsigned v = 0; v < minVec; ++v) {
Value currVal = extract_element(dstElemTy, valVec, i32_val(v));
Value currVal = extract_element(elemTy, valVec, i32_val(v));
outVals[i * minVec + v] = currVal;
}
}
Expand Down Expand Up @@ -1407,6 +1409,8 @@ static Value packLLElements(Location loc,
<< v.value();
}
if (v.value().getType() != elementTypes[v.index()]) {
LDBG("type " << type << " structType " << structType);
LDBG("value " << v.value());
emitError(loc) << "invalid element type in packLLEElements. Expected "
<< elementTypes[v.index()] << " but got "
<< v.value().getType();
Expand Down
21 changes: 21 additions & 0 deletions include/triton/Dialect/Triton/IR/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,27 @@ template <typename VecT> bool isConsecutive(const VecT &vec) {
return isConsecutive(ArrayRef(vec));
}

// LLVM's STLExtras.h provides a bunch of functions that work over ranges, but
// it's missing min/max_element until
// https://github.com/llvm/llvm-project/commit/fab2bb8b makes it into Triton.
// TODO(jlebar): Remove this once we have the LLVM helpers.
template <typename R> auto min_element(R &&Range) {
return std::min_element(llvm::adl_begin(Range), llvm::adl_end(Range));
}
template <typename R, typename Compare>
auto min_element(R &&Range, Compare &&C) {
return std::min_element(llvm::adl_begin(Range), llvm::adl_end(Range),
std::forward<Compare>(C));
}
template <typename R> auto max_element(R &&Range) {
return std::max_element(llvm::adl_begin(Range), llvm::adl_end(Range));
}
template <typename R, typename T, typename Compare>
auto max_element(R &&Range, Compare &&C) {
return std::max_element(llvm::adl_begin(Range), llvm::adl_end(Range),
std::forward<Compare>(C));
}

} // namespace triton
} // namespace mlir

Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,8 @@ bool supportMMA(triton::DotOp op, int version) {
auto aElemTy = op.getA().getType().getElementType();
auto bElemTy = op.getB().getType().getElementType();
if (version == 3) {
if (triton::tools::getBoolEnv("DISABLE_MMA_V3"))
// TODO(b/311157761): enable mma_v3
if (!triton::tools::getBoolEnv("ENABLE_MMA_V3"))
return false;
auto retType = op.getType();
auto retShapePerCTA = getShapePerCTA(retType);
Expand Down
74 changes: 74 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,26 @@ struct ExternElementwiseOpConversion
}
};

template <typename SourceOp, typename DestOp>
struct ElementwiseOpConversion
: public ElementwiseOpConversionBase<
SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
using Base =
ElementwiseOpConversionBase<SourceOp,
ElementwiseOpConversion<SourceOp, DestOp>>;
using Base::Base;
using OpAdaptor = typename Base::OpAdaptor;

// An interface to support variant DestOp builder.
SmallVector<DestOp> createDestOps(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
return {rewriter.create<DestOp>(loc, elemTy, operands[0],
adaptor.getAttributes().getValue())};
}
};

struct ElementwiseInlineAsmOpConversion
: public ConvertOpToLLVMPattern<ElementwiseInlineAsmOp> {
using Base = ConvertOpToLLVMPattern<ElementwiseInlineAsmOp>;
Expand Down Expand Up @@ -720,6 +740,60 @@ void mlir::triton::populateClampFOpToLLVMPattern(
void mlir::triton::populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) {
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
typeConverter, axisInfoAnalysis, benefit);

POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp)
POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp)
POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp)
POPULATE_UNARY_OP(math::FloorOp, math::FloorOp)
POPULATE_UNARY_OP(math::LogOp, math::LogOp)
POPULATE_UNARY_OP(math::Log2Op, math::Log2Op)
POPULATE_UNARY_OP(math::CosOp, math::CosOp)
POPULATE_UNARY_OP(math::SinOp, math::SinOp)
POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp)
POPULATE_UNARY_OP(math::ExpOp, math::ExpOp)
POPULATE_UNARY_OP(math::Exp2Op, math::Exp2Op)
POPULATE_UNARY_OP(math::ErfOp, math::ErfOp)
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
#undef POPULATE_UNARY_OP

#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
typeConverter, axisInfoAnalysis, benefit);

POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // *
POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp)
POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp)
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
// fmin (return non-NaN if either op is non-NaN)
POPULATE_BINARY_OP(arith::MinNumFOp, LLVM::MinNumOp)
// fmax (return non-NaN if either op is non-NaN)
POPULATE_BINARY_OP(arith::MaxNumFOp, LLVM::MaxNumOp)
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax
POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin
POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax
#undef POPULATE_BINARY_OP

patterns.add<ElementwiseOpConversion<math::FmaOp, LLVM::FMAOp>>(
typeConverter, axisInfoAnalysis, benefit);

patterns.add<AddPtrOpConversion>(typeConverter, benefit);
patterns.add<CmpIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<CmpFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
Expand Down
8 changes: 5 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using ttg::SliceEncodingAttr;
// Get the highest version supported for the hardware and the dot.
static int getMMAVersionSafe(int computeCapability, tt::DotOp op) {
int baseVersion = 0;
if (computeCapability < 75) {
if (computeCapability < 80) {
baseVersion = 1;
} else if (computeCapability < 90) {
baseVersion = 2;
Expand Down Expand Up @@ -307,8 +307,10 @@ class BlockedToMMA : public mlir::RewritePattern {
} else {

// convert operands
int minBitwidth =
std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
// TODO(b/296812125): Fix minBitwidth issue upstream and uncomment.
// int minBitwidth =
// std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
int minBitwidth = 0;
Type minType = IntegerType::get(ctx, minBitwidth);
// convert A operand
auto newAEncoding = ttg::DotOperandEncodingAttr::get(
Expand Down
16 changes: 15 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,19 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <algorithm>
#include <cstdlib>
#include <cctype>
#include <memory>
#include <string>

inline bool isPipeliningEnabled() {
const char *s = std::getenv("ENABLE_PIPELINING");
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) { return std::tolower(c); });
return (str == "on" || str == "true" || str == "1");
}

namespace {

Expand Down Expand Up @@ -329,7 +341,9 @@ class TritonGPUOptimizeDotOperandsPass

mlir::RewritePatternSet patterns(context);
patterns.add<SwizzleShmemConvert>(context);
if (triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80)
// TODO(b/291216607): Fix crashes and enable by default.
if (isPipeliningEnabled() &&
triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80)
patterns.add<HoistLayoutConversion>(context);
patterns.add<FuseTransHopper>(context);
patterns.add<MMAV3UseRegOperand>(context);
Expand Down
Loading