Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
// TritonAMDGPUTransforms passes
mlir::registerTritonAMDGPUAccelerateMatmul();
mlir::registerTritonAMDGPUOptimizeEpilogue();
mlir::registerTritonAMDGPUBypassLDSForDotOperand();
mlir::registerTritonAMDGPUReorderInstructions();
mlir::registerTritonAMDGPUBlockPingpong();
mlir::registerTritonAMDGPUStreamPipeline();
Expand Down
2 changes: 0 additions & 2 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,6 @@ enum class MMALoadType {
};
MMALoadType getMMALoadType(Operation *loadOp);

// Convert \param op operands and results to layout \param encoding.
void convertOpEncoding(Attribute encoding, Operation *op);
} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
1 change: 0 additions & 1 deletion include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_ENABLE_LLVM_DEBUG",
"TRITON_HIP_STREAM_PREFETCH",
"TRITON_HIP_USE_BLOCK_PINGPONG",
"TRITON_HIP_BYPASS_LDS_FOR_DOT",
"TRITON_LLVM_DEBUG_ONLY",
"TRITON_ENABLE_ASAN",
"TRITON_OVERRIDE_ARCH",
Expand Down
51 changes: 50 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,55 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
threadsPerWarp, CTALayout);
}

static Type getNewType(Type type, Attribute encoding) {
RankedTensorType tensorType = cast<RankedTensorType>(type);
return RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
}

void coalesceOp(Attribute encoding, Operation *op) {
OpBuilder builder(op);
// Convert operands
// For load/store with tensor pointers, we don't have to change the
// operands' type, we do this by changing the outputs' type of
// `make_tensor_ptr`
SmallVector<Value, 4> newArgs;
for (auto operand : op->getOperands()) {
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
if (tensorType &&
!isa<triton::gpu::SharedEncodingAttr>(tensorType.getEncoding())) {
Type newType = getNewType(tensorType, encoding);
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, operand));
} else {
newArgs.push_back(operand);
}
}

// Convert output types
SmallVector<Type, 4> newTypes;
for (auto t : op->getResultTypes()) {
bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
}

// Construct new op with the new encoding
Operation *newOp =
builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs,
newTypes, op->getAttrs());

// Cast the results back to the original layout
for (size_t i = 0; i < op->getNumResults(); i++) {
Value newResult = newOp->getResult(i);
if (newTypes[i] != op->getResultTypes()[i]) {
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), op->getResult(i).getType(), newResult);
}
op->getResult(i).replaceAllUsesWith(newResult);
}
op->erase();
}

void runOnOperation() override {
// Run axis info analysis
ModuleOp moduleOp = getOperation();
Expand Down Expand Up @@ -138,7 +187,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
// 4. Convert the output of this new memory op back to L1
// 5. Replace all the uses of the original memory op by the new one
for (auto &kv : layoutMap) {
convertOpEncoding(kv.second, kv.first);
coalesceOp(kv.second, kv.first);
}
}
};
Expand Down
48 changes: 6 additions & 42 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1022,43 +1022,6 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
}
}

bool shouldPropagateConversion(ConvertLayoutOp convertOp) {
RankedTensorType targetType = convertOp.getType();
auto dotEnc = dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding());
// If the target encoding is not DotOperandEncodingAttr, allow propagation.
if (!dotEnc) {
return true;
}
// Skip conversions to DotOperandEncodingAttr when the operand index is 0.
// This heuristic is applied to prevent moving the blocked->dot conversion of
// the Q tensor (a loop invariant in Flash Attention) outside the loop. Doing
// so can increase register pressure and cause spilling in some cases.
if (dotEnc.getOpIdx() == 0) {
return false;
}
// Skip conversions to DotOperandEncodingAttr when the operand index is 1 if
// it's not intentionally placed above a load as we have to be a bit more
// careful with the heuristics for both correctness and performance.
// TODO: Fix this logic to avoid propagating conversions backward unless
// it reduces the total number of conversions.
assert(dotEnc.getOpIdx() == 1);
SetVector<Operation *> slice;
BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = [&](Operation *op) {
return op->getParentRegion() == convertOp->getParentRegion();
};
getBackwardSlice(convertOp.getOperation(), &slice, opt);

for (Operation *currOp : slice) {
if (isa<LoadOp>(currOp)) {
return false;
}
}
// Allow propagation if no LoadOp is found.
return true;
}

void LayoutRematerialization::hoistConvertIntoConditionals() {
// Go through each ConvertLayoutOp.
SmallVector<ConvertLayoutOp> convertOps;
Expand All @@ -1077,11 +1040,11 @@ void LayoutRematerialization::hoistConvertIntoConditionals() {

void LayoutRematerialization::backwardRematerialization(
ConvertLayoutOp convertOp) {
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristic to accommodate fused attention
RankedTensorType targetType = convertOp.getType();
if (!shouldPropagateConversion(convertOp)) {
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
return;
}

Value oldV = convertOp.getSrc();
LDBG("check backward remat with source " << oldV << " encoding "
<< targetType.getEncoding());
Expand Down Expand Up @@ -1120,10 +1083,11 @@ void LayoutRematerialization::backwardRematerialization(
// of the convert.
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
ConvertLayoutOp convertOp) {
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accommodate fused attention
RankedTensorType targetType = convertOp.getType();
if (!shouldPropagateConversion(convertOp)) {
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
return;
}

auto isExtOrBroadcastOp = [](Operation *op) {
if (isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp, BroadcastOp,
Expand Down
48 changes: 0 additions & 48 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1057,54 +1057,6 @@ MMALoadType getMMALoadType(Operation *loadOp) {
}
}

static Type getNewType(Type type, Attribute encoding) {
RankedTensorType tensorType = cast<RankedTensorType>(type);
return RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
}

void convertOpEncoding(Attribute encoding, Operation *op) {
OpBuilder builder(op);
// Convert operands
// For load/store with tensor pointers, we don't have to change the
// operands' type, we do this by changing the outputs' type of
// `make_tensor_ptr`
SmallVector<Value, 4> newArgs;
for (auto operand : op->getOperands()) {
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
if (tensorType &&
!isa<triton::gpu::SharedEncodingAttr>(tensorType.getEncoding())) {
Type newType = getNewType(tensorType, encoding);
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, operand));
} else {
newArgs.push_back(operand);
}
}

// Convert output types
SmallVector<Type, 4> newTypes;
for (auto t : op->getResultTypes()) {
bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
}

// Construct new op with the new encoding
Operation *newOp = builder.create(op->getLoc(), op->getName().getIdentifier(),
newArgs, newTypes, op->getAttrs());

// Cast the results back to the original layout
for (size_t i = 0; i < op->getNumResults(); i++) {
Value newResult = newOp->getResult(i);
if (newTypes[i] != op->getResultTypes()[i]) {
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), op->getResult(i).getType(), newResult);
}
op->getResult(i).replaceAllUsesWith(newResult);
}
op->erase();
}

namespace {

/// Detect dead arguments in scf.for op by assuming all the values are dead and
Expand Down
Loading