Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
852f156
Add blocked to dot shortcut
oplavsic Jan 5, 2025
202dd21
pack tensors in vectors instead of structures
binarman Aug 8, 2024
b854eab
fix
binarman Aug 8, 2024
502b26c
add moe bypass option
binarman Aug 8, 2024
21f5acc
initial commit
Aug 7, 2024
246a377
fix
Aug 8, 2024
75a6e1a
fix
Aug 9, 2024
c9d65d5
add missing configurations and add more checks in passes
binarman Aug 9, 2024
8e4bd03
adjust global load layout for vllm swizzling format
binarman Aug 10, 2024
b9f30af
Remove debug print
vgokhale Aug 11, 2024
fe2820b
make load width dependable on data type
binarman Aug 20, 2024
33e9591
fix int 8 logic
binarman Aug 21, 2024
29333de
generalize load analysis: return last load in dependant laod chain in…
binarman Aug 21, 2024
3072f2d
Add message for assert failure
zhanglx13 Aug 22, 2024
96b27bc
add k=512/1024 cases
binarman Aug 26, 2024
a06c437
[BACKEND] Add memory space to memdesc type. (#4027)
ThomasRaoux May 29, 2024
59391fd
[BACKEND] Fix memory side effects of `tt.dot` (#4033)
Jokeren May 30, 2024
5c2bd6a
remove streamPipelinev2
oplavsic Jan 5, 2025
3327b4a
[TEST] NFC: Drop irrelevant NVIDIA specific attributes (#4384)
antiagainst Jul 24, 2024
9728b51
[Pipeliner] NFC: Expose Pipeliner infrastructure for use by other tar…
sjw36 Jun 19, 2024
5f4ee26
[BACKEND] Fix regression in pipeliner pre-checks. (#4196)
ThomasRaoux Jun 24, 2024
fe8a8bd
[Backend][AMD] Introduce stream pipeliner v2 (#4148)
sjw36 Jul 29, 2024
50bff3d
[AMD] Prefetch loads and independent local_stores (#4429)
sjw36 Aug 1, 2024
a6d115a
[Pipeliner] Implement dynamic loop peeling
Aug 1, 2024
e4d0a43
* disabled for num_stages > 2
Aug 1, 2024
7265557
* guard each stage of ramp-down in epilogue
Aug 7, 2024
6cb24f1
* pipeline reg buffers
Aug 8, 2024
1636444
[AMD] Fixed bug with tritonamdgpu-reorder-instructions
Aug 9, 2024
62f4fbc
* only move ops early
Aug 10, 2024
07d0178
Fix in streamPipelinerV2
oplavsic Jan 5, 2025
2639388
Fix lit tests
oplavsic Jan 6, 2025
b74c8cf
[Backend][AMD] Add temporary environment variable for pipeliner v2 (#…
antiagainst Jul 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUAccelerateMatmul();
mlir::registerTritonAMDGPUOptimizeEpilogue();
mlir::registerTritonAMDGPUReorderInstructions();
mlir::registerTritonAMDGPUBypassLDSForDotLayout();
mlir::registerTritonAMDGPUStreamPipeline();
mlir::registerTritonAMDGPUStreamPipelineV2();

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ bool isSingleValue(Value value);

bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);

bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
Expand Down
35 changes: 35 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,20 @@ inline SmallVector<Value>
unpackLLElements(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
assert(bool(llvmStruct) && "can not unpack null values");

if (isMoeLDSBypass()) {
auto llvmVec = llvmStruct;
auto vecTy = dyn_cast<::mlir::VectorType>(llvmVec.getType());
if (vecTy) {
auto elemTy = vecTy.getElementType();
auto elemNo = vecTy.getDimSize(0);
SmallVector<Value> results;
for (int elem = 0; elem < elemNo; ++elem)
results.push_back(extract_element(llvmVec, i32_val(elem)));
return results;
}
}

if (llvmStruct.getType().isIntOrIndexOrFloat() ||
isa<triton::PointerType>(llvmStruct.getType()) ||
isa<LLVM::LLVMPointerType>(llvmStruct.getType()))
Expand All @@ -1550,6 +1564,26 @@ inline Value packLLElements(Location loc,
const LLVMTypeConverter *typeConverter,
ValueRange resultVals,
ConversionPatternRewriter &rewriter, Type type) {
if (isMoeLDSBypass()) {
auto firstType = resultVals[0].getType();
bool sameTypes = true;
for (int i = 1; i < resultVals.size(); ++i)
if (resultVals[i].getType() != firstType) {
sameTypes = false;
break;
}
if (sameTypes && firstType.isIntOrFloat() && resultVals.size() > 1) {
// Packing into vector instead of structure, to prevent LLVM splitting
// structure into separate values
auto vecTy = vec_ty(firstType, resultVals.size());
Value llvmVector = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (const auto &v : llvm::enumerate(resultVals))
llvmVector =
insert_element(vecTy, llvmVector, v.value(), i32_val(v.index()));
return llvmVector;
}
}

auto structType =
dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType(type));
if (!structType) {
Expand All @@ -1562,6 +1596,7 @@ inline Value packLLElements(Location loc,
emitError(loc) << " size mismatch when packing elements for LLVM struct"
<< " expected " << elementTypes.size() << " but got "
<< resultVals.size();
assert(false);
}
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
for (const auto &v : llvm::enumerate(resultVals)) {
Expand Down
4 changes: 4 additions & 0 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
namespace mlir {
namespace triton {

void enableMoeLDSBypass(bool value);

bool isMoeLDSBypass();

struct GlobalMemory : public SideEffects::Resource::Base<GlobalMemory> {
StringRef getName() final { return "<GlobalMemory>"; }
};
Expand Down
48 changes: 46 additions & 2 deletions include/triton/Dialect/Triton/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LogicalResult.h"

#include <iostream>
#include "triton/Dialect/Triton/IR/Types.h"

namespace mlir {
namespace OpTrait {
Expand Down Expand Up @@ -58,6 +57,51 @@ class VerifyTensorLayoutsTrait
}
};

// Verify if the op is a dot-like operation.
// A dot-like operation should have three operands.
// The first two operands should share a common dimension, and the result
// should have the dimensions of the two operands that are not shared.
// A dot-like operation can be either 2d or 3d.
// In the 3d case, the first dimension of operands is the batch dimension.
template <class ConcreteType>
class DotLike : public TraitBase<ConcreteType, DotLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
if (op->getNumOperands() != 3)
return op->emitOpError("expected 3 operands");
auto aTy = cast<TensorOrMemDesc>(op->getOperand(0).getType());
auto bTy = cast<TensorOrMemDesc>(op->getOperand(1).getType());
auto cTy = cast<TensorType>(op->getOperand(2).getType());
auto aShape = aTy.getShape();
auto bShape = bTy.getShape();
auto cShape = cTy.getShape();
// Check if all 3d or all 2d
if (aShape.size() != 2 && aShape.size() != 3)
return op->emitOpError("expected operands to be 2d or 3d");
if (aShape.size() != bShape.size() || aShape.size() != cShape.size())
return op->emitOpError("expected all operands to have the same rank");
// Check if the first two operands share a common dimension
if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2])
return op->emitOpError("expected the last dimension of the first operand "
"to be equal to the second-to-last dimension of "
"the second operand");
// Check the batch dimension
if (aShape.size() == 3 &&
(aShape[0] != cShape[0] || bShape[0] != cShape[0]))
return op->emitOpError("expected the first dimension of the first "
"operand to be equal to the first dimension of "
"the result");
// Check the output shape
if (cShape[cShape.size() - 2] != aShape[aShape.size() - 2] ||
cShape[cShape.size() - 1] != bShape[aShape.size() - 1])
return op->emitOpError(
"expected the output shape to be the concatenation of the last "
"dimension of the first operand and the last dimension of the "
"second ");
return success();
}
};

template <typename ConcreteType>
class SameOperandsAndResultEncoding
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {
Expand Down
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/IR/TritonInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ include "mlir/IR/OpBase.td"

def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">;
def DotLike : NativeOpTrait<"DotLike">;
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">;
Expand Down
5 changes: 3 additions & 2 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> {
//
def TT_DotOp : TT_Op<"dot", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
DotLike,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot";
Expand All @@ -640,8 +641,8 @@ def TT_DotOp : TT_Op<"dot", [Pure,

let arguments = (
ins
TT_TensorOrMemDesc:$a,
TT_TensorOrMemDesc:$b,
TT_FpIntTensor:$a,
TT_FpIntTensor:$b,
TT_FpIntTensor:$c,
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
DefaultValuedAttr<I32Attr, "0">:$maxNumImpreciseAcc
Expand Down
14 changes: 9 additions & 5 deletions include/triton/Dialect/Triton/IR/TritonTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def TT_BoolTensor : RankedTensorOf<[I1]>;
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;

// Integer Type
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
def I4 : I<4>;
def TT_Int : AnyTypeOf<[I1, I4, I8, I16, I32, I64], "integer">;
def TT_IntTensor : RankedTensorOf<[TT_Int]>;
def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;

Expand Down Expand Up @@ -106,12 +107,13 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]>
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace,
"bool":$mutable_memory
);
let extraClassDeclaration = [{
MemDescType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding());
return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory());
}

bool hasRank() const { return true; }
Expand All @@ -120,17 +122,19 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]>
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$encoding
"Attribute":$encoding,
"Attribute":$memorySpace
), [{
return $_get(elementType.getContext(), shape, elementType, encoding, /*mutableMemory=*/false);
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false);
}]>,
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace,
"bool":$mutableMemory
), [{
return $_get(elementType.getContext(), shape, elementType, encoding, mutableMemory);
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory);
}]>
];
let hasCustomAssemblyFormat = 1;
Expand Down
6 changes: 6 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1298,4 +1298,10 @@ elements along the K dim, or they use all elements of the tensor along the K dim
}];
}

def TTG_SharedMemorySpace : AttrDef<TritonGPU_Dialect, "SharedMemorySpace"> {
let mnemonic = "shared_memory";
let description = [{
Attribute to indicate that the memory descriptor points to shared memory.
}];
}
#endif
8 changes: 7 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods<MemoryEf
}];
let arguments = (ins Optional<TT_Tensor>:$src);

let extraClassDeclaration = [{
bool isSharedMemoryAlloc() {
return getType().getMemorySpace() &&
isa<SharedMemorySpaceAttr>(getType().getMemorySpace());
}
}];
let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}];

let results = (outs TT_MemDescType:$result);
Expand All @@ -163,7 +169,7 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree<SharedM

Because we assume a memdesc is dead at the first point that post-dominates
its uses, ops that wait for an async operation on a memdesc to complete
(such as triton_nvidia_gpu.dot_wait) should also take the memdesc as an
(such as triton_nvidia_gpu.warp_group_dot_wait) should also take the memdesc as an
operand.
}];

Expand Down
1 change: 1 addition & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
"mlir::arith::ArithDialect"];
}


def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> {
let summary = "accelerate matmul";

Expand Down
107 changes: 107 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Schedule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_
#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Support/LLVM.h"
#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h"
#include "llvm/ADT/ArrayRef.h"
#include <list>
#include <vector>

namespace mlir {
namespace triton {

/// This fill out the pipelining options including schedule and annotations
/// for wait ops. This also does pre-processing by converting some of the
/// loads into async loads so that the IR is ready to be pipelined.
bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages,
mlir::triton::PipeliningOption &options);

/// Fills out pipelining options for an outer loop pipelining case. This
/// schedules async copies to overlap with the epilogue of a loop.
bool getOuterLoopSchedule(scf::ForOp &forOp, int numStages,
mlir::triton::PipeliningOption &options);

/// Pipeline the TMA stores in the loop.
bool pipelineTMAStores(scf::ForOp forOp);

/// This does post-processing on the pipelined loop to try to pipeline wgmma
/// ops.
// TODO: this should be included as part of the pipeline but currently the wgmma
// wait modeling is problematic.
void asyncLaunchDots(scf::ForOp forOp);

/// Post process the pipelined loop by updating the wait ops with the right
/// number of groups in flight.
void updateWaits(ModuleOp module);

class CoarseSchedule {
public:
class ClusterList {
std::list<int> orderClusters;

public:
using iterator = decltype(orderClusters)::iterator;
ClusterList() = default;
iterator begin() { return orderClusters.begin(); }
iterator end() { return orderClusters.end(); }
size_t size() { return orderClusters.size(); }
iterator newAtBack() {
orderClusters.push_back(orderClusters.size());
return std::prev(orderClusters.end());
}
iterator newAtFront() {
orderClusters.push_front(-1);
for (auto &clusterId : orderClusters) {
clusterId++;
}
return orderClusters.begin();
}
iterator newBefore(iterator cluster) {
auto ret = orderClusters.insert(cluster, *cluster);
for (auto &clusterId : llvm::make_range(cluster, orderClusters.end())) {
clusterId++;
}
return ret;
}
};

CoarseSchedule(int numStages) : numStages(numStages) {}
int numStages;
ClusterList clusters;
using Cluster = decltype(clusters)::iterator;

DenseMap<Operation *, std::pair<int, Cluster>> opToStageAndCluster;

void insert(Operation *op, int stage, Cluster cluster) {
opToStageAndCluster[op] = {stage, cluster};
}

bool insertIfAbsent(Operation *op, int stage, Cluster cluster) {
if (opToStageAndCluster.count(op))
return false;
insert(op, stage, cluster);
return true;
}

void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster,
bool includeArg);

void erase(Operation *op) { opToStageAndCluster.erase(op); }

int count(Operation *op) { return opToStageAndCluster.count(op); }

std::pair<int, Cluster> operator[](Operation *op) {
return opToStageAndCluster[op];
}

SmallVector<std::tuple<Operation *, int, Cluster>>
getOpsInOrder(scf::ForOp forOp);
std::vector<std::pair<Operation *, unsigned>>
createFinalSchedule(scf::ForOp forOp);
void dump();
};

} // namespace triton
} // namespace mlir
#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_
5 changes: 4 additions & 1 deletion include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SharedEncodingAttr;

SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
const ArrayRef<int64_t> &shape,
TensorOrMemDesc type,
RankedTensorType type,
int numWarps);

/// Returns true if the Load uses block pointer.
Expand Down Expand Up @@ -135,6 +135,9 @@ scf::IfOp replaceIfOpWithNewSignature(
RewriterBase &rewriter, scf::IfOp loop, TypeRange newResultTypes,
SmallVectorImpl<std::tuple<Value, Value>> &replacements);

// Append the given |newOperands| to the |forOp|'s yield op.
void appendToForOpYield(scf::ForOp forOp, ArrayRef<Value> newOperands);

Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
IRMapping &mapping);

Expand Down
Loading