Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,31 +101,28 @@ class GraphLayoutMarker : public GraphDumper {
std::string getColor(const Type &type) const;
};

// TODO: Interface
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret);
// Infers the encoding of the result of op given the source encoding.
std::optional<Attribute> inferDstEncoding(Operation *op, Attribute encoding);

bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding);
// Infers the encoding of the source of op given the result encoding.
std::optional<Attribute> inferSrcEncoding(Operation *op, Attribute encoding);

bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding);
bool isExpensiveLoadOrStore(Operation *op);

// skipInit is True when we only consider the operands of the initOp but
// not the initOp itself.
int simulateBackwardRematerialization(
Operation *initOp, SetVector<Operation *> &processed,
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
Attribute targetEncoding);
bool canFoldIntoConversion(Operation *op, Attribute targetEncoding);

Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
IRMapping &mapping);

void rematerializeConversionChain(
const llvm::MapVector<Value, Attribute> &toConvert,
mlir::PatternRewriter &rewriter, SetVector<Operation *> &processed,
IRMapping &mapping);
// Get backward slice of tensor values starting from the root node along with
// encoding propagation.
LogicalResult getConvertBackwardSlice(
Value root, SetVector<Value> &slice, Attribute rootEncoding,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr);

LogicalResult canMoveOutOfLoop(BlockArgument arg,
SmallVector<Operation *> &cvts);
// Populate pattern to remove dead cycles in ForOp.
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns);

// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
Expand Down
16 changes: 13 additions & 3 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,9 +473,11 @@ struct DFSState {
SmallVector<Operation *, 16> topologicalCounts;
DenseSet<Operation *> seen;

/// We mark each op as ready if all its operands are seen. If an op is ready,
/// we add it to the queue. Otherwise, we keep adding its operands to the
/// ancestors set.
/// We mark each op as ready if all its operands and parents ops are seen. If
/// an op is ready, we add it to the queue. Otherwise, we keep adding its
/// operands to the ancestors set.
/// We always want an op to be scheduled after all its parents to handle
/// correctly cases with scf operations.
void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph,
SmallVector<Operation *, 4> &readyQueue) {
bool ready = true;
Expand All @@ -486,6 +488,14 @@ struct DFSState {
ready = false;
}
}
Operation *parent = op->getParentOp();
Comment thread
ThomasRaoux marked this conversation as resolved.
while (parent) {
if (!seen.count(parent)) {
subGraph.push_back(parent);
ready = false;
}
parent = parent->getParentOp();
}
if (ready)
readyQueue.push_back(op);
}
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class MoveOpAfterLayoutConversion : public mlir::RewritePattern {
cvtArgOp->getDialect()->getTypeID() !=
mlir::TypeID::get<arith::ArithDialect>())
return mlir::failure();
// not handled in elementwise lowering.
if (isa<arith::TruncIOp, arith::TruncFOp>(cvtArgOp))
return mlir::failure();
// only considers conversions to dot operand
if (!cvtTy.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
return mlir::failure();
Expand Down
Loading