Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUStreamPipelineV2();
mlir::registerTritonAMDGPUCanonicalizePointers();
mlir::registerTritonAMDGPUConvertToBufferOps();
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
Expand Down
1 change: 1 addition & 0 deletions docs/python-api/triton.language.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Linear Algebra Ops
:nosignatures:

dot
dot_scaled


Memory/Pointer Ops
Expand Down
4 changes: 1 addition & 3 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,10 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);

bool atomicNeedsSharedMemory(Value result);

bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// Return true if the src and dst layout match.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,6 @@ namespace mlir::triton {

namespace gpu {

SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
Type ouType);

SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter);

SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter);

Type getElementType(Value value);

class MultipleOperandsRange
Expand Down Expand Up @@ -187,8 +176,6 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
for (auto operand : adaptor.getOperands()) {
auto argTy = op->getOperand(0).getType();
auto subOperands = unpackLLElements(loc, operand, rewriter);
subOperands = unpackI32(subOperands, argTy, rewriter, loc,
this->getTypeConverter());
allOperands.resize(subOperands.size());
for (auto v : llvm::enumerate(subOperands))
allOperands[v.index()].push_back(v.value());
Expand All @@ -209,13 +196,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
}
it += curr.size();
}
if (op->getNumOperands() > 0) {
auto argTy = op->getOperand(0).getType();
resultVals = reorderValues(resultVals, argTy, resultTy);
}
resultVals = maybeDeduplicate(op, resultVals);
resultVals =
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
Value view = packLLElements(loc, this->getTypeConverter(), resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, view);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,33 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
constexpr int patternBenefitClampOptimizedPattern = 20;
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;

struct BackendCallbacks {
/**
* A backend-specific callback for appending auxiliary data during
* `LocalStoreOp` conversion.
*
* @param[in] op The reference to the re-written `LocalStoreOp`.
* @param[in] count The number of issued LLVM instructions.
* @param[in] type The input type of issued LLVM instructions.
*/
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
Type llvmOpType)>
localStoreOpConversion = nullptr;
};

void populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
RewritePatternSet &patterns,
PatternBenefit benefit);
// The given callback is invoked at the end of a successful rewrite. The
// callback receives 1) the current source op, 2) the number of issued LLVM
// instructions and 3) their input types. Each MLIR backend can provide a
// callback and, thus, handle backend-specific behaviors.
void populateMemoryOpToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit,
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);

void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
Expand Down
30 changes: 22 additions & 8 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,19 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal);
return base;
}

// -----------------------------------------------------------------------
// MXFP utilities
// -----------------------------------------------------------------------

// Convert each value, which is an int8 containing 2 packed mxfp4 values,
// into 2 standalone bf16 values
SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
ArrayRef<Value> values);

// Scale a mxfp4 value by a given scale.
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale);

} // namespace LLVM

/* ------------------------------------ */
Expand Down Expand Up @@ -453,15 +466,16 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter,
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
auto order = blockedLayout.getOrder();
auto threadOrder = blockedLayout.getThreadOrder();
auto warpOrder = blockedLayout.getWarpOrder();
auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape);
unsigned rank = shape.size();

// delinearize threadId to get the base index
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);

SmallVector<Value> multiDimBase(rank);
for (unsigned k = 0; k < rank; ++k) {
Expand Down Expand Up @@ -1366,11 +1380,11 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);

void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
Type elemLlvmTy, ArrayRef<Value> srcVals,
Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);
void storeDistributedToShared(
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount = nullptr);

inline Value getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
Expand Down
9 changes: 5 additions & 4 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,16 @@ def TT_InputPrecisionAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}

// Type for F8F6F4 kind of floats.
def TT_F8F6F4TypeAttr : I32EnumAttr<
"F8F6F4Type", "",
// Type for ScaleDotElemType kind of floats.
def TT_ScaleDotElemTypeAttr : I32EnumAttr<
"ScaleDotElemType", "",
[
I32EnumAttrCase<"E4M3", 0, "e4m3">,
I32EnumAttrCase<"E5M2", 1, "e5m2">,
I32EnumAttrCase<"E2M3", 2, "e2m3">,
I32EnumAttrCase<"E3M2", 3, "e3m2">,
I32EnumAttrCase<"E2M1", 4, "e2m1">
I32EnumAttrCase<"E2M1", 4, "e2m1">,
I32EnumAttrCase<"BF16", 5, "bf16">

]>{
let cppNamespace = "::mlir::triton";
Expand Down
16 changes: 8 additions & 8 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -685,15 +685,15 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,

let arguments = (
ins
// inputs are integer types as they are packed types and we currently
// don't have a representation for those.
TT_IntTensor:$lhs,
TT_IntTensor:$rhs,
// inputs are floats if we have a type for them, otherwise (fp4),
// they are packed in pairs in an I8Tensor
RankedTensorOf<[TT_Float,I8]>:$lhs,
RankedTensorOf<[TT_Float,I8]>:$rhs,
TT_FloatTensor:$c,
TT_IntTensor:$lhs_scale,
Optional<TT_IntTensor>:$rhs_scale,
TT_F8F6F4TypeAttr:$lhs_type,
TT_F8F6F4TypeAttr:$rhs_type
RankedTensorOf<[I8]>:$lhs_scale,
Optional<RankedTensorOf<[I8]>>:$rhs_scale,
TT_ScaleDotElemTypeAttr:$lhs_type,
TT_ScaleDotElemTypeAttr:$rhs_type
);

let results = (outs TT_FloatTensor:$d);
Expand Down
16 changes: 13 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ SmallVector<unsigned>
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);

// Returns the dimensions of the tensor from minor (fast-varying) to
// major (slow-varying). For blocked, mma, and dotOperand layouts,
// though the elements are in registers, the order refers to memory
// layout of the original tensor in global memory.
// major (slow-varying). For distributed layouts, this represents
// the order of the elements within a thread.
// For shared Layout, the order refers to which dimension of the original tensor
// is contiguous in shared memory.
SmallVector<unsigned> getOrder(Attribute layout);
Expand Down Expand Up @@ -130,6 +129,17 @@ unsigned getNumWarpsPerCTA(Attribute layout);

unsigned getNumCTAs(Attribute layout);

// Return the order that represents that the batch is in row-major or
// column-major order for a batch of matrices of shape [*, m, n] with
// len(shape) == rank.
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);

// Return the order that represents that the dot operand is in kMajor
// (contiguous in the inner dimension) or it's contiguous on the outer
// dimension.
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor);

bool isExpensiveCat(CatOp cat, Attribute targetEncoding);

// Return true if a view between the two types cannot be implemented as a no-op.
Expand Down
59 changes: 41 additions & 18 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ compared to 1*64 when the hasLeadingOffset is false.
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}

// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
// ---- begin Ampere & Hopper ----
if (mmaEnc.isAmpere() || mmaEnc.isHopper()) {
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
perPhase = std::max<int>(perPhase, 1);
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
Expand Down Expand Up @@ -397,13 +397,6 @@ compared to 1*64 when the hasLeadingOffset is false.
llvm_unreachable("invalid operand index");
}

// ---- begin version 3 ----
if (mmaEnc.isHopper()) {
llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr"
" is Hopper has not been implemented yet");
return $_get(context, 1, 1, 1, order, CTALayout, true);
}

// ---- not implemented ----
llvm_unreachable("unsupported swizzling for provided MMA version");
}]>,
Expand Down Expand Up @@ -481,9 +474,16 @@ layout = [0 4 8 12]
[3 7 11 15]

For the Threads Per Warp and Values Per Thread level, the linear id distribution is variant for each sub-class encoding.

If the layout does not completely cover the tensor, we tile it until we cover the entire tensor.
We call each individual tile "rep".
}];

let methods = [
InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first",
"SmallVector<unsigned>",
"getRepOrder">,

// Interface for the meta information about the multiple thread hierarchy.
InterfaceMethod<"Get the shape of the CTAs per CGA.",
"SmallVector<unsigned>",
Expand Down Expand Up @@ -570,6 +570,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
}];

code extraDistributedDeclaration = extraBaseClassDeclaration # [{
SmallVector<unsigned> getRepOrder() const;
SmallVector<unsigned> getCTAsPerCGA() const;
SmallVector<unsigned> getCTAOrder() const;
SmallVector<unsigned> getCTASplitNum() const;
Expand Down Expand Up @@ -921,6 +922,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;

SmallVector<unsigned> getContigPerThread() {
auto rank = getWarpsPerCTA().size();
Expand Down Expand Up @@ -1029,6 +1031,7 @@ Row | warp 0 warp 2
SmallVector<int64_t> getElemsPerInstrForOperands() const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
static SmallVector<unsigned> getMNKDimPerInstr();

SmallVector<unsigned> getContigPerThread() {
Expand Down Expand Up @@ -1206,8 +1209,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
bool isAmpere() const;
bool isHopper() const;

unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef<int64_t> shape) const;

// Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor
std::tuple<bool, bool, bool, bool, int> decodeVoltaLayoutStates() const;

Expand All @@ -1224,8 +1225,9 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
SmallVector<int> getMMAv1Rep(int opIdx) const;
SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
int getMMAv1Vec(int opIdx) const;
SmallVector<int64_t> getMMAv2RepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int kWidth, int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;

bool supportReduction() const {
if (isAmpere() || isHopper()) {
Expand Down Expand Up @@ -1319,6 +1321,27 @@ The parent field is the layout of d.
kWidth defines number of consecutive elements stored by one thread along k dimension.
Some layouts do not use this parameter, either because they have a fixed number of
elements along the K dim, or they use all elements of the tensor along the K dim.

# WGMMA Notes
We require kWidth to be provided for Hopper because the dtype at loading might be
different from the dtype at WGMMA, due to casting. The kWidth is determined by the
dtype at WGMMA.

The encoded tensor consists of operand A for possibly multiple wgmma instructions.
For each wgmma, each warp in a warp group feeds a single "warp matrix"
Each warp matrix consists of 2x2 "quads".
Each thread holds several elements in each quad. Right before a wgmma,
the sum of bitwidth of
the elements in each quad should add up to 32.

These values are stored unrolled in `elements`.
The ordering of dimensions is as follows by convention:
batch (only 1 batch for Hopper currently)
matM (m-index of the "warp matrix")
matK (k-index of the "warp matrix")
quadK (k-index of the "quad" in the core matrix)
quadM (m-index of the "quad" in the core matrix)
vecIdx (index of the element in the quad; this is always along the k-dim)
}];

let parameters = (
Expand All @@ -1329,16 +1352,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim
);

let builders = [
// Specially for MMAV1(Volta)
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent,
"Type":$eltTy), [{
NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent);
if (!parentAttr || !parentAttr.isAmpere())
return $_get(context, opIdx, parent, 0);
if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper()))
return $_get(context, opIdx, parent, 0); // For MMAV1
// For MMAV2 and V3
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
unsigned MMAv2kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, MMAv2kWidth);
unsigned kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, kWidth);
}]>
];

Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<In
let arguments = (ins
TT_Tensor:$src,
TT_Tensor:$scale,
TT_F8F6F4TypeAttr:$fp_type);
TT_ScaleDotElemTypeAttr:$fp_type);
let results = (outs TT_Tensor:$result);

let assemblyFormat = [{
Expand Down
Loading