diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index b72efbdf676e..6c0193182336 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -101,31 +101,28 @@ class GraphLayoutMarker : public GraphDumper { std::string getColor(const Type &type) const; }; -// TODO: Interface -LogicalResult invertEncoding(Attribute targetEncoding, Operation *op, - Attribute &ret); +// Infers the encoding of the result of op given the source encoding. +std::optional inferDstEncoding(Operation *op, Attribute encoding); -bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding); +// Infers the encoding of the source of op given the result encoding. +std::optional inferSrcEncoding(Operation *op, Attribute encoding); -bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding); +bool isExpensiveLoadOrStore(Operation *op); -// skipInit is True when we only consider the operands of the initOp but -// not the initOp itself. -int simulateBackwardRematerialization( - Operation *initOp, SetVector &processed, - SetVector &layout, llvm::MapVector &toConvert, - Attribute targetEncoding); +bool canFoldIntoConversion(Operation *op, Attribute targetEncoding); Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, IRMapping &mapping); -void rematerializeConversionChain( - const llvm::MapVector &toConvert, - mlir::PatternRewriter &rewriter, SetVector &processed, - IRMapping &mapping); +// Get backward slice of tensor values starting from the root node along with +// encoding propagation. +LogicalResult getConvertBackwardSlice( + Value root, SetVector &slice, Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation = nullptr); -LogicalResult canMoveOutOfLoop(BlockArgument arg, - SmallVector &cvts); +// Populate pattern to remove dead cycles in ForOp. +void populateForOpDeadArgumentElimination(RewritePatternSet &patterns); // Convert an \param index to a multi-dim coordinate given \param shape and // \param order. diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 011ff8d513b2..3beb816d2057 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -473,9 +473,11 @@ struct DFSState { SmallVector topologicalCounts; DenseSet seen; - /// We mark each op as ready if all its operands are seen. If an op is ready, - /// we add it to the queue. Otherwise, we keep adding its operands to the - /// ancestors set. + /// We mark each op as ready if all its operands and parents ops are seen. If + /// an op is ready, we add it to the queue. Otherwise, we keep adding its + /// operands to the ancestors set. + /// We always want an op to be scheduled after all its parents to handle + /// correctly cases with scf operations. void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph, SmallVector &readyQueue) { bool ready = true; @@ -486,6 +488,14 @@ struct DFSState { ready = false; } } + Operation *parent = op->getParentOp(); + while (parent) { + if (!seen.count(parent)) { + subGraph.push_back(parent); + ready = false; + } + parent = parent->getParentOp(); + } if (ready) readyQueue.push_back(op); } diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 84bf5cebbdbd..43f411bcae7d 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -121,6 +121,9 @@ class MoveOpAfterLayoutConversion : public mlir::RewritePattern { cvtArgOp->getDialect()->getTypeID() != mlir::TypeID::get()) return mlir::failure(); + // not handled in elementwise lowering. + if (isa(cvtArgOp)) + return mlir::failure(); // only considers conversions to dot operand if (!cvtTy.getEncoding().isa()) return mlir::failure(); diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index fc6be2afafee..2a38b03d2e14 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -12,11 +12,11 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" - #include using namespace mlir; @@ -82,542 +82,774 @@ class DecomposeDotOperand : public mlir::RewritePattern { } }; -// It's beneficial to move the conversion -// to after the reduce if necessary since it will be -// done on a rank-reduced tensor hence cheaper -class SimplifyReduceCvt : public mlir::RewritePattern { +// +class ConvertDotConvert : public mlir::RewritePattern { public: - explicit SimplifyReduceCvt(mlir::MLIRContext *context) + ConvertDotConvert(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 2, context) {} + 1, context) {} - mlir::LogicalResult + LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - if (!llvm::isa(op)) + auto dstOp = cast(op); + auto dotOp = dstOp.getSrc().getDefiningOp(); + if (!dotOp) return mlir::failure(); - auto convert = llvm::cast(op); - triton::ReduceOp reduce; - for (auto &use : convert.getResult().getUses()) { - auto owner = llvm::dyn_cast(use.getOwner()); - if (!owner) { - continue; - } - - // TODO: This only moves conversions from the first argument which is - // fine for argmin/argmax but may not be optimal generally - if (convert.getResult() != owner.getOperands()[0]) { - continue; - } - reduce = owner; - break; - } - if (!reduce) + if (std::distance(dstOp->user_begin(), dstOp->user_end()) != 1 || + std::distance(dotOp->user_begin(), dotOp->user_end()) != 1) return mlir::failure(); - - SmallVector newOperands = reduce.getOperands(); - - newOperands[0] = convert.getOperand(); - auto newEncoding = - newOperands[0].getType().cast().getEncoding(); - - // this may generate unsupported conversions in the LLVM codegen - if (newEncoding.isa()) { - return failure(); - } - - // ReduceOp does not support SharedLayout as its src layout, therefore - // ConvertLayoutOp and ReduceOp should not be swapped when the conversion is - // from SharedLayout to DistributedLayout - if (newEncoding.isa()) { + auto cvtOp = + dotOp.getOperand(2).getDefiningOp(); + if (!cvtOp) + return mlir::failure(); + if (!cvtOp.getSrc().getDefiningOp()) return failure(); - } - - for (unsigned i = 1; i < newOperands.size(); ++i) { - auto oldTy = newOperands[i].getType().cast(); - RankedTensorType newTy = - RankedTensorType::Builder(oldTy).setEncoding(newEncoding); - - newOperands[i] = rewriter.create( - op->getLoc(), newTy, newOperands[i]); - } - - rewriter.setInsertionPoint(reduce); - auto newReduce = rewriter.create( - op->getLoc(), newOperands, reduce.getAxis()); - auto &newCombineOp = newReduce.getCombineOp(); - rewriter.cloneRegionBefore(reduce.getCombineOp(), newCombineOp, - newCombineOp.end()); - - SmallVector newRet = newReduce.getResult(); - auto oldTypes = reduce.getResult().getType(); - for (unsigned i = 0; i < reduce.getNumOperands(); ++i) { - // it's still beneficial to move the conversion - // to after the reduce if necessary since it will be - // done on a rank-reduced tensor hence cheaper - if (newRet[i].getType() != oldTypes[i]) - newRet[i] = rewriter.create( - op->getLoc(), oldTypes[i], newRet[i]); - } - rewriter.replaceAllUsesWith(reduce.getResult(), newRet); + auto dstTy = dstOp.getResult().getType().cast(); + auto srcTy = cvtOp.getOperand().getType().cast(); + if (dstTy != srcTy) + return mlir::failure(); - return success(); + auto _0f = rewriter.create( + op->getLoc(), dstTy.getElementType(), + rewriter.getZeroAttr(dstTy.getElementType())); + auto _0 = rewriter.create( + op->getLoc(), dotOp.getResult().getType(), _0f); + auto newDot = rewriter.create( + op->getLoc(), dotOp.getResult().getType(), dotOp.getOperand(0), + dotOp.getOperand(1), _0, dotOp.getAllowTF32()); + auto newCvt = rewriter.create( + op->getLoc(), dstTy, newDot.getResult()); + rewriter.replaceOpWithNewOp(op, newCvt, cvtOp.getOperand()); + return mlir::success(); } }; -// Layout conversions can't deduce their return type automatically. -// IIUC they are therefore not handled by DRR right now -class SimplifyConversion : public mlir::RewritePattern { +// Class to propagate layout globally within a function. +// The current algorithm works by analysis the IR and doing a one shot rewrite +// based on the analysis. The algorithm is as follows: +// 1. Find all the anchor ops. These are ops that have a layout we want to +// preserve. +// +// 2. Propagate the layout to every op reachable which is a transitive child of +// an anchor op until we reach a fix point. +// An op can have multiple transitive anchor parents therefore at this stage +// it may have multiple layout associated to it. +// +// 3. Resolve conflicts by deciding which of the multiple layouts the op should +// keep. If one of the parents has a different layout than what is picked a +// convert operation will be inserted. After this stage each value should have +// only one layout associated. +// +// 4. Rewrite the IR by walking the function following dominance order. Since we +// assume the IR is structured we just need to process the regions in the +// correct order. For each op rewrite it using the layout decided by the +// analysis phase. +class LayoutPropagation { public: - explicit SimplifyConversion(mlir::MLIRContext *context) - : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 4, context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - if (!llvm::isa(op)) - return mlir::failure(); - auto convert = llvm::cast(op); - return ConvertLayoutOp::canonicalize(convert, rewriter); - } + // Structure to keep track of the layout associated to a value. + struct LayoutInfo { + LayoutInfo(Attribute encoding) { encodings.insert(encoding); } + LayoutInfo() {} + llvm::SmallSetVector encodings; + }; + LayoutPropagation(triton::FuncOp F) : funcOp(F) {} + // Find the anchor ops and set their layout in the data structure. + void initAnchorLayout(); + // Recursively Propagate the layout to all the users of the anchor ops until + // we reach a fix point. + void propagateLayout(); + // Add layouts given in `Info` to the uses of `value`. + SmallVector propagateToUsers(Value value, LayoutInfo &info); + // Set the encoding to all the values and fill out the values with new layout + // in `changed`. + void setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, Operation *op); + // Resolve cases where a value has multiple layouts associated to it. + void resolveConflicts(); + // Rewrite the IR for the full module. + void rewrite(); + // Rewrite the IR for a region. + void rewriteRegion(Region &R); + // Rewrite an op based on the layout picked by the analysis. + Operation *rewriteOp(Operation *op); + // Rewrite a for op based on the layout picked by the analysis. + Operation *rewriteForOp(scf::ForOp forOp); + Operation *rewriteIfOp(scf::IfOp ifOp); + Operation *rewriteYieldOp(scf::YieldOp yieldOp); + Operation *cloneElementwise(OpBuilder &rewriter, Operation *op, + Attribute encoding); + // Map the original value to the rewritten one. + void map(Value old, Value newV); + // Return the mapped value in the given encoding. This will insert a convert + // if the encoding is different than the encoding decided at resolve time. + Value getValueAs(Value value, Attribute encoding); + // Dump the current stage of layout information. + void dump(); + +private: + // map from value to layout information. + llvm::MapVector layouts; + // map of the values rewrite based on their encoding. + DenseMap, Value> rewriteMapping; + std::vector opToDelete; + triton::FuncOp funcOp; }; -// ----------------------------------------------------------------------------- -// -// ----------------------------------------------------------------------------- +} // namespace -// op(cvt(arg_0), arg_1, ..., arg_n) -// -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n))) -void pushConversionForward(triton::gpu::ConvertLayoutOp cvt, - SetVector &cvtSlices, - mlir::PatternRewriter &rewriter) { - auto srcEncoding = - cvt.getOperand().getType().cast().getEncoding(); - auto dstEncoding = - cvt.getResult().getType().cast().getEncoding(); - IRMapping mapping; - auto op = cvtSlices.front(); - for (Value arg : op->getOperands()) { - if (arg.getDefiningOp() == cvt) - mapping.map(arg, cvt.getOperand()); - else { - auto oldType = arg.getType().dyn_cast(); - // TODO: we may be creating block pointer load/store with mismatching - // pointer type. - if (!oldType) +// Look ahead to at the transitive uses and see if there is a convert to mma +// operations. +static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { + SmallVector queue = {op->getResult(0)}; + SetVector forwardSlice; + llvm::SmallDenseSet seen; + while (!queue.empty()) { + Value currentValue = queue.back(); + queue.pop_back(); + getForwardSlice(currentValue, &forwardSlice); + for (Operation *op : forwardSlice) { + if (auto convertOp = dyn_cast(op)) { + if (convertOp.getResult() + .getType() + .cast() + .getEncoding() == encoding) + return true; + } + auto yield = dyn_cast(op); + if (!yield) continue; - auto newType = RankedTensorType::get( - oldType.getShape(), oldType.getElementType(), srcEncoding); - auto cvtI = rewriter.create(arg.getLoc(), - newType, arg); - if (Operation *argOp = arg.getDefiningOp()) - cvtI->moveAfter(argOp); - mapping.map(arg, cvtI); + auto forOp = dyn_cast(yield.getOperation()->getParentOp()); + if (!forOp) + continue; + for (OpOperand &operand : yield->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if (def && forwardSlice.count(def) && + (seen.insert(operand.get()).second == true)) + queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber())); + } } } - rewriter.setInsertionPoint(op); - if (op->getNumResults() == 0) { - Operation *newOp = cloneWithInferType(rewriter, op, mapping); - rewriter.eraseOp(op); - return; - } - auto *newOp = cloneWithInferType(rewriter, op, mapping); - auto newType = newOp->getResult(0).getType().cast(); - auto newCvtType = RankedTensorType::get( - newType.getShape(), newType.getElementType(), dstEncoding); - auto newCvt = rewriter.create( - newOp->getLoc(), newCvtType, newOp->getResult(0)); - rewriter.replaceOp(op, newCvt->getResults()); + return false; } -// -class MoveConvertOutOfIf : public mlir::RewritePattern { -public: - explicit MoveConvertOutOfIf(mlir::MLIRContext *context) - : mlir::RewritePattern(scf::IfOp::getOperationName(), 2, context) {} +// Return true if the op is an op with a layout we don't want to change. We will +// propagate the layout starting from anchor ops. +static bool isLayoutAnchor(Operation *op) { + if (isa(op)) + return isExpensiveLoadOrStore(op); + if (isa(op)) + return true; + return false; +} - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ifOp = cast(*op); - // If “scf.if” defines no values, “scf.yield” will be inserted implicitly. - // However, "scf.else" is not required to be present, so we need to check - // if it exists. - auto thenYield = ifOp.thenYield(); - int numOps = thenYield.getNumOperands(); - SmallVector newThenYieldOps = thenYield.getOperands(); - SetVector thenCvts; - SmallVector newRetTypes; - - bool hasElse = !ifOp.getElseRegion().empty(); - - scf::YieldOp elseYield; - SmallVector newElseYieldOps; - SetVector elseCvts; - if (hasElse) { - elseYield = ifOp.elseYield(); - newElseYieldOps = elseYield.getOperands(); +void LayoutPropagation::initAnchorLayout() { + funcOp.walk([&](Operation *op) { + if (isLayoutAnchor(op)) { + for (auto result : op->getResults()) { + if (auto tensorType = result.getType().dyn_cast()) { + // Workaround, don't popagate MMA layout unless there is a convert + // back to mma further down to avoid generating reduction with MMA + // layout that may have lower performance. + // This can be improved with more aggressive backward propagation. + if (tensorType.getEncoding().isa() && + !hasConvertToMMATransisitiveUse(op, tensorType.getEncoding())) + continue; + layouts.insert({result, tensorType.getEncoding()}); + } + } } + }); +} - IRMapping mapping; - for (size_t i = 0; i < numOps; i++) { - auto thenCvt = - thenYield.getOperand(i).getDefiningOp(); - if (hasElse) { - auto elseYield = ifOp.elseYield(); - auto elseCvt = elseYield.getOperand(i) - .getDefiningOp(); - if (thenCvt && elseCvt && - std::distance(elseCvt->user_begin(), elseCvt->user_end()) == 1 && - std::distance(thenCvt->user_begin(), thenCvt->user_end()) == 1 && - thenCvt.getOperand().getType() == elseCvt.getOperand().getType()) { - // If thenCvt and elseCvt's type are the same, it means a single - // conversion is enough to replace both of them. We can move the - // conversion out of scf.if and replace both thenCvt and elseCvt with - // the new conversion. - mapping.map(thenCvt.getResult(), thenCvt.getOperand()); - thenCvts.insert((Operation *)thenCvt); - newRetTypes.push_back(thenCvt.getOperand().getType()); - mapping.map(elseCvt.getResult(), elseCvt.getOperand()); - elseCvts.insert((Operation *)elseCvt); - } else - // Cannot move out of scf.if because thenCvt != elseCvt - // Moving it out of scf.if will introduce a new conversion - newRetTypes.push_back(thenYield.getOperand(i).getType()); - } else { - if (thenCvt && - std::distance(thenCvt->user_begin(), thenCvt->user_end()) == 1) { - // If there's only a single use of the conversion then we can move it - mapping.map(thenCvt.getResult(), thenCvt.getOperand()); - thenCvts.insert((Operation *)thenCvt); - newRetTypes.push_back(thenCvt.getOperand().getType()); - } else - // Cannot move out of scf.if because either there's another use of - // the conversion or there's no conversion at all - newRetTypes.push_back(thenYield.getOperand(i).getType()); - } +void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, + Operation *op) { + for (Value value : values) { + if (!value.getType().isa()) + continue; + bool hasChanged = false; + for (auto encoding : info.encodings) { + auto dstEncoding = inferDstEncoding(op, encoding); + if (dstEncoding) + hasChanged |= layouts[value].encodings.insert(*dstEncoding); } - if (mapping.getValueMap().empty()) - return mlir::failure(); + if (hasChanged) + changed.push_back(value); + } +} - auto newIfOp = rewriter.create(ifOp.getLoc(), newRetTypes, - ifOp.getCondition(), hasElse); - auto rematerialize = [&](Block *block, SetVector &cvts) { - for (Operation &op : block->getOperations()) { - if (cvts.contains(&op)) { - if (mapping.contains(op.getOperand(0))) - mapping.map(op.getResult(0), mapping.lookup(op.getOperand(0))); - continue; - } - cloneWithInferType(rewriter, &op, mapping); - } - }; - rewriter.setInsertionPointToEnd(newIfOp.thenBlock()); - rematerialize(ifOp.thenBlock(), thenCvts); - if (hasElse) { - rewriter.setInsertionPointToEnd(newIfOp.elseBlock()); - rematerialize(ifOp.elseBlock(), elseCvts); +SmallVector LayoutPropagation::propagateToUsers(Value value, + LayoutInfo &info) { + SmallVector changed; + for (OpOperand &use : value.getUses()) { + Operation *user = use.getOwner(); + if (auto forOp = dyn_cast(user)) { + Value arg = forOp.getRegionIterArgForOpOperand(use); + Value result = forOp.getResultForOpOperand(use); + setEncoding({arg, result}, info, changed, user); + continue; + } + if (auto yieldOp = dyn_cast(user)) { + auto parent = yieldOp->getParentOp(); + SmallVector valuesToPropagate = { + parent->getResult(use.getOperandNumber())}; + if (auto forOp = dyn_cast(parent)) + valuesToPropagate.push_back( + forOp.getRegionIterArg(use.getOperandNumber())); + if (isa(parent)) + setEncoding({valuesToPropagate}, info, changed, user); + // TODO: handle while. + continue; + } + // Workaround: don't propagate through truncI + if (isa(user)) + continue; + if (user->hasTrait() || + user->hasTrait() || + isa(user)) { + setEncoding(user->getResults(), info, changed, user); + continue; } + } + return changed; +} + +void LayoutPropagation::propagateLayout() { + SmallVector queue; + for (auto it : layouts) { + queue.push_back(it.first); + } + while (!queue.empty()) { + Value currentValue = queue.back(); + LayoutInfo info = layouts[currentValue]; + queue.pop_back(); + SmallVector changed = propagateToUsers(currentValue, info); + queue.insert(queue.end(), changed.begin(), changed.end()); + } +} - rewriter.setInsertionPointAfter(newIfOp); - SmallVector newRetValues = newIfOp.getResults(); - for (size_t i = 0; i < numOps; i++) { - if (newIfOp.getResult(i).getType() != ifOp.getResult(i).getType()) { - newRetValues[i] = rewriter.create( - newIfOp.getLoc(), ifOp.getResult(i).getType(), - newIfOp.getResult(i)); +void LayoutPropagation::resolveConflicts() { + for (auto &it : layouts) { + LayoutInfo &info = it.second; + if (info.encodings.size() <= 1) + continue; + // Hacky resolve, prefer block encoding. + // TODO: add a proper heuristic. + Attribute encoding = *info.encodings.begin(); + for (Attribute e : info.encodings) { + if (e.isa()) { + encoding = e; + break; } } - - rewriter.replaceOp(op, newRetValues); - return mlir::success(); + info.encodings.clear(); + info.encodings.insert(encoding); } -}; +} -// -class RematerializeForward : public mlir::RewritePattern { -public: - explicit RematerializeForward(mlir::MLIRContext *context) - : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 1, context) {} +void LayoutPropagation::dump() { + for (auto it : layouts) { + llvm::errs() << "Value: "; + OpPrintingFlags flags; + flags.skipRegions(); + it.first.print(llvm::errs(), flags); + llvm::errs() << " \n encoding:\n"; + for (auto encoding : it.second.encodings) { + encoding.print(llvm::errs()); + llvm::errs() << "\n"; + } + llvm::errs() << "--\n"; + } +} - mlir::LogicalResult - matchAndRewrite(mlir::Operation *cvtOp, - mlir::PatternRewriter &rewriter) const override { - auto cvt = dyn_cast(*cvtOp); - auto srcEncoding = - cvt.getOperand().getType().cast().getEncoding(); - auto dstEncoding = - cvt.getResult().getType().cast().getEncoding(); - if (srcEncoding.isa() || - dstEncoding.isa()) - return failure(); - // heuristics for flash attention - if (srcEncoding.isa()) - return failure(); - // For cases like: - // %0 = convert_layout %arg0 - // We should try to move %0 out of scf.for first, if it couldn't be moved - // out additional conversions will be added to the loop body. - if (!cvt.getOperand().getDefiningOp() && - isa(cvt->getParentOp())) - return failure(); +void LayoutPropagation::rewrite() { rewriteRegion(funcOp->getRegion(0)); } - SetVector cvtSlices; - auto filter = [&](Operation *op) { - return op->getBlock() == cvt->getBlock() && - !isa(op) && - !(isa(op) && - !op->getResult(0).getType().isa()); - }; - mlir::getForwardSlice(cvt.getResult(), &cvtSlices, {filter}); - if (cvtSlices.empty()) - return failure(); +static bool allowChangingSrcEncoding(Operation *op) { + // For reductions returning a scalar we can change the src encoding without + // affecting the output. + if (isa(op) && + !op->getResultTypes()[0].isa() && + op->getNumOperands() == 1) + return true; + return false; +} - for (Operation *op : cvtSlices) { - // don't rematerialize anything expensive - if (isExpensiveToRemat(op, srcEncoding)) - return failure(); - // don't rematerialize non-element-wise - if (!op->hasTrait() && - !op->hasTrait() && - !isa(op)) - return failure(); - // don't rematerialize if it adds an extra conversion that can't - // be removed - for (Value arg : op->getOperands()) { - Operation *argOp = arg.getDefiningOp(); - SetVector processed; - SetVector layout; - llvm::MapVector toConvert; - int numAddedConvs = simulateBackwardRematerialization( - argOp, processed, layout, toConvert, srcEncoding); - if (argOp && !isa(argOp) && - cvtSlices.count(argOp) == 0 && numAddedConvs > 0) - return failure(); +void LayoutPropagation::rewriteRegion(Region ®ion) { + SmallVector queue = {®ion}; + while (!queue.empty()) { + Region *currentRegion = queue.back(); + queue.pop_back(); + for (Operation &op : currentRegion->getOps()) { + bool needRewrite = false; + SmallVector results = op.getResults(); + for (Value result : results) { + auto it = layouts.find(result); + // If we haven't mapped this value skip. + if (it == layouts.end()) + continue; + LayoutInfo &info = it->second; + assert(info.encodings.size() == 1 && + "we should have resolved to a single encoding"); + auto encoding = result.getType().cast().getEncoding(); + // If the encoding is already what we want skip. + if (encoding == *info.encodings.begin()) + continue; + needRewrite = true; + } + if (needRewrite) { + Operation *newOp = rewriteOp(&op); + for (Region &R : newOp->getRegions()) + queue.push_back(&R); + } else if (auto yieldOp = dyn_cast(&op)) { + rewriteYieldOp(yieldOp); + } else { + bool canChangeSrcEncoding = allowChangingSrcEncoding(&op); + // If we don't need to rewrite the op we still need to remap the + // operands. + for (OpOperand &operand : op.getOpOperands()) { + auto it = layouts.find(operand.get()); + if (it == layouts.end()) + continue; + Attribute encoding = + operand.get().getType().cast().getEncoding(); + if (canChangeSrcEncoding) + encoding = it->second.encodings[0]; + Value newOperand = getValueAs(operand.get(), encoding); + op.setOperand(operand.getOperandNumber(), newOperand); + } + for (Region &R : op.getRegions()) + queue.push_back(&R); } } + } + for (Operation *op : llvm::reverse(opToDelete)) + op->erase(); +} - // Call SimplifyReduceCvt instead of the general push conversion forward - if (isa(cvtSlices.front())) - return failure(); +void LayoutPropagation::map(Value old, Value newV) { + rewriteMapping[{old, newV.getType().cast().getEncoding()}] = + newV; +} - pushConversionForward(cvt, cvtSlices, rewriter); - return success(); +Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { + if (auto tensorType = value.getType().dyn_cast()) { + Value rewrittenValue; + auto layoutIt = layouts.find(value); + if (layoutIt == layouts.end()) { + rewrittenValue = value; + } else { + assert(layoutIt->second.encodings.size() == 1 && + "we should have resolved to a single encoding"); + Attribute encodingPicked = *(layoutIt->second.encodings.begin()); + if (encodingPicked == tensorType.getEncoding()) + rewrittenValue = value; + else + rewrittenValue = rewriteMapping[{value, encodingPicked}]; + } + assert(rewrittenValue); + if (rewrittenValue.getType().cast().getEncoding() == + encoding) + return rewrittenValue; + OpBuilder rewriter(value.getContext()); + rewriter.setInsertionPointAfterValue(rewrittenValue); + auto tmpType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + Value converted = rewriter.create( + value.getLoc(), tmpType, rewrittenValue); + // TODO: we could cache the conversion. + return converted; } -}; + return value; +} -// Layout conversions are expensive. They require going through -// shared memory, which is orders of magnitude slower than -// other non-i/o operations in the dialect. -// It therefore makes sense to remove them whenever possible, -// even if it means rematerializing all values whose definitions -// are reachable from it without passing through any memory operation. -class RematerializeBackward : public mlir::RewritePattern { -public: - explicit RematerializeBackward(mlir::MLIRContext *context) - : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 3, context) {} +Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter, + Operation *op, + Attribute encoding) { + Operation *newOp = rewriter.clone(*op); + for (OpOperand &operand : op->getOpOperands()) + newOp->setOperand( + operand.getOperandNumber(), + getValueAs(operand.get(), *inferSrcEncoding(op, encoding))); + for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) { + auto origType = op->getResult(i).getType().dyn_cast(); + if (!origType) + continue; + auto newType = RankedTensorType::get(origType.getShape(), + origType.getElementType(), encoding); + newOp->getResult(i).setType(newType); + } + return newOp; +} - mlir::LogicalResult - matchAndRewrite(mlir::Operation *cvt, - mlir::PatternRewriter &rewriter) const override { - if (!llvm::isa(cvt)) - return mlir::failure(); - // we don't touch block arguments - Operation *op = cvt->getOperand(0).getDefiningOp(); - if (!op) - return mlir::failure(); - // we don't want to rematerialize any conversion to/from shared - if (triton::gpu::isSharedEncoding(cvt->getResults()[0]) || - triton::gpu::isSharedEncoding(cvt->getOperand(0))) - return mlir::failure(); - // we don't handle conversions to DotOperandEncodingAttr - // this is a heuristics to accommodate fused attention - auto targetType = cvt->getResultTypes()[0].cast(); - if (targetType.getEncoding().isa()) - return mlir::failure(); - // DFS - SetVector processed; - SetVector layout; - llvm::MapVector toConvert; - if (simulateBackwardRematerialization(cvt, processed, layout, toConvert, - targetType.getEncoding()) > 0) - return mlir::failure(); +Operation *LayoutPropagation::rewriteForOp(scf::ForOp forOp) { + SmallVector operands; + OpBuilder rewriter(forOp); + for (auto [operand, result] : + llvm::zip(forOp.getInitArgs(), forOp.getResults())) { + Value convertedOperand = operand; + if (layouts.count(result)) + convertedOperand = + getValueAs(operand, *layouts[result].encodings.begin()); + operands.push_back(convertedOperand); + } + auto newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), operands); + + newForOp.getBody()->getOperations().splice( + newForOp.getBody()->getOperations().begin(), + forOp.getBody()->getOperations()); + + for (auto [oldResult, newResult] : + llvm::zip(forOp.getResults(), newForOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } - IRMapping mapping; - rematerializeConversionChain(toConvert, rewriter, processed, mapping); - rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0))); + for (auto [oldArg, newArg] : llvm::zip(forOp.getBody()->getArguments(), + newForOp.getBody()->getArguments())) { + if (oldArg.getType() == newArg.getType()) { + oldArg.replaceAllUsesWith(newArg); + continue; + } + map(oldArg, newArg); + } + return newForOp.getOperation(); +} - return mlir::success(); +Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) { + SmallVector operands; + OpBuilder rewriter(ifOp); + SmallVector newResultTypes(ifOp->getResultTypes()); + for (unsigned i = 0, e = ifOp->getNumResults(); i < e; ++i) { + auto it = layouts.find(ifOp->getResult(i)); + if (it == layouts.end()) + continue; + auto origType = ifOp->getResult(i).getType().cast(); + Attribute encoding = *(it->second.encodings.begin()); + newResultTypes[i] = RankedTensorType::get( + origType.getShape(), origType.getElementType(), encoding); } -}; + auto newIfOp = rewriter.create(ifOp.getLoc(), newResultTypes, + ifOp.getCondition(), true, true); + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + for (auto [oldResult, newResult] : + llvm::zip(ifOp.getResults(), newIfOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + return newIfOp.getOperation(); +} -// ----------------------------------------------------------------------------- -// -// ----------------------------------------------------------------------------- +Operation *LayoutPropagation::rewriteYieldOp(scf::YieldOp yieldOp) { + OpBuilder rewriter(yieldOp); + Operation *newYield = rewriter.clone(*yieldOp.getOperation()); + Operation *parentOp = yieldOp->getParentOp(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + Type yieldType = operand.get().getType(); + if (isa(parentOp)) + yieldType = parentOp->getResult(operand.getOperandNumber()).getType(); + auto tensorType = yieldType.dyn_cast(); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + newYield->setOperand(operand.getOperandNumber(), newOperand); + } + opToDelete.push_back(yieldOp.getOperation()); + return newYield; +} -class MoveConvertOutOfLoop : public mlir::RewritePattern { -public: - explicit MoveConvertOutOfLoop(mlir::MLIRContext *context) - : mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {} - - SmallVector - rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp, - size_t i, RankedTensorType newType, - triton::gpu::ConvertLayoutOp origConversion) const { - // Rewrite init argument - auto origType = forOp.getInitArgs()[i].getType().cast(); - SmallVector newInitArgs = forOp.getInitArgs(); - newInitArgs[i] = rewriter.create( - newInitArgs[i].getLoc(), newType, newInitArgs[i]); - // Clone for loop - auto newForOp = rewriter.create( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newInitArgs); - newForOp->moveBefore(forOp); - rewriter.setInsertionPointToStart(newForOp.getBody()); - IRMapping mapping; - for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) - mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); - mapping.map(origConversion.getResult(), newForOp.getRegionIterArgs()[i]); - - mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); - for (Operation &op : forOp.getBody()->without_terminator()) { - if (dyn_cast(op) == origConversion) - continue; +Operation *LayoutPropagation::rewriteOp(Operation *op) { + opToDelete.push_back(op); + if (auto forOp = dyn_cast(op)) + return rewriteForOp(forOp); + if (auto ifOp = dyn_cast(op)) + return rewriteIfOp(ifOp); + OpBuilder rewriter(op); + Attribute encoding = *layouts[op->getResult(0)].encodings.begin(); + if (auto convertOp = dyn_cast(op)) { + Attribute srcEncoding = + convertOp.getOperand().getType().cast().getEncoding(); + auto it = layouts.find(convertOp.getOperand()); + if (it != layouts.end()) + srcEncoding = *(it->second.encodings.begin()); + Value src = getValueAs(convertOp.getOperand(), srcEncoding); + auto tensorType = op->getResult(0).getType().cast(); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + auto cvt = rewriter.create(op->getLoc(), + newType, src); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (canFoldIntoConversion(op, encoding)) { + Operation *newOp = rewriter.clone(*op); + auto tensorType = op->getResult(0).getType().cast(); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + auto cvt = rewriter.create( + op->getLoc(), newType, newOp->getResult(0)); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (op->hasTrait() || + op->hasTrait() || + isa( + op)) { + Operation *newOp = cloneElementwise(rewriter, op, encoding); + for (auto [oldResult, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) + map(oldResult, newResult); + return newOp; + } + assert(0 && "unexpected op in rewrite"); + return nullptr; +} - bool convert = llvm::any_of(op.getOperands(), [&](auto operand) { - return operand == origConversion.getOperand(); - }); - auto convertLayout = [&](Value operand, Value value, Attribute encoding) { - auto tensorType = value.getType().cast(); - auto cvtType = RankedTensorType::get( - tensorType.getShape(), tensorType.getElementType(), encoding); - auto cvt = rewriter.create( - op.getLoc(), cvtType, value); - mapping.map(operand, cvt); - }; - DenseMap cvtValues; - if (convert) { - for (auto operand : op.getOperands()) { - if (operand == origConversion.getOperand() || - !isa(operand.getType())) - continue; - auto value = mapping.lookupOrDefault(operand); - // Convert to the new type - convertLayout(operand, value, newType.getEncoding()); - // Other ops don't use the converted value and we need to restore - cvtValues[operand] = value; +static bool canBeRemat(Operation *op) { + if (isa(op)) + return !isExpensiveLoadOrStore(op); + if (isa(op)) + return false; + if (isa(op)) + return false; + + return true; +} + +// Replace ForOp with a new ForOp with extra operands. The YieldOp is not +// updated and needs to be updated separatly for the loop to be correct. +static scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, + scf::ForOp loop, + ValueRange newIterOperands) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop); + + // Create a new loop before the existing one, with the extra operands. + rewriter.setInsertionPoint(loop); + auto operands = llvm::to_vector<4>(loop.getIterOperands()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = rewriter.create( + loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), + operands); + newLoop.getBody()->erase(); + + newLoop.getLoopBody().getBlocks().splice( + newLoop.getLoopBody().getBlocks().begin(), + loop.getLoopBody().getBlocks()); + for (Value operand : newIterOperands) + newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); + + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + return newLoop; +} + +static void rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp, IRMapping &mapping) { + SetVector opsToRewrite; + for (Value v : slice) { + if (v.getDefiningOp()) { + opsToRewrite.insert(v.getDefiningOp()); + } else { + opsToRewrite.insert(v.cast().getOwner()->getParentOp()); + // We also need to rewrite the yield op. + opsToRewrite.insert(v.cast().getOwner()->getTerminator()); + } + } + opsToRewrite = multiRootTopologicalSort(opsToRewrite); + + SmallVector deadLoops; + OpBuilder builder(slice.begin()->getContext()); + for (Operation *op : opsToRewrite) { + if (auto forOp = dyn_cast(op)) { + // Keep a mapping of the operands index to the new operands index. + SmallVector> argMapping; + SmallVector newOperands; + for (auto arg : forOp.getRegionIterArgs()) { + if (slice.count(arg)) { + OpOperand &initVal = forOp.getOpOperandForRegionIterArg(arg); + argMapping.push_back( + std::make_pair(*forOp.getIterArgNumberForOpOperand(initVal), + forOp.getNumIterOperands() + newOperands.size())); + newOperands.push_back(mapping.lookup(initVal.get())); } } - auto *newOp = cloneWithInferType(rewriter, &op, mapping); - if (convert) { - for (auto result : op.getResults()) { - if (!isa(result.getType())) - continue; - auto value = mapping.lookupOrDefault(result); - auto tensorType = result.getType().cast(); - // Convert to the original type - convertLayout(result, value, tensorType.getEncoding()); - } - // Restore original values - for (auto [operand, value] : cvtValues) - mapping.map(operand, value); + // Create a new for loop with the new operands. + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, newOperands); + deadLoops.push_back(forOp.getOperation()); + Block &loopBody = *newForOp.getBody(); + for (auto m : argMapping) { + mapping.map(newForOp.getResult(m.first), newForOp.getResult(m.second)); + int numIndVars = newForOp.getNumInductionVars(); + mapping.map(loopBody.getArgument(m.first + numIndVars), + loopBody.getArgument(m.second + numIndVars)); + } + continue; + } + builder.setInsertionPoint(op); + if (auto yieldOp = dyn_cast(op)) { + auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); + for (Value operand : yieldOp.getOperands()) { + if (slice.count(operand) == 0) + continue; + yieldOperands.push_back(mapping.lookup(operand)); } + builder.create(op->getLoc(), yieldOperands); + op->erase(); + continue; + } + if (isa(op)) { + Operation *newOp = builder.clone(*op); + auto tensorType = op->getResult(0).getType().cast(); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), + layout[op->getResult(0)]); + auto cvt = builder.create( + op->getLoc(), newType, newOp->getResult(0)); + mapping.map(op->getResult(0), cvt.getResult()); + continue; + } + Operation *newOp = builder.clone(*op, mapping); + for (auto [old, newV] : llvm::zip(op->getResults(), newOp->getResults())) { + auto it = layout.find(old); + if (it == layout.end()) + continue; + auto newType = RankedTensorType::get( + old.getType().cast().getShape(), + old.getType().cast().getElementType(), it->second); + newV.setType(newType); } - // create yield, inserting conversions if necessary - auto yieldOp = forOp.getBody()->getTerminator(); - SmallVector newYieldArgs; - // We use the new type for the result of the conversion - for (Value arg : yieldOp->getOperands()) - newYieldArgs.push_back(mapping.lookup(arg)); - if (newYieldArgs[i].getType() != newType) - newYieldArgs[i] = rewriter.create( - yieldOp->getLoc(), newType, newYieldArgs[i]); - rewriter.create(forOp.getLoc(), newYieldArgs); - - // replace - SmallVector newResults = newForOp->getResults(); - newResults[i] = rewriter.create( - newForOp.getLoc(), origType, newForOp->getResult(i)); - newResults[i].getDefiningOp()->moveAfter(newForOp); - - return newResults; } + convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getOperand())); + convertOp.erase(); + for (Operation *op : deadLoops) + op->erase(); +} - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto forOp = cast(op); - auto iterArgs = forOp.getRegionIterArgs(); - for (const auto &iterArg : llvm::enumerate(iterArgs)) { - // skip non-tensor types - if (!iterArg.value().getType().isa()) - continue; - SmallVector cvts; - if (canMoveOutOfLoop(iterArg.value(), cvts).failed()) - continue; - // check - for (auto *op : cvts) { - auto cvt = dyn_cast(op); - auto targetType = op->getResultTypes()[0].cast(); - auto newFor = rematerializeForLoop(rewriter, forOp, iterArg.index(), - targetType, cvt); - rewriter.replaceOp(forOp, newFor); - return success(); - } +static void rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp) { + IRMapping mapping; + rewriteSlice(slice, layout, convertOp, mapping); +} + +static void backwardRematerialization(ConvertLayoutOp convertOp) { + // we don't want to rematerialize any conversion to/from shared + if (triton::gpu::isSharedEncoding(convertOp.getResult()) || + triton::gpu::isSharedEncoding(convertOp.getOperand())) + return; + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristics to accommodate fused attention + auto targetType = convertOp->getResultTypes()[0].cast(); + if (targetType.getEncoding().isa()) + return; + + // 1. Take a backward slice of all the tensor dependencies. + SetVector slice; + DenseMap layout; + LogicalResult result = getConvertBackwardSlice( + convertOp.getOperand(), slice, targetType.getEncoding(), layout); + if (result.failed() || slice.empty()) + return; + + // 2. Check if all the operations in the slice can be rematerialized. + for (Value v : slice) { + if (Operation *op = v.getDefiningOp()) { + if (!canBeRemat(op)) + return; } - return failure(); } -}; + // 3. Rewrite the slice. + rewriteSlice(slice, layout, convertOp); +} -// -class ConvertDotConvert : public mlir::RewritePattern { -public: - ConvertDotConvert(mlir::MLIRContext *context) - : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 1, context) {} +// For convert left we try to hoist them above type extension to reduce the cost +// of the convert. +static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) { + // we don't want to rematerialize any conversion to/from shared + if (triton::gpu::isSharedEncoding(convertOp.getResult()) || + triton::gpu::isSharedEncoding(convertOp.getOperand())) + return; + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristics to accommodate fused attention + auto targetType = convertOp->getResultTypes()[0].cast(); + if (targetType.getEncoding().isa()) + return; - LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto dstOp = cast(op); - auto dotOp = dstOp.getSrc().getDefiningOp(); - if (!dotOp) - return mlir::failure(); - if (std::distance(dstOp->user_begin(), dstOp->user_end()) != 1 || - std::distance(dotOp->user_begin(), dotOp->user_end()) != 1) - return mlir::failure(); - auto cvtOp = - dotOp.getOperand(2).getDefiningOp(); - if (!cvtOp) - return mlir::failure(); - if (!cvtOp.getSrc().getDefiningOp()) - return failure(); - auto dstTy = dstOp.getResult().getType().cast(); - auto srcTy = cvtOp.getOperand().getType().cast(); - if (dstTy != srcTy) - return mlir::failure(); + // 1. Take a backward slice of all the tensor dependencies. + SetVector slice; + DenseMap layout; + auto isExtOp = [](Operation *op) { + return isa(op); + }; + // Get a backward slice but don't go past ext ops + LogicalResult result = getConvertBackwardSlice( + convertOp.getOperand(), slice, targetType.getEncoding(), layout, isExtOp); + if (result.failed() || slice.empty()) + return; + Operation *extOp = nullptr; + // 2. Check if all the operations in the slice can be rematerialized. + for (Value v : slice) { + if (Operation *op = v.getDefiningOp()) { + if (!canBeRemat(op)) + return; + if (isExtOp(op)) { + // Only apply it if there is a single ext op otherwise we would have to + // duplicate the convert. + if (extOp != nullptr) + return; + extOp = op; + } + } + } + if (extOp == nullptr) + return; + // Move the convert before the ext op and rewrite the slice. + OpBuilder builder(extOp); + auto tensorType = extOp->getOperand(0).getType().cast(); + auto newType = + RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), + layout[extOp->getResult(0)]); + auto newConvertOp = builder.create( + convertOp.getLoc(), newType, extOp->getOperand(0)); + IRMapping mapping; + mapping.map(extOp->getOperand(0), newConvertOp.getResult()); + // 3. Rewrite the slice. + rewriteSlice(slice, layout, convertOp, mapping); +} - auto _0f = rewriter.create( - op->getLoc(), dstTy.getElementType(), - rewriter.getZeroAttr(dstTy.getElementType())); - auto _0 = rewriter.create( - op->getLoc(), dotOp.getResult().getType(), _0f); - auto newDot = rewriter.create( - op->getLoc(), dotOp.getResult().getType(), dotOp.getOperand(0), - dotOp.getOperand(1), _0, dotOp.getAllowTF32()); - auto newCvt = rewriter.create( - op->getLoc(), dstTy, newDot.getResult()); - rewriter.replaceOpWithNewOp(op, newCvt, cvtOp.getOperand()); - return mlir::success(); +static void backwardRematerialization(ModuleOp module) { + SmallVector convertOps; + module.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + backwardRematerialization(convertOp); } -}; +} -} // namespace +static void hoistConvert(ModuleOp module) { + SmallVector convertOps; + module.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertOnTopOfExt(convertOp); + } +} #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" @@ -632,18 +864,45 @@ class TritonGPURemoveLayoutConversionsPass MLIRContext *context = &getContext(); ModuleOp m = getOperation(); - mlir::RewritePatternSet patterns(context); + // 1. Propagate layout forward starting from "anchor" ops. + m.walk([](triton::FuncOp funcOp) { + LayoutPropagation layoutPropagation(funcOp); + layoutPropagation.initAnchorLayout(); + layoutPropagation.propagateLayout(); + layoutPropagation.resolveConflicts(); + layoutPropagation.rewrite(); + }); + + mlir::RewritePatternSet cleanUpPatterns(context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, context); + if (mlir::applyPatternsAndFoldGreedily(m, std::move(cleanUpPatterns)) + .failed()) { + signalPassFailure(); + } - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); + // 2. For convert ops left try to rematerialize the slice of producer + // operation to avoid having to convert. + backwardRematerialization(m); + // 3. For converts left try to hoist them above cast generating larger size + // types in order to reduce the cost of the convert op. + hoistConvert(m); + + mlir::RewritePatternSet decomposePatterns(context); + decomposePatterns.add(context); + decomposePatterns.add(context); + if (mlir::applyPatternsAndFoldGreedily(m, std::move(decomposePatterns)) + .failed()) { + signalPassFailure(); + } - if (mlir::applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { + // 4. Apply clean up patterns to remove remove dead convert and dead code + // generated by the previous transformations. + mlir::RewritePatternSet cleanUpPatterns2(context); + populateForOpDeadArgumentElimination(cleanUpPatterns2); + scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + if (mlir::applyPatternsAndFoldGreedily(m, std::move(cleanUpPatterns2)) + .failed()) { signalPassFailure(); } } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index df605853b1f6..62791835067c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -240,30 +240,57 @@ std::string GraphLayoutMarker::getColor(const Type &type) const { } // -------------------------------------------------------------------------- // -// TODO: Interface -LogicalResult invertEncoding(Attribute targetEncoding, Operation *op, - Attribute &ret) { - ret = targetEncoding; - if (auto expand_dims = dyn_cast(op)) { - ret = triton::gpu::SliceEncodingAttr::get( - op->getContext(), expand_dims.getAxis(), targetEncoding); - } - if (auto reduce = dyn_cast(op)) { - auto sliceEncoding = - targetEncoding.dyn_cast(); - if (!sliceEncoding) - return failure(); - if (sliceEncoding.getDim() != reduce.getAxis()) - return failure(); - ret = sliceEncoding.getParent(); - } - if (isa(op)) { - return failure(); - } - return success(); +static std::optional inferDstEncoding(triton::ReduceOp op, + Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get(op->getContext(), op.getAxis(), + encoding); +} + +static std::optional inferDstEncoding(triton::ExpandDimsOp op, + Attribute encoding) { + auto sliceEncoding = encoding.dyn_cast(); + if (!sliceEncoding) + return std::nullopt; + assert(op.getAxis() == sliceEncoding.getDim()); + return sliceEncoding.getParent(); +} + +static std::optional inferSrcEncoding(triton::ReduceOp op, + Attribute encoding) { + auto sliceEncoding = encoding.dyn_cast(); + if (!sliceEncoding) + return std::nullopt; + assert(op.getAxis() == sliceEncoding.getDim()); + return sliceEncoding.getParent(); +} + +static std::optional inferSrcEncoding(triton::ExpandDimsOp op, + Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get(op->getContext(), op.getAxis(), + encoding); } -bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding) { +std::optional inferSrcEncoding(Operation *op, Attribute encoding) { + if (auto reduceOp = dyn_cast(op)) + return inferSrcEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferSrcEncoding(expand, encoding); + if (isa(op)) + return std::nullopt; + return encoding; +} + +std::optional inferDstEncoding(Operation *op, Attribute encoding) { + if (auto reduceOp = dyn_cast(op)) + return inferDstEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferDstEncoding(expand, encoding); + if (isa(op)) + return std::nullopt; + return encoding; +} + +bool isExpensiveLoadOrStore(Operation *op) { // Case 1: Pointer of tensor is always expensive auto operandType = op->getOperand(0).getType(); if (triton::isTensorPointerType(operandType)) @@ -287,7 +314,7 @@ bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { if (!op) return true; if (isa(op)) - return isExpensiveLoadOrStore(op, targetEncoding); + return isExpensiveLoadOrStore(op); if (isa(op)) return triton::gpu::isExpensiveCat(cast(op), targetEncoding); if (isa(op)) return !triton::gpu::isExpensiveCat(cast(op), targetEncoding); - return isa(op); -} - -int simulateBackwardRematerialization( - Operation *initOp, SetVector &processed, - SetVector &layout, llvm::MapVector &toConvert, - Attribute targetEncoding) { - // DFS - std::vector> queue; - queue.emplace_back(initOp, targetEncoding); - // We want to see the effect of converting `initOp` to a new layout - // so we initialize `numCvts = 1`. - int numCvts = 1; - while (!queue.empty()) { - Operation *currOp; - Attribute currLayout; - std::tie(currOp, currLayout) = queue.back(); - queue.pop_back(); - // If the current operation is expensive to rematerialize, - // we stop everything - if (isExpensiveToRemat(currOp, currLayout)) - break; - // A conversion will be removed here (i.e. transferred to operands) - numCvts -= 1; - // Done processing - processed.insert(currOp); - layout.insert(currLayout); - // Add all operands to the queue - for (Value argI : currOp->getOperands()) { - Attribute newEncoding; - // Cannot invert the current encoding for this operand - // we stop everything - if (failed(invertEncoding(currLayout, currOp, newEncoding))) - return INT_MAX; - if (toConvert.count(argI) && toConvert[argI] != newEncoding) - return INT_MAX; - if (auto ptrTy = argI.getType().dyn_cast()) { - if (ptrTy.getPointeeType().isa()) { - return INT_MAX; - } - } - - Operation *opArgI = argI.getDefiningOp(); - toConvert.insert({argI, newEncoding}); - // 1. Only convert RankedTensorType - // 2. Skip if there's no defining op - // 3. Skip if the defining op has already been processed - // 4. Skip or the defining op is in a different block - if (!argI.getType().isa() || !opArgI || - processed.contains(opArgI) || - opArgI->getBlock() != currOp->getBlock()) - continue; - // If the conversion can be folded into opArgI then - // we don't count this conversion as expensive - if (canFoldConversion(opArgI, newEncoding)) - continue; - - // We add one expensive conversion for the current operand - numCvts += 1; - queue.emplace_back(opArgI, newEncoding); + if (auto convert = dyn_cast(op)) { + if (targetEncoding.isa()) { + auto srcEncoding = + convert.getOperand().getType().cast().getEncoding(); + if (targetEncoding != srcEncoding) + return false; } + return true; } - // return net number of conversions - return numCvts; + return isa(op); } // @@ -409,213 +382,54 @@ Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, return newOp; } -namespace { - -struct OpUseInfo { - Value value; - Operation *op; - unsigned index; -}; - -void getForwardSliceOpUseInfo(Operation *op, - SetVector *forwardSliceOps, - SmallVector *forwardOpUseInfo) { - if (!op) - return; - - for (Region ®ion : op->getRegions()) - for (Block &block : region) - for (Operation &blockOp : block) - if (forwardSliceOps->count(&blockOp) == 0) - getForwardSliceOpUseInfo(&blockOp, forwardSliceOps, forwardOpUseInfo); - for (Value result : op->getResults()) { - for (OpOperand &operand : result.getUses()) { - auto *blockOp = operand.getOwner(); - forwardOpUseInfo->push_back( - {operand.get(), blockOp, operand.getOperandNumber()}); - if (forwardSliceOps->count(blockOp) == 0) - getForwardSliceOpUseInfo(blockOp, forwardSliceOps, forwardOpUseInfo); - } - } - - forwardSliceOps->insert(op); -} -} // namespace - -LogicalResult simulateForwardRematerializationInLoop(Operation *startOp, - BlockArgument arg, - Attribute targetEncoding) { - // heuristics for flash attention - if (targetEncoding.isa()) - return failure(); - SetVector cvtSliceOps; - SmallVector cvtSliceOpUseInfo; - getForwardSliceOpUseInfo(startOp, &cvtSliceOps, &cvtSliceOpUseInfo); - - // Check if any additional conversion is needed along the way - for (Operation *op : cvtSliceOps) { - if (isa(op)) +LogicalResult +getConvertBackwardSlice(Value root, SetVector &slice, + Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation) { + SmallVector> queue = {{root, rootEncoding}}; + while (!queue.empty()) { + auto [currentValue, encoding] = queue.back(); + queue.pop_back(); + if (!currentValue.getType().isa()) continue; - // The first op doesn't push forward any conversion - if (op != startOp) { - if (isa(op) && - !op->getResult(0).getType().isa()) - return failure(); - // don't rematerialize anything expensive - if (isExpensiveToRemat(op, targetEncoding)) - return failure(); - // don't rematerialize non-element-wise - if (!op->hasTrait() && - !op->hasTrait() && - !isa(op)) - return failure(); - } - // don't rematerialize if it adds an extra conversion that can't - // be removed - for (Value value : op->getOperands()) { - Operation *argOp = arg.getDefiningOp(); - SetVector processed; - SetVector layout; - llvm::MapVector toConvert; - int numAddedConvs = simulateBackwardRematerialization( - argOp, processed, layout, toConvert, targetEncoding); - if (argOp && !isa(argOp) && - cvtSliceOps.count(argOp) == 0 && numAddedConvs > 0) + // Skip propagating through for op results for now. + // TODO: enable this based on needs. + if (currentValue.getDefiningOp()) + return failure(); + slice.insert(currentValue); + layout[currentValue] = encoding; + if (auto *definingOp = currentValue.getDefiningOp()) { + if (canFoldIntoConversion(definingOp, encoding)) + continue; + if (stopPropagation && stopPropagation(definingOp)) + continue; + if (isa(definingOp)) return failure(); - } - } - - // We apply conservative analysis. Only when the final operand's index - // matches the argument's index or their encoding match, we can rematerialize. - for (auto &opUseInfo : cvtSliceOpUseInfo) { - Operation *op = opUseInfo.op; - if (isa(op)) { - auto yieldIdx = opUseInfo.index; - // 0 is the induction variable - auto argIdx = arg.getArgNumber() - 1; - if (yieldIdx != argIdx) { - auto argType = arg.getType().cast(); - auto yieldType = - op->getOperand(yieldIdx).getType().dyn_cast(); - if (!yieldType || argType.getEncoding() != yieldType.getEncoding()) + for (Value operand : definingOp->getOperands()) { + auto srcEncoding = inferSrcEncoding(definingOp, encoding); + if (!srcEncoding) return failure(); + if (slice.count(operand) == 0) + queue.push_back({operand, *srcEncoding}); } + continue; } - } - return success(); -} - -void rematerializeConversionChain( - const llvm::MapVector &toConvert, - mlir::PatternRewriter &rewriter, SetVector &processed, - IRMapping &mapping) { - SmallVector sortedValues; - SetVector tmp; - for (auto &item : toConvert) { - Value v = item.first; - if (v.getDefiningOp()) - tmp.insert(v.getDefiningOp()); - else - sortedValues.push_back(v); - } - tmp = mlir::multiRootTopologicalSort(tmp); - for (Operation *op : tmp) - sortedValues.push_back(op->getResult(0)); - - for (Value currOperand : sortedValues) { - Value origOperand = currOperand; - // unpack information - Attribute targetLayout = toConvert.lookup(currOperand); - // rematerialize the operand if necessary - Operation *currOperation = currOperand.getDefiningOp(); - if (processed.contains(currOperation)) { - Operation *newOperation = - cloneWithInferType(rewriter, currOperation, mapping); - newOperation->moveAfter(currOperation); - currOperation = newOperation; - currOperand = currOperation->getResult(0); - } - // compute target type for the layout cast - auto currType = currOperand.getType().cast(); - auto newType = RankedTensorType::get( - currType.getShape(), currType.getElementType(), targetLayout); - auto newOperand = rewriter.create( - currOperand.getLoc(), newType, currOperand); - if (currOperation) - newOperand->moveAfter(currOperation); - else { - Block *block = currOperand.cast().getOwner(); - newOperand->moveBefore(block, block->begin()); + auto blockArg = cast(currentValue); + Block *block = blockArg.getOwner(); + Operation *parentOp = block->getParentOp(); + if (auto forOp = dyn_cast(parentOp)) { + OpOperand &initOperand = forOp.getOpOperandForRegionIterArg(blockArg); + Value yieldOperand = forOp.getBody()->getTerminator()->getOperand( + blockArg.getArgNumber() - forOp.getNumInductionVars()); + queue.push_back({initOperand.get(), encoding}); + queue.push_back({yieldOperand, encoding}); + continue; } - mapping.map(origOperand, newOperand); - } -} - -LogicalResult canMoveOutOfLoop(BlockArgument arg, - SmallVector &cvts) { - auto parentOp = arg.getOwner()->getParentOp(); - // Don't move if arg is defined in a while loop - if (isa(parentOp)) + // TODO: add support for WhileOp and other region types. return failure(); - // Skip if arg is not defined in scf.for - if (!isa(parentOp)) - return success(); - auto forOp = cast(parentOp); - // We only move `iterArg` out of the loop if - // 1. There is no conversion - // 2. There is only a single conversion - // 3. Moving this conversion out of the loop will not generate any extra - // non-removable conversion - SetVector cvtTypes; - SetVector others; - auto oldType = arg.getType().cast(); - for (auto user : arg.getUsers()) { - if (isa(user)) { - // Don't move if the conversion target is a dot operand or shared memory - auto newType = user->getResults()[0].getType().cast(); - if (oldType.getEncoding().isa() && - newType.getEncoding().isa()) { - continue; - } - if (newType.getEncoding().isa()) { - if (newType.getEncoding() - .cast() - .getVec() == 1) - continue; - } - cvts.emplace_back(user); - cvtTypes.insert(newType); - } else - others.insert(user); - } - // First condition - if (cvts.empty()) - return success(); - if (cvtTypes.size() == 1) { - // Third condition - part 1: - // If the other or the cvt is in the different block, we cannot push the - // conversion forward or backward - for (auto *cvt : cvts) { - if (cvt->getBlock() != forOp.getBody()) - return failure(); - } - auto targetEncoding = cvtTypes.front().getEncoding(); - for (auto *other : others) { - // Third condition - part 2: - // If the other non-cvt op is in the different block, we cannot push the - // conversion forward or backward - if (other->getBlock() != forOp.getBody()) - return failure(); - // Third condition - part 3: - // Check if we can directly use arg without conversion - if (simulateForwardRematerializationInLoop(other, arg, targetEncoding) - .failed()) - return failure(); - } - return success(); } - return failure(); + return success(); } // TODO(thomas): this is duplicated with what is in GPUToLLVM @@ -700,4 +514,117 @@ void setRoleId(Operation *op, int roleId) { op->setAttr("agent.mutex_role", attr); } +namespace { + +/// Detect dead arguments in scf.for op by assuming all the values are dead and +/// propagate liveness property. +struct ForOpDeadArgElimination : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const final { + Block &block = *forOp.getBody(); + auto yieldOp = cast(block.getTerminator()); + // Assume that nothing is live at the beginning and mark values as live + // based on uses. + DenseSet aliveValues; + SmallVector queue; + // Helper to mark values as live and add them to the queue of value to + // propagate if it is the first time we detect the value as live. + auto markLive = [&](Value val) { + if (!forOp->isAncestor(val.getParentRegion()->getParentOp())) + return; + if (aliveValues.insert(val).second) + queue.push_back(val); + }; + // Mark all yield operands as live if the associated forOp result has any + // use. + for (auto result : llvm::enumerate(forOp.getResults())) { + if (!result.value().use_empty()) + markLive(yieldOp.getOperand(result.index())); + } + if (aliveValues.size() == forOp.getNumResults()) + return failure(); + // Operations with side-effects are always live. Mark all theirs operands as + // live. + block.walk([&](Operation *op) { + if (!isa(op) && !wouldOpBeTriviallyDead(op)) { + for (Value operand : op->getOperands()) + markLive(operand); + } + }); + // Propagate live property until reaching a fixed point. + while (!queue.empty()) { + Value value = queue.pop_back_val(); + if (auto nestedFor = value.getDefiningOp()) { + auto result = value.cast(); + OpOperand &forOperand = nestedFor.getOpOperandForResult(result); + markLive(forOperand.get()); + auto nestedYieldOp = + cast(nestedFor.getBody()->getTerminator()); + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + continue; + } + if (auto nestedIf = value.getDefiningOp()) { + auto result = value.cast(); + for (scf::YieldOp nestedYieldOp : + {nestedIf.thenYield(), nestedIf.elseYield()}) { + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + } + continue; + } + if (Operation *def = value.getDefiningOp()) { + // TODO: support while ops. + if (isa(def)) + return failure(); + for (Value operand : def->getOperands()) + markLive(operand); + continue; + } + // If an argument block is live then the associated yield operand and + // forOp operand are live. + auto arg = value.cast(); + if (auto forOwner = dyn_cast(arg.getOwner()->getParentOp())) { + if (arg.getArgNumber() < forOwner.getNumInductionVars()) + continue; + unsigned iterIdx = arg.getArgNumber() - forOwner.getNumInductionVars(); + Value yieldOperand = + forOwner.getBody()->getTerminator()->getOperand(iterIdx); + markLive(yieldOperand); + markLive(forOwner.getIterOperands()[iterIdx]); + } + } + SmallVector deadArg; + for (auto yieldOperand : llvm::enumerate(yieldOp->getOperands())) { + if (aliveValues.contains(yieldOperand.value())) + continue; + if (yieldOperand.value() == block.getArgument(yieldOperand.index() + 1)) + continue; + deadArg.push_back(yieldOperand.index()); + } + if (deadArg.empty()) + return failure(); + rewriter.updateRootInPlace(forOp, [&]() { + // For simplicity we just change the dead yield operand to use the + // associated argument and leave the operations and argument removal to + // dead code elimination. + for (unsigned deadArgIdx : deadArg) { + BlockArgument arg = block.getArgument(deadArgIdx + 1); + yieldOp.setOperand(deadArgIdx, arg); + } + }); + return success(); + } +}; + +} // namespace + +void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + } // namespace mlir diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 3693930cc1ab..ba240ae2e11e 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -77,6 +77,20 @@ tt.func @remat_fast_load(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { tt.return } +// Hoist the convert on top of ext to make it cheaper. +// CHECK-LABEL: hoist_above_ext +tt.func @hoist_above_ext(%arg0: tensor<1024xf16, #layout0>, %arg1: f32) -> tensor<1024xf32, #layout1> { +// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout +// CHECK: arith.extf %[[CVT]] +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.return + %0 = arith.extf %arg0 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> + %1 = tt.splat %arg1 : (f32) -> tensor<1024xf32, #layout1> + %2 = triton_gpu.convert_layout %0 : (tensor<1024xf32, #layout0>) -> tensor<1024xf32, #layout1> + %3 = arith.addf %1, %2 : tensor<1024xf32, #layout1> + tt.return %3 : tensor<1024xf32, #layout1> +} + // CHECK-LABEL: if tt.func @if(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { // CHECK-NOT: triton_gpu.convert_layout @@ -229,8 +243,10 @@ tt.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, // CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr, [[$row_layout]]>, tensor<64x64xi32, [[$row_layout]]> // CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr, [[$row_layout]]> // CHECK-NEXT: } - // CHECK-NEXT: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[$row_layout]]>) -> tensor<64x64xf32, [[$col_layout_novec]]> // CHECK-NOT: triton_gpu.convert_layout + // CHECK: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[$row_layout]]>) -> tensor<64x64xf32, [[$col_layout_novec]]> + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: tt.return %cst = arith.constant dense : tensor<64x64xi1, #blocked1> %cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1> %c1 = arith.constant 1 : index @@ -276,6 +292,19 @@ tt.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, } // CHECK-LABEL: loop_if +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: scf.for +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: scf.if +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: scf.yield +// CHECK: else +// CHECK: scf.yield +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: scf.yield +// CHECK: triton_gpu.convert_layout +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.store module attributes {"triton_gpu.num-warps" = 4 : i32} { tt.func @loop_if(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { %cst = arith.constant dense : tensor<64x64xi1, #blocked1> @@ -1125,14 +1154,14 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { // ----- -// Check if the SimplifyReduceCvt handles convert_layout lifted from the for loop. // CHECK-LABEL: reduce_cvt2 // Match the reduction // CHECK: tt.reduce // CHECK-SAME: axis = 1 // CHECK: (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -// CHECK-NEXT: triton_gpu.convert_layout +// CHECK: triton_gpu.convert_layout // CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.return #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> @@ -1347,6 +1376,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // Check if MoveConvertOutOfLoop hangs because of adding additional conversions // CHECK-LABEL: loop_print // CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.return #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> @@ -1502,3 +1532,211 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war tt.return } } + + +// ----- + +// Check that we don't have extra convert for flash attention IR. +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [4, 1, 8], warpsPerCTA = [4, 1, 1], order = [1, 2, 0], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [1, 0, 2]}> +#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [1, 4, 8], warpsPerCTA = [1, 4, 1], order = [0, 2, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}> +#blocked6 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked7 = #triton_gpu.blocked<{sizePerThread = [8, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [1, 1, 4], order = [1, 0, 2], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [1, 0, 2]}> +#blocked8 = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 8, 4], warpsPerCTA = [1, 1, 4], order = [0, 1, 2], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}> +#blocked9 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @attention_fw(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg21: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %c0_i64 = arith.constant 0 : i64 + %c64_i64 = arith.constant 64 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked> + %cst_0 = arith.constant dense<0xFF800000> : tensor<128xf32, #blocked1> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128xf32, #blocked1> + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #blocked2> + %cst_3 = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = arith.muli %1, %arg10 : i32 + %4 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %5 = arith.muli %0, %c128_i32 : i32 + %6 = arith.extsi %arg8 : i32 to i64 + %7 = arith.extsi %5 : i32 to i64 + %8 = tt.addptr %arg1, %3 : !tt.ptr, i32 + %9 = arith.addi %arg20, %arg21 : i32 + %10 = arith.extsi %arg11 : i32 to i64 + %11 = tt.addptr %arg2, %3 : !tt.ptr, i32 + %12 = arith.extsi %arg14 : i32 to i64 + %13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked1> + %14 = tt.splat %5 : (i32) -> tensor<128xi32, #blocked1> + %15 = arith.addi %14, %13 : tensor<128xi32, #blocked1> + %16 = arith.mulf %arg3, %cst_3 : f32 + %17 = tt.splat %4 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked3> + %18 = tt.splat %7 : (i64) -> tensor<128xi64, #blocked3> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3> + %20 = arith.extsi %19 : tensor<128xi32, #blocked3> to tensor<128xi64, #blocked3> + %21 = arith.addi %18, %20 : tensor<128xi64, #blocked3> + %22 = triton_gpu.convert_layout %21 : (tensor<128xi64, #blocked3>) -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %23 = tt.expand_dims %22 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<128x1xi64, #blocked4> + %24 = tt.splat %6 : (i64) -> tensor<128x1xi64, #blocked4> + %25 = arith.muli %23, %24 : tensor<128x1xi64, #blocked4> + %26 = tt.broadcast %25 : (tensor<128x1xi64, #blocked4>) -> tensor<128x64xi64, #blocked4> + %27 = triton_gpu.convert_layout %26 : (tensor<128x64xi64, #blocked4>) -> tensor<128x64xi64, #blocked3> + %28 = tt.addptr %17, %27 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> + %29 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3> + %30 = arith.extsi %29 : tensor<64xi32, #blocked3> to tensor<64xi64, #blocked3> + %31 = triton_gpu.convert_layout %30 : (tensor<64xi64, #blocked3>) -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> + %32 = tt.expand_dims %31 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked5}>>) -> tensor<1x64xi64, #blocked5> + %33 = tt.broadcast %32 : (tensor<1x64xi64, #blocked5>) -> tensor<128x64xi64, #blocked5> + %34 = triton_gpu.convert_layout %33 : (tensor<128x64xi64, #blocked5>) -> tensor<128x64xi64, #blocked3> + %35 = tt.addptr %28, %34 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> + %36 = tt.load %35 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked3> + %37 = triton_gpu.convert_layout %36 : (tensor<128x64xf16, #blocked3>) -> tensor<128x64xf16, #blocked2> + %38 = tt.splat %16 : (f32) -> tensor<128x64xf32, #blocked2> + %39 = arith.extf %37 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> + %40 = arith.mulf %39, %38 : tensor<128x64xf32, #blocked2> + %41 = arith.truncf %40 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: scf.for +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op +// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.dot +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op +// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.dot +// CHECK: scf.yield + %42:5 = scf.for %arg22 = %c0_i32 to %9 step %c64_i32 iter_args(%arg23 = %cst_2, %arg24 = %cst_1, %arg25 = %cst_0, %arg26 = %c0_i64, %arg27 = %c0_i64) -> (tensor<128x64xf32, #blocked2>, tensor<128xf32, #blocked1>, tensor<128xf32, #blocked1>, i64, i64) : i32 { + %78 = tt.splat %8 : (!tt.ptr) -> tensor<64x64x!tt.ptr, #blocked6> + %79 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked6> + %80 = arith.extsi %79 : tensor<64xi32, #blocked6> to tensor<64xi64, #blocked6> + %81 = triton_gpu.convert_layout %80 : (tensor<64xi64, #blocked6>) -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> + %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked7}>>) -> tensor<64x1xi64, #blocked7> + %83 = tt.broadcast %82 : (tensor<64x1xi64, #blocked7>) -> tensor<64x64xi64, #blocked7> + %84 = triton_gpu.convert_layout %83 : (tensor<64x64xi64, #blocked7>) -> tensor<64x64xi64, #blocked6> + %85 = tt.addptr %78, %84 : tensor<64x64x!tt.ptr, #blocked6>, tensor<64x64xi64, #blocked6> + %86 = tt.splat %arg26 : (i64) -> tensor<64xi64, #blocked6> + %87 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked6> + %88 = arith.extsi %87 : tensor<64xi32, #blocked6> to tensor<64xi64, #blocked6> + %89 = arith.addi %86, %88 : tensor<64xi64, #blocked6> + %90 = triton_gpu.convert_layout %89 : (tensor<64xi64, #blocked6>) -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked8}>> + %91 = tt.expand_dims %90 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked8}>>) -> tensor<1x64xi64, #blocked8> + %92 = tt.splat %10 : (i64) -> tensor<1x64xi64, #blocked8> + %93 = arith.muli %91, %92 : tensor<1x64xi64, #blocked8> + %94 = tt.broadcast %93 : (tensor<1x64xi64, #blocked8>) -> tensor<64x64xi64, #blocked8> + %95 = triton_gpu.convert_layout %94 : (tensor<64x64xi64, #blocked8>) -> tensor<64x64xi64, #blocked6> + %96 = tt.addptr %85, %95 : tensor<64x64x!tt.ptr, #blocked6>, tensor<64x64xi64, #blocked6> + %97 = tt.load %96 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked6> + %98 = tt.splat %11 : (!tt.ptr) -> tensor<64x64x!tt.ptr, #blocked3> + %99 = tt.splat %arg27 : (i64) -> tensor<64xi64, #blocked3> + %100 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3> + %101 = arith.extsi %100 : tensor<64xi32, #blocked3> to tensor<64xi64, #blocked3> + %102 = arith.addi %99, %101 : tensor<64xi64, #blocked3> + %103 = triton_gpu.convert_layout %102 : (tensor<64xi64, #blocked3>) -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %104 = tt.expand_dims %103 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64x1xi64, #blocked4> + %105 = tt.splat %12 : (i64) -> tensor<64x1xi64, #blocked4> + %106 = arith.muli %104, %105 : tensor<64x1xi64, #blocked4> + %107 = tt.broadcast %106 : (tensor<64x1xi64, #blocked4>) -> tensor<64x64xi64, #blocked4> + %108 = triton_gpu.convert_layout %107 : (tensor<64x64xi64, #blocked4>) -> tensor<64x64xi64, #blocked3> + %109 = tt.addptr %98, %108 : tensor<64x64x!tt.ptr, #blocked3>, tensor<64x64xi64, #blocked3> + %110 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3> + %111 = arith.extsi %110 : tensor<64xi32, #blocked3> to tensor<64xi64, #blocked3> + %112 = triton_gpu.convert_layout %111 : (tensor<64xi64, #blocked3>) -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> + %113 = tt.expand_dims %112 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked5}>>) -> tensor<1x64xi64, #blocked5> + %114 = tt.broadcast %113 : (tensor<1x64xi64, #blocked5>) -> tensor<64x64xi64, #blocked5> + %115 = triton_gpu.convert_layout %114 : (tensor<64x64xi64, #blocked5>) -> tensor<64x64xi64, #blocked3> + %116 = tt.addptr %109, %115 : tensor<64x64x!tt.ptr, #blocked3>, tensor<64x64xi64, #blocked3> + %117 = tt.load %116 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked3> + %118 = triton_gpu.convert_layout %41 : (tensor<128x64xf16, #blocked2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %119 = triton_gpu.convert_layout %97 : (tensor<64x64xf16, #blocked6>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %120 = tt.dot %118, %119, %cst {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked> + %121 = triton_gpu.convert_layout %120 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #blocked2> + %122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> + %123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({ + ^bb0(%arg28: f32, %arg29: f32): + %153 = arith.maxf %arg28, %arg29 : f32 + tt.reduce.return %153 : f32 + }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %124 = triton_gpu.convert_layout %123 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xf32, #blocked1> + %125 = arith.maxf %arg25, %124 : tensor<128xf32, #blocked1> + %126 = arith.subf %arg25, %125 : tensor<128xf32, #blocked1> + %127 = tt.extern_elementwise %126 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #blocked1> + %128 = triton_gpu.convert_layout %125 : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> + %129 = tt.expand_dims %128 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>>) -> tensor<128x1xf32, #blocked9> + %130 = triton_gpu.convert_layout %129 : (tensor<128x1xf32, #blocked9>) -> tensor<128x1xf32, #blocked2> + %131 = tt.broadcast %130 : (tensor<128x1xf32, #blocked2>) -> tensor<128x64xf32, #blocked2> + %132 = arith.subf %122, %131 : tensor<128x64xf32, #blocked2> + %133 = tt.extern_elementwise %132 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #blocked2> + %134 = arith.mulf %arg24, %cst_1 : tensor<128xf32, #blocked1> + %135 = arith.addf %134, %127 : tensor<128xf32, #blocked1> + %136 = triton_gpu.convert_layout %135 : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> + %137 = tt.expand_dims %136 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>>) -> tensor<128x1xf32, #blocked9> + %138 = triton_gpu.convert_layout %137 : (tensor<128x1xf32, #blocked9>) -> tensor<128x1xf32, #blocked2> + %139 = tt.broadcast %138 : (tensor<128x1xf32, #blocked2>) -> tensor<128x64xf32, #blocked2> + %140 = arith.mulf %arg23, %139 : tensor<128x64xf32, #blocked2> + %141 = arith.truncf %133 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> + %142 = triton_gpu.convert_layout %141 : (tensor<128x64xf16, #blocked2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %143 = triton_gpu.convert_layout %117 : (tensor<64x64xf16, #blocked3>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %144 = triton_gpu.convert_layout %140 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #blocked> + %145 = tt.dot %142, %143, %144 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> + %146 = triton_gpu.convert_layout %145 : (tensor<128x64xf32, #blocked>) -> tensor<128x64xf32, #blocked2> + %147 = arith.mulf %arg24, %127 : tensor<128xf32, #blocked1> + %148 = "tt.reduce"(%133) <{axis = 1 : i32}> ({ + ^bb0(%arg28: f32, %arg29: f32): + %153 = arith.addf %arg28, %arg29 : f32 + tt.reduce.return %153 : f32 + }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %149 = triton_gpu.convert_layout %148 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xf32, #blocked1> + %150 = arith.addf %147, %149 : tensor<128xf32, #blocked1> + %151 = arith.addi %arg26, %c64_i64 : i64 + %152 = arith.addi %arg27, %c64_i64 : i64 + scf.yield %146, %150, %125, %151, %152 : tensor<128x64xf32, #blocked2>, tensor<128xf32, #blocked1>, tensor<128xf32, #blocked1>, i64, i64 + } + %43 = triton_gpu.convert_layout %42#1 : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> + %44 = tt.expand_dims %43 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>>) -> tensor<128x1xf32, #blocked9> + %45 = triton_gpu.convert_layout %44 : (tensor<128x1xf32, #blocked9>) -> tensor<128x1xf32, #blocked2> + %46 = tt.broadcast %45 : (tensor<128x1xf32, #blocked2>) -> tensor<128x64xf32, #blocked2> + %47 = arith.divf %42#0, %46 : tensor<128x64xf32, #blocked2> + %48 = arith.muli %1, %arg20 : i32 + %49 = tt.addptr %arg4, %48 : !tt.ptr, i32 + %50 = tt.splat %49 : (!tt.ptr) -> tensor<128x!tt.ptr, #blocked1> + %51 = tt.addptr %50, %15 : tensor<128x!tt.ptr, #blocked1>, tensor<128xi32, #blocked1> + %52 = tt.extern_elementwise %42#1 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_log2f"} : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #blocked1> + %53 = arith.addf %42#2, %52 : tensor<128xf32, #blocked1> + tt.store %51, %53 {cache = 1 : i32, evict = 1 : i32} : tensor<128xf32, #blocked1> + %54 = tt.addptr %arg5, %2 : !tt.ptr, i32 + %55 = arith.extsi %arg17 : i32 to i64 + %56 = arith.extsi %5 : i32 to i64 + %57 = arith.truncf %47 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> + %58 = triton_gpu.convert_layout %57 : (tensor<128x64xf16, #blocked2>) -> tensor<128x64xf16, #blocked3> + %59 = tt.splat %54 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked3> + %60 = tt.splat %56 : (i64) -> tensor<128xi64, #blocked3> + %61 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3> + %62 = arith.extsi %61 : tensor<128xi32, #blocked3> to tensor<128xi64, #blocked3> + %63 = arith.addi %60, %62 : tensor<128xi64, #blocked3> + %64 = triton_gpu.convert_layout %63 : (tensor<128xi64, #blocked3>) -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %65 = tt.expand_dims %64 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<128x1xi64, #blocked4> + %66 = tt.splat %55 : (i64) -> tensor<128x1xi64, #blocked4> + %67 = arith.muli %65, %66 : tensor<128x1xi64, #blocked4> + %68 = tt.broadcast %67 : (tensor<128x1xi64, #blocked4>) -> tensor<128x64xi64, #blocked4> + %69 = triton_gpu.convert_layout %68 : (tensor<128x64xi64, #blocked4>) -> tensor<128x64xi64, #blocked3> + %70 = tt.addptr %59, %69 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> + %71 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3> + %72 = arith.extsi %71 : tensor<64xi32, #blocked3> to tensor<64xi64, #blocked3> + %73 = triton_gpu.convert_layout %72 : (tensor<64xi64, #blocked3>) -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> + %74 = tt.expand_dims %73 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked5}>>) -> tensor<1x64xi64, #blocked5> + %75 = tt.broadcast %74 : (tensor<1x64xi64, #blocked5>) -> tensor<128x64xi64, #blocked5> + %76 = triton_gpu.convert_layout %75 : (tensor<128x64xi64, #blocked5>) -> tensor<128x64xi64, #blocked3> + %77 = tt.addptr %70, %76 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> + tt.store %77, %58 {cache = 1 : i32, evict = 1 : i32} : tensor<128x64xf16, #blocked3> + tt.return + } +} diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 688e8ebd1810..1fbfaa9d4d07 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-optimize-dot-operands -tritongpu-remove-layout-conversions -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-optimize-dot-operands -canonicalize | FileCheck %s #Cv2 = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> #Av2k1 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=1}>