Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
// TritonAMDGPUTransforms passes
mlir::registerTritonAMDGPUAccelerateMatmul();
mlir::registerTritonAMDGPUOptimizeEpilogue();
mlir::registerTritonAMDGPUBypassLDSForDotOperand();
mlir::registerTritonAMDGPUReorderInstructions();
mlir::registerTritonAMDGPUStreamPipelineV2();
mlir::registerTritonAMDGPUCanonicalizePointers();
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ bool isPureUnaryInlineAsm(Operation *op);
// read the compute capability from the module attributes
int getNVIDIAComputeCapability(Operation *module);

// Convert \param op operands and results to layout \param encoding.
void convertOpEncoding(Attribute encoding, Operation *op);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved (and renamed) this function from Coalece.cpp so I could use it in BypassLDS pass since I needed this exact functionality.

} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
51 changes: 1 addition & 50 deletions lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,55 +104,6 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
threadsPerWarp, CTALayout);
}

static Type getNewType(Type type, Attribute encoding) {
RankedTensorType tensorType = cast<RankedTensorType>(type);
return RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
}

void coalesceOp(Attribute encoding, Operation *op) {
OpBuilder builder(op);
// Convert operands
// For load/store with tensor pointers, we don't have to change the
// operands' type, we do this by changing the outputs' type of
// `make_tensor_ptr`
SmallVector<Value, 4> newArgs;
for (auto operand : op->getOperands()) {
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
if (tensorType &&
!isa<triton::gpu::SharedEncodingAttr>(tensorType.getEncoding())) {
Type newType = getNewType(tensorType, encoding);
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, operand));
} else {
newArgs.push_back(operand);
}
}

// Convert output types
SmallVector<Type, 4> newTypes;
for (auto t : op->getResultTypes()) {
bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
}

// Construct new op with the new encoding
Operation *newOp =
builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs,
newTypes, op->getAttrs());

// Cast the results back to the original layout
for (size_t i = 0; i < op->getNumResults(); i++) {
Value newResult = newOp->getResult(i);
if (newTypes[i] != op->getResultTypes()[i]) {
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), op->getResult(i).getType(), newResult);
}
op->getResult(i).replaceAllUsesWith(newResult);
}
op->erase();
}

void runOnOperation() override {
// Run axis info analysis
ModuleOp moduleOp = getOperation();
Expand Down Expand Up @@ -187,7 +138,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
// 4. Convert the output of this new memory op back to L1
// 5. Replace all the uses of the original memory op by the new one
for (auto &kv : layoutMap) {
coalesceOp(kv.second, kv.first);
convertOpEncoding(kv.second, kv.first);
}
}
};
Expand Down
22 changes: 16 additions & 6 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Comment thread
antiagainst marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -967,10 +967,15 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {

void LayoutRematerialization::backwardRematerialization(
ConvertLayoutOp convertOp) {
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristic to accommodate fused attention
// Skip conversions to DotOperandEncodingAttr when the operand index is 0.
// This heuristic is applied to prevent moving the blocked->dot conversion of
// the Q tensor (a loop invariant in Flash Attention) outside the loop. Doing
// so can increase register pressure and cause spilling in some cases.
// TODO: Fix this logic to avoid propagating conversions backward unless
// it reduces the total number of conversions.
RankedTensorType targetType = convertOp.getType();
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
auto dotEnc = dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding());
Comment thread
antiagainst marked this conversation as resolved.
if (dotEnc && dotEnc.getOpIdx() == 0)
return;
Value oldV = convertOp->getOperand(0);
LDBG("check backward remat with source " << oldV << " encoding "
Expand Down Expand Up @@ -1010,10 +1015,15 @@ void LayoutRematerialization::backwardRematerialization(
// of the convert.
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
ConvertLayoutOp convertOp) {
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accommodate fused attention
// Skip conversions to DotOperandEncodingAttr when the operand index is 0.
// This heuristic is applied to prevent moving the blocked->dot conversion of
// the Q tensor (a loop invariant in Flash Attention) outside the loop. Doing
// so can increase register pressure and cause spilling in some cases.
// TODO: Fix this logic to avoid propagating conversions backward unless
// it reduces the total number of conversions.
RankedTensorType targetType = convertOp.getType();
if (mlir::isa<DotOperandEncodingAttr>(targetType.getEncoding()))
auto dotEnc = dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding());
if (dotEnc && dotEnc.getOpIdx() == 0)
return;

auto isExtOrBroadcastOp = [](Operation *op) {
Expand Down
48 changes: 48 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,54 @@ int getNVIDIAComputeCapability(Operation *module) {
return computeCapability;
}

static Type getNewType(Type type, Attribute encoding) {
RankedTensorType tensorType = cast<RankedTensorType>(type);
return RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
}

void convertOpEncoding(Attribute encoding, Operation *op) {
OpBuilder builder(op);
// Convert operands
// For load/store with tensor pointers, we don't have to change the
// operands' type, we do this by changing the outputs' type of
// `make_tensor_ptr`
SmallVector<Value, 4> newArgs;
for (auto operand : op->getOperands()) {
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
if (tensorType &&
!isa<triton::gpu::SharedEncodingAttr>(tensorType.getEncoding())) {
Type newType = getNewType(tensorType, encoding);
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, operand));
} else {
newArgs.push_back(operand);
}
}

// Convert output types
SmallVector<Type, 4> newTypes;
for (auto t : op->getResultTypes()) {
bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
}

// Construct new op with the new encoding
Operation *newOp = builder.create(op->getLoc(), op->getName().getIdentifier(),
newArgs, newTypes, op->getAttrs());

// Cast the results back to the original layout
for (size_t i = 0; i < op->getNumResults(); i++) {
Value newResult = newOp->getResult(i);
if (newTypes[i] != op->getResultTypes()[i]) {
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), op->getResult(i).getType(), newResult);
}
op->getResult(i).replaceAllUsesWith(newResult);
}
op->erase();
}

namespace {

/// Detect dead arguments in scf.for op by assuming all the values are dead and
Expand Down
Loading