From c7ece064ae9261320e1fe05d69ac2f1ca8f8cf21 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Fri, 6 Dec 2024 11:11:07 -0800 Subject: [PATCH 1/8] fix an issue with two consumers when output of ForOp is used (#5) Summary: For an example with IfOp o = forOp mulf that uses o Prior to the change: IfOp o1 = forOp mulf that uses o1 IfOp o2 = forOp mulf that uses o1 After: IfOp o1 = forOp use o1 IfOp o2 = forOp use o2 We should not replace use of o with output of the specialized forOp while handling the first taskId, that will make handling of the next taskId incorrect. Instead of we update the mapping that is private to each taskId. --- lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp index c3b6c4d86a..70fb206031 100644 --- a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp @@ -224,9 +224,7 @@ Operation *SpecializeForOp(scf::ForOp forOp, IRMapping &mapping, for (unsigned i = 0; i < usedArgs.size(); ++i) { auto oldResult = forOp.getResult(usedArgs[i]); auto newResult = newForOp.getResult(i); - oldResult.replaceUsesWithIf(newResult, [&](OpOperand &operand) -> bool { - return hasAsyncTaskId(operand.getOwner(), asyncTaskId); - }); + mapping.map(oldResult, newResult); } return newForOp; From 547b1e2499928a16301cf8fe1ba43f23a7db9045 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Wed, 27 Nov 2024 16:22:37 -0800 Subject: [PATCH 2/8] Support arbitrary data channel --- .../TritonGPU/Transforms/WSCodePartition.cpp | 427 +++++++++++++----- 1 file changed, 318 insertions(+), 109 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp index 70fb206031..4daae21fce 100644 --- a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp @@ -383,29 +383,127 @@ struct Channel { public: using Relation = std::pair>; - Channel(int producer, SmallVector &consumers, Operation *src, - Operation *dst, Value srcOperand, unsigned numBuffers) - : relation(producer, consumers), srcOp(src), dstOp(dst), - srcOperand(srcOperand), numBuffers(numBuffers) {} + Channel(int producer, SmallVector &consumers, Operation *op, + unsigned operandIdx, unsigned numBuffers) + : relation(producer, consumers), op(op), operandIdx(operandIdx), + numBuffers(numBuffers) {} bool operator==(const Channel &c) { - return relation == c.relation && srcOp == c.srcOp && dstOp == c.dstOp; + return relation == c.relation && operandIdx == c.operandIdx && op == c.op; } + Operation *getDstOp() { return op; } + unsigned getDstOperandIdx() { return operandIdx; } + Value getSrcOperand() { return op->getOperand(operandIdx); } + Operation *getSrcOp() { return getSrcOperand().getDefiningOp(); } + Relation relation; // producer task Id, a list of consumer task Ids - Operation *srcOp; - Operation *dstOp; - Value srcOperand; + Operation *op; + unsigned operandIdx; unsigned numBuffers; }; +// Find transitive users of the root op. Ignore control flow ops such as yield +// in between. +void getTransitiveUsers(Value root, + SetVector> &users) { + for (Operation *userOp : root.getUsers()) { + if (auto yieldOp = dyn_cast(userOp)) { + for (OpOperand &operand : yieldOp->getOpOperands()) { + if (operand.get() == root) { + auto result = + yieldOp->getParentOp()->getResult(operand.getOperandNumber()); + getTransitiveUsers(result, users); + } + } + } else { + // find operand index of root + unsigned operandIndex = 0; + for (OpOperand &operand : userOp->getOpOperands()) { + if (operand.get() == root) { + break; + } + operandIndex++; + } + assert(operandIndex < userOp->getNumOperands() && + "root is not an operand of userOp"); + users.insert({userOp, operandIndex}); + } + } +} + +// Value findTransitiveProducer(Value root) { +// SetVector users; +// getTransitiveUsers(root, users); +// assert(!users.empty() && "producer op has no consumers"); +// if (users.size() == 1) { +// return root; +// } + +// // Reconsile the user list. If there are multiple users, we need to +// // find a global producer in the same scope with the outtermost-scoped +// user. Region *currentScope = root.getDefiningOp()->getParentRegion(); +// Region *userScope = nullptr; + +// // check all users have same scope +// for (Operation *user : users) { +// auto currentScope = user->getParentRegion(); +// if (userScope) + +// auto parentRegions = getAllParentRegions(user); +// if (parentRegions.count(currentScope) == 0) { +// llvm::errs() << "user is not in the same scope as the producer\n"; +// llvm_unreachable("user is not in the same scope as the producer"); +// } + +// if (userScope->getParentOfType() == nullptr) { +// newScope = userScope; +// } +// } + +// auto getAllParentRegions = [](Operation *op) { +// SetVector parentRegions; +// Region *region = op->getParentRegion(); +// while (region) { +// parentRegions.insert(region); +// mlir::Operation *parentOp = region->getParentOp(); +// if (!parentOp) +// break; +// region = parentOp->getParentRegion(); +// } + +// return parentRegions; +// }; + +// // Find the outtermost-scoped user. +// for (Operation *user : users) { +// auto parentRegions = getAllParentRegions(user); +// if (parentRegions.count(currentScope) == 0) { +// llvm::errs() << "user is not in the same scope as the producer\n"; +// llvm_unreachable("user is not in the same scope as the producer"); +// } + +// if (userScope->getParentOfType() == nullptr) { +// newScope = userScope; +// } +// } + +// // Find the global producer in the same scope with the outtermost-scoped +// user. + +// // Promote producer to the same level as the consumer. + +// return root; +// } + // Loads will be in producer warp groups. For now, we only allow a single // warp group/task for a producer. For each LoadOp, create a channel from it // to any direct user which belongs to a different taskId. void collectAsyncChannels(SmallVector> &channels, triton::FuncOp &funcOp, unsigned numBuffers) { funcOp.walk([&](Operation *op) { - if (isa(op)) { + if (isa(op) || + op->hasTrait()) { auto producerTaskIds = getAsyncTaskIds(op); if (producerTaskIds.empty() || producerTaskIds.size() > 1) { LLVM_DEBUG({ @@ -425,7 +523,11 @@ void collectAsyncChannels(SmallVector> &channels, if (result.use_empty()) { continue; } - for (Operation *userOp : result.getUsers()) { + + SetVector> users; + getTransitiveUsers(result, users); + for (auto pc : users) { + auto userOp = pc.first; auto consumerTaskIds = getAsyncTaskIds(userOp); if (consumerTaskIds.empty()) continue; @@ -435,9 +537,9 @@ void collectAsyncChannels(SmallVector> &channels, consumerTaskIds.erase(iter, consumerTaskIds.end()); // Add a channel from the single producer task to consumerTaskIds. if (consumerTaskIds.size() > 0) { - channels.push_back( - std::make_unique(producerTaskId, consumerTaskIds, op, - userOp, result, producerNumBuffers)); + channels.push_back(std::make_unique( + producerTaskId, consumerTaskIds, userOp, pc.second, + producerNumBuffers)); } } } @@ -448,43 +550,40 @@ void collectAsyncChannels(SmallVector> &channels, LDBG("Async channels:"); for (auto &channel : channels) { LDBG("producer op: " << channel->relation.first); - channel->srcOp->dump(); + channel->getSrcOp()->dump(); for (auto &asyncTaskId : channel->relation.second) LDBG("consumer: " << asyncTaskId); - channel->dstOp->dump(); + channel->getDstOp()->dump(); LDBG("numBuffers: " << channel->numBuffers); } }); } -// Update map, which will be keyed by dstOp of the channel. Use mapKeyVec to -// enforce deterministic order for map. -void groupChannels(SmallVector &channels, - DenseMap> &map, - SmallVector &mapKeyVec) { +// Update map, which will be keyed by getDstOp() of the channel. Use mapKeyVec +// to enforce deterministic order for map. +void groupChannels( + SmallVector &channels, + DenseMap> &groupedChannels, + SmallVector &mapKeyVec) { // Two channels can be combined if // src1 and src2 are in the same block and // (dst1 == dst2 or // (dst1 and dst2 are in the same block, both have a single user, and // dst1User == dst2User and dst1User is in the same block as dst1)) auto channelCanBeMerged = [](Channel *c1, Channel *c2) -> bool { - if (c1->srcOp->getBlock() != c2->srcOp->getBlock()) + if (c1->getSrcOp()->getBlock() != c2->getSrcOp()->getBlock()) return false; - Operation *dst1 = c1->dstOp, *dst2 = c2->dstOp; + Operation *dst1 = c1->getDstOp(), *dst2 = c2->getDstOp(); if (dst1 == dst2) return true; - if (dst1->getBlock() != dst2->getBlock() || !dst1->hasOneUse() || - !dst2->hasOneUse()) - return false; - Operation *dst1User = *(dst1->getUsers().begin()); - Operation *dst2User = *(dst2->getUsers().begin()); - return dst1User == dst2User && dst1User->getBlock() == dst1->getBlock(); + return dst1->getBlock() == dst2->getBlock(); }; assert(channels.size() > 0 && "channel size is zero"); - // Compare with existing channels in the map to see if it can be combined. + // Compare with existing channels in the groupedChannels to see if it can be + // combined. for (auto *c0 : channels) { bool merged = false; - for (auto &kv : map) { + for (auto &kv : groupedChannels) { if (kv.second.size() > 0 && channelCanBeMerged(c0, kv.second.front())) { kv.second.push_back(c0); merged = true; @@ -492,35 +591,51 @@ void groupChannels(SmallVector &channels, } } if (!merged) { // Create a new entry. - auto *keyOp = c0->dstOp; - if (!map.count(keyOp)) + auto *keyOp = c0->getDstOp(); + if (!groupedChannels.count(keyOp)) mapKeyVec.push_back(keyOp); - map[keyOp].push_back(c0); + groupedChannels[keyOp].push_back(c0); } } // Reorder channels associated with one entry based on program order of the // producers. - for (auto &kv : map) { + for (auto &kv : groupedChannels) { if (kv.second.size() > 1) { - auto &allOps = kv.second.front()->srcOp->getBlock()->getOperations(); + auto &allOps = kv.second.front()->getSrcOp()->getBlock()->getOperations(); std::sort( kv.second.begin(), kv.second.end(), [&](Channel *a, Channel *b) { auto itrA = std::find_if(allOps.begin(), allOps.end(), [&](Operation &op) { Operation *opPointer = &op; - return opPointer == a->srcOp; + return opPointer == a->getSrcOp(); }); auto itrB = std::find_if(allOps.begin(), allOps.end(), [&](Operation &op) { Operation *opPointer = &op; - return opPointer == b->srcOp; + return opPointer == b->getSrcOp(); }); assert(itrA != allOps.end() && itrB != allOps.end()); return std::distance(itrA, itrB) < 0; }); } } + + LLVM_DEBUG({ + LDBG("\n\nGrouped channels:"); + for (auto &kv : groupedChannels) { + DBGS() << "consumer op: "; + kv.getFirst()->dump(); + for (auto &channel : kv.second) { + DBGS() << "producer: "; + channel->getSrcOp()->dump(); + for (auto &asyncTaskId : channel->relation.second) + DBGS() << asyncTaskId << ", "; + DBGS() << "] "; + LDBG("numBuffers: " << channel->numBuffers); + } + } + }); } // Reorder producer ops to unblock consumers interleavingly. @@ -529,9 +644,9 @@ void reorderProducerOps(SmallVector &channels) { return; // Bail out if channels are not in the same block - auto block = channels.front()->srcOp->getBlock(); + auto block = channels.front()->getSrcOp()->getBlock(); for (auto &channel : channels) { - if (channel->srcOp->getBlock() != block) { + if (channel->getSrcOp()->getBlock() != block) { return; } } @@ -560,11 +675,11 @@ void reorderProducerOps(SmallVector &channels) { // Start from the first producer in channels. Iterate through the groups // which are ordered by the first consumer taskId. Within each group, channels // are ordered by number of consumers. - Operation *currOp = channels.front()->srcOp; + Operation *currOp = channels.front()->getSrcOp(); for (auto &group : groupedProducerOps) { for (auto &channel : group.second) { - channel->srcOp->moveAfter(currOp); - currOp = channel->srcOp; + channel->getSrcOp()->moveAfter(currOp); + currOp = channel->getSrcOp(); } } @@ -577,10 +692,10 @@ void reorderProducerOps(SmallVector &channels) { BackwardSliceOptions opt; opt.omitBlockArguments = true; SetVector backwardSlice; - getBackwardSlice(channel->srcOp, &backwardSlice, opt); + getBackwardSlice(channel->getSrcOp(), &backwardSlice, opt); for (auto &op : backwardSlice) { if (op->getBlock() == block) - op->moveBefore(channel->srcOp); + op->moveBefore(channel->getSrcOp()); } } } @@ -742,16 +857,16 @@ scf::ForOp createNewLoop(scf::ForOp forOp, int numBuffers, return newForOp; } -// Find top-level ops which contain at least one channel. If a channel's srcOp -// and dstOp belong to the inner loop, the outer loop will be part of -// asyncTaskOps. +// Find top-level ops which contain at least one channel. If a channel's +// getSrcOp() and getDstOp() belong to the inner loop, the outer loop will be +// part of asyncTaskOps. SmallVector getTaskTopRegion(triton::FuncOp funcOp, const SmallVector &channels) { SmallVector asyncTaskOps; auto isAsyncTaskTopOp = [&](Operation *taskTopOp) -> bool { for (auto c : channels) { - Operation *producer = c->srcOp, *consumer = c->dstOp; + Operation *producer = c->getSrcOp(), *consumer = c->getDstOp(); while (producer && !isa(producer->getParentOp())) { producer = producer->getParentOp(); } @@ -815,6 +930,11 @@ void appendBufferIdxArgs(SmallVector &taskTopOps, *asyncTaskLoopForItr = newForOp.getOperation(); } } + + LLVM_DEBUG({ + LDBG("\n\nafter appendBufferIdxArgs"); + taskTopOps[0]->getParentOfType().dump(); + }); } // Create an allocation to hold the mbarriers. @@ -847,25 +967,25 @@ static Value createBarrierAlloc(triton::FuncOp funcOp, unsigned distance) { return barrierAlloc; } -// map: channels are grouped together. +// groupedChannels: channels are grouped together. // Go through each group, check the first channel in the group, create a token // for each consumer taskId. Return a map that maps each channel + consumer // taskId to a token. Also update barrierAllocMap that maps each channel + // consumer taskId to a BarrierAlloc. -DenseMap> -createToken(const DenseMap> &map, - const SmallVector &mapKeyVec, triton::FuncOp funcOp, - int numConsumerGroups, - DenseMap> &barrierAllocMap) { +DenseMap> createToken( + const DenseMap> &groupedChannels, + const SmallVector &mapKeyVec, triton::FuncOp funcOp, + int numConsumerGroups, + DenseMap> &barrierAllocMap) { DenseMap> ret; OpBuilder builder(funcOp); builder.setInsertionPointToStart(&(funcOp.getBody().front())); for (auto *key : mapKeyVec) { - auto it = map.find(key); + auto it = groupedChannels.find(key); Channel *channel = it->second.front(); for (auto consumerAsyncTaskId : channel->relation.second) { Value v; - if (it->second.front()->srcOp->getParentOfType()) { + if (it->second.front()->getSrcOp()->getParentOfType()) { v = builder.create(funcOp.getLoc(), channel->numBuffers); } else { @@ -875,7 +995,7 @@ createToken(const DenseMap> &map, for (auto &c : it->second) ret[c][consumerAsyncTaskId] = v; - auto producerOp = it->second.front()->srcOp; + auto producerOp = it->second.front()->getSrcOp(); if (isa(producerOp)) { Value bAlloc = createBarrierAlloc(funcOp, channel->numBuffers); // Channels in the group share the same set of tokens. @@ -889,17 +1009,39 @@ createToken(const DenseMap> &map, return ret; } -// Create a buffer array for each channel, if the producer is in a ForOp, +// Create a buffer array for each producer op, if the producer is in a ForOp, // the buffer array will contain numBuffers. DenseMap createBuffer(const SmallVector &channels, triton::FuncOp funcOp, int numConsumerGroups) { + + // Group channels by producer op. + DenseMap> groupedChannels; + for (auto channel : channels) { + groupedChannels[channel->getSrcOperand()].push_back(channel); + } + +#ifndef NDEBUG + // Some sanity checks. + for (auto &item : groupedChannels) { + auto &channels = item.second; + unsigned numBuffers = channels.front()->numBuffers; + for (auto c : channels) { + assert(c->numBuffers == numBuffers && "Unmatched number of buffers"); + } + } +#endif + DenseMap bufferMap; MLIRContext *context = funcOp.getContext(); OpBuilder builder(funcOp); - builder.setInsertionPointToStart(&(funcOp.getBody().front())); - for (const auto &c : channels) { - if (auto tensorType = dyn_cast(c->srcOperand.getType())) { + for (auto &item : groupedChannels) { + auto &channels = item.second; + auto srcValue = item.first; + auto srcOp = srcValue.getDefiningOp(); + unsigned numBuffers = channels.front()->numBuffers; + + if (auto tensorType = dyn_cast(srcValue.getType())) { // Get basic information from tensorType auto order = ttg::getOrder(tensorType.getEncoding()); auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); @@ -914,8 +1056,8 @@ DenseMap createBuffer(const SmallVector &channels, // Get shape, layout and type of the complete buffer SmallVector bufferShape(sliceShape.begin(), sliceShape.end()); - if (c->srcOp->getParentOfType()) - bufferShape.insert(bufferShape.begin(), c->numBuffers); + if (srcOp->getParentOfType()) + bufferShape.insert(bufferShape.begin(), numBuffers); else bufferShape.insert(bufferShape.begin(), 1); Attribute sharedMemorySpace = @@ -925,15 +1067,13 @@ DenseMap createBuffer(const SmallVector &channels, Type memdescType = tt::MemDescType::get(bufferShape, elemType, sharedLayout, sharedMemorySpace, /*mutableMemory*/ true); - Value buffer; - if (isa(c->srcOp)) { - buffer = - builder.create(funcOp.getLoc(), memdescType); - } else { - buffer = builder.create(funcOp.getLoc(), memdescType, - c->srcOperand); - } - bufferMap[c] = buffer; + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + Value buffer = + builder.create(funcOp.getLoc(), memdescType); + + // Channels in the group share the same buffer. + for (auto c : channels) + bufferMap[c] = buffer; } else { llvm_unreachable("Unexpected result type"); } @@ -991,7 +1131,7 @@ static Operation *createAsyncCopy(const DenseMap &bufferMap, // Extract part. builder.setAsyncTaskIdsFromValueUsers(loadResult); - builder.setInsertionPoint(c->dstOp); + builder.setInsertionPoint(c->getDstOp()); SmallVector loadOffsets(sliceType.getRank() + 1, zero); loadOffsets[0] = bufferIdxExtract; auto viewLoad = builder.createWithAsyncTaskIds( @@ -999,11 +1139,72 @@ static Operation *createAsyncCopy(const DenseMap &bufferMap, auto sharedLoad = builder.createWithAsyncTaskIds( loadOp.getLoc(), loadOp.getType(), viewLoad /*,wait->getResult(0)*/); // Replace all uses of loadResult + // TODO: replace the real consumer loadResult.replaceAllUsesWith(sharedLoad.getResult()); loadOp.erase(); return copy; } +static Operation *createLocalCopy(const DenseMap &bufferMap, + Channel *channel, Value srcBufferIdx, + Value dstBufferIdx) { + Operation *srcOp = channel->getSrcOp(); + Operation *dstOp = channel->getDstOp(); + MLIRContext *context = srcOp->getContext(); + auto buffer = bufferMap.find(channel)->second; + + Value srcValue = channel->getSrcOperand(); + auto tensorType = dyn_cast(srcValue.getType()); + if (!tensorType) + return nullptr; + // Get basic information from tensorType + auto order = ttg::getOrder(tensorType.getEncoding()); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::SharedEncodingAttr::get(context, sliceShape, order, + CTALayout, elemType); + auto sliceType = RankedTensorType::get(sliceShape, elemType, sharedLayout); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + tt::MemDescType subviewTy = + tt::MemDescType::get(sliceType.getShape(), sliceType.getElementType(), + sliceType.getEncoding(), sharedMemorySpace, + /*mutableMemory=*/true); + + // Consumer part. + OpBuilderWithAsyncTaskIds dstBuilder(dstOp); + dstBuilder.setInsertionPoint(dstOp); + Value zero = dstBuilder.createWithAsyncTaskIds( + srcOp->getLoc(), 0, 32); + SmallVector loadOffsets(sliceType.getRank() + 1, zero); + loadOffsets[0] = dstBufferIdx; + auto dstView = dstBuilder.createWithAsyncTaskIds( + srcOp->getLoc(), subviewTy, buffer, loadOffsets); + auto sharedLoad = dstBuilder.createWithAsyncTaskIds( + srcOp->getLoc(), srcValue.getType(), dstView); + srcValue.replaceAllUsesWith(sharedLoad.getResult()); + + // Producer part. + OpBuilderWithAsyncTaskIds srcBuilder(srcOp); + srcBuilder.setInsertionPoint(srcOp->getParentOp()); + zero = srcBuilder.createWithAsyncTaskIds( + srcOp->getLoc(), 0, 32); + SmallVector storeOffsets(sliceType.getRank() + 1, zero); + storeOffsets[0] = srcBufferIdx; + srcBuilder.setInsertionPointAfter(srcOp); + auto srcView = srcBuilder.createWithAsyncTaskIds( + srcOp->getLoc(), subviewTy, buffer, storeOffsets); + // Create local_alloc + Operation *copy = srcBuilder.createWithAsyncTaskIds( + srcOp->getLoc(), srcValue, srcView); + + return copy; +} + static int getTMALoadSize(tt::ExperimentalDescriptorLoadOp &tmaLoad) { auto tensorTy = cast(tmaLoad->getResult(0).getType()); int loadSize = product(tensorTy.getShape()); @@ -1121,10 +1322,10 @@ optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder, return copy; } -// Lower producers for channels. Here channels are grouped in "map". tokenMap -// tracks the set of tokens for each channel. +// Lower producers for channels. Here channels are grouped in "groupedChannels". +// tokenMap tracks the set of tokens for each channel. void buildAsyncComm( - const DenseMap> &map, + const DenseMap> &groupedChannels, const DenseMap> &tokenMap, const DenseMap> &barrierAllocMap, const DenseMap &bufferMap, int numConsumerGroups) { @@ -1170,11 +1371,11 @@ void buildAsyncComm( }; // Go through each channel group. - for (auto kv : map) { - auto headProducer = kv.second.front()->srcOp; - auto tailProducer = kv.second.back()->srcOp; - auto headConsumer = kv.second.front()->dstOp; - auto tailConsumer = kv.second.back()->dstOp; + for (auto kv : groupedChannels) { + auto headProducer = kv.second.front()->getSrcOp(); + auto tailProducer = kv.second.back()->getSrcOp(); + auto headConsumer = kv.second.front()->getDstOp(); + auto tailConsumer = kv.second.back()->getDstOp(); // We have one set of tokens for each channel group. auto tokens = tokenMap.find(kv.second.front())->second; @@ -1205,8 +1406,6 @@ void buildAsyncComm( headProducer->getLoc(), 0, 1); } - assert((isa(headProducer)) && - "producer must be a LoadOp or tma LoadOp"); builder.setAsynTaskIdsFromArray(asyncTaskP); for (auto token : tokens) { // Insert ProducerAcquireOp before the producer. @@ -1216,7 +1415,7 @@ void buildAsyncComm( // Insert ProducerCommitOp if producer is LoadOp. For TMA, TMA lowering // will handle the ProducerCommit. - if (isa(headProducer)) { + if (!isa(headProducer)) { builder.setInsertionPointAfter(tailProducer); builder.createWithAsyncTaskIds( tailProducer->getLoc(), token.second, bufferIdx); @@ -1245,18 +1444,17 @@ void buildAsyncComm( SmallVector buffers; // Go through all channels in this channel group. for (auto &c : kv.second) { - assert( - (isa(c->srcOp)) && - "producer must be a LoadOp or tma LoadOp"); - bool insideLoop = c->srcOp->getParentOfType() != nullptr; - if (isa(c->srcOp)) { - // After createAsyncCopy, c->srcOp/headProducer are no longer valid. - createAsyncCopy(bufferMap, c, c->srcOp, asyncTasksPC, bufferIdx, + if (isa(c->getSrcOp())) { + // After createAsyncCopy, c->getSrcOp()/headProducer are no longer + // valid. + createAsyncCopy(bufferMap, c, c->getSrcOp(), asyncTasksPC, bufferIdx, bufferIdx); - } else if (auto tmaLoad = - dyn_cast(c->srcOp)) { + } else if (auto tmaLoad = dyn_cast( + c->getSrcOp())) { tmaLoads.push_back(tmaLoad); buffers.push_back(bufferMap.find(c)->second); + } else { + createLocalCopy(bufferMap, c, bufferIdx, bufferIdx); } } @@ -1272,6 +1470,25 @@ void buildAsyncComm( } } +void foldLocalLoads(triton::FuncOp funcOp) { + // If loadResult has a single use which is LocalAlloc, we can get rid of + // sharedLoad and replace all uses of LocalAlloc with viewLoad. + DenseMap opsToReplace; + funcOp.walk([&](ttg::LocalAllocOp localAlloc) { + if (auto src = localAlloc.getSrc()) { + if (auto localLoad = dyn_cast(src.getDefiningOp())) { + // Only fold within the same tasks + if (getAsyncTaskIds(localLoad) == getAsyncTaskIds(localAlloc)) { + opsToReplace[localAlloc] = localLoad.getSrc(); + } + } + } + }); + OpBuilderWithAsyncTaskIds builder(funcOp.getContext()); + for (auto kv : opsToReplace) + replaceUsesAndPropagateType(builder, kv.getFirst(), kv.getSecond()); +} + class TritonGPUWSCodePartitionPass : public impl::TritonGPUWSCodePartitionBase { public: @@ -1294,10 +1511,11 @@ class TritonGPUWSCodePartitionPass return; } - // Step 2: group channels where each entry of the map is keyed by the dstOp. - DenseMap> map; + // Step 2: group channels where each entry of the groupedChannels is keyed + // by the dstOp. + DenseMap> groupedChannels; SmallVector mapKeyVec; - groupChannels(channels, map, mapKeyVec); + groupChannels(channels, groupedChannels, mapKeyVec); // Step 3: reorder producer ops and the backward slices of the producer ops. reorderProducerOps(channels); @@ -1316,8 +1534,8 @@ class TritonGPUWSCodePartitionPass // Step 5: Create tokens, and buffers. A set of tokens for each group of // channels and an array of buffers for each channel. DenseMap> barrierAllocMap; - DenseMap> tokenMap = - createToken(map, mapKeyVec, funcOp, numConsumerGroups, barrierAllocMap); + DenseMap> tokenMap = createToken( + groupedChannels, mapKeyVec, funcOp, numConsumerGroups, barrierAllocMap); DenseMap bufferMap = createBuffer(channels, funcOp, numConsumerGroups); LLVM_DEBUG({ @@ -1327,7 +1545,7 @@ class TritonGPUWSCodePartitionPass // Step 6: add async communication ops (ProducerAcquire etc). Also lower the // loads. - buildAsyncComm(map, tokenMap, barrierAllocMap, bufferMap, + buildAsyncComm(groupedChannels, tokenMap, barrierAllocMap, bufferMap, numConsumerGroups); LLVM_DEBUG({ LDBG("\n\nwith SyncOps"); @@ -1337,16 +1555,7 @@ class TritonGPUWSCodePartitionPass // If loadResult has a single use which is LocalAlloc, we can get rid of // sharedLoad and replace all uses of LocalAlloc with viewLoad. DenseMap opsToReplace; - funcOp.walk([&](ttg::LocalAllocOp localAlloc) { - if (auto src = localAlloc.getSrc()) { - if (auto localLoad = dyn_cast(src.getDefiningOp())) { - opsToReplace[localAlloc] = localLoad.getSrc(); - } - } - }); - OpBuilderWithAsyncTaskIds builder(funcOp.getContext()); - for (auto kv : opsToReplace) - replaceUsesAndPropagateType(builder, kv.getFirst(), kv.getSecond()); + foldLocalLoads(funcOp); LLVM_DEBUG({ LDBG("\n\nsimplify localLoad + localAlloc"); funcOp.dump(); From 9a5aa57b79d581576e07286e93a82045fe386585 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Thu, 5 Dec 2024 23:18:18 -0800 Subject: [PATCH 3/8] createLocalCopy --- .../TritonGPU/Transforms/WSCodePartition.cpp | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp index 4daae21fce..d85d668ff3 100644 --- a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp @@ -1145,18 +1145,22 @@ static Operation *createAsyncCopy(const DenseMap &bufferMap, return copy; } -static Operation *createLocalCopy(const DenseMap &bufferMap, - Channel *channel, Value srcBufferIdx, - Value dstBufferIdx) { +static void createLocalCopy(const DenseMap &bufferMap, + DenseMap &producerCopyMap, + Channel *channel, Value srcBufferIdx, + Value dstBufferIdx) { Operation *srcOp = channel->getSrcOp(); Operation *dstOp = channel->getDstOp(); + if (producerCopyMap.contains(srcOp)) + return; + MLIRContext *context = srcOp->getContext(); auto buffer = bufferMap.find(channel)->second; Value srcValue = channel->getSrcOperand(); auto tensorType = dyn_cast(srcValue.getType()); if (!tensorType) - return nullptr; + return; // Get basic information from tensorType auto order = ttg::getOrder(tensorType.getEncoding()); auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); @@ -1177,7 +1181,7 @@ static Operation *createLocalCopy(const DenseMap &bufferMap, // Consumer part. OpBuilderWithAsyncTaskIds dstBuilder(dstOp); - dstBuilder.setInsertionPoint(dstOp); + dstBuilder.setInsertionPointAfter(srcOp); Value zero = dstBuilder.createWithAsyncTaskIds( srcOp->getLoc(), 0, 32); SmallVector loadOffsets(sliceType.getRank() + 1, zero); @@ -1188,7 +1192,7 @@ static Operation *createLocalCopy(const DenseMap &bufferMap, srcOp->getLoc(), srcValue.getType(), dstView); srcValue.replaceAllUsesWith(sharedLoad.getResult()); - // Producer part. + // Producer part. Create local_store for new producers. OpBuilderWithAsyncTaskIds srcBuilder(srcOp); srcBuilder.setInsertionPoint(srcOp->getParentOp()); zero = srcBuilder.createWithAsyncTaskIds( @@ -1201,8 +1205,7 @@ static Operation *createLocalCopy(const DenseMap &bufferMap, // Create local_alloc Operation *copy = srcBuilder.createWithAsyncTaskIds( srcOp->getLoc(), srcValue, srcView); - - return copy; + producerCopyMap[srcOp] = copy; } static int getTMALoadSize(tt::ExperimentalDescriptorLoadOp &tmaLoad) { @@ -1442,6 +1445,7 @@ void buildAsyncComm( SmallVector tmaLoads; SmallVector buffers; + DenseMap producerCopyMap; // Go through all channels in this channel group. for (auto &c : kv.second) { if (isa(c->getSrcOp())) { @@ -1454,7 +1458,7 @@ void buildAsyncComm( tmaLoads.push_back(tmaLoad); buffers.push_back(bufferMap.find(c)->second); } else { - createLocalCopy(bufferMap, c, bufferIdx, bufferIdx); + createLocalCopy(bufferMap, producerCopyMap, c, bufferIdx, bufferIdx); } } From dfd47a2462411ca6e2502cf2e79070fb1a7703ac Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Fri, 6 Dec 2024 15:29:50 -0800 Subject: [PATCH 4/8] group channels --- .../TritonGPU/Transforms/WSCodePartition.cpp | 334 ++++++++++-------- 1 file changed, 190 insertions(+), 144 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp index d85d668ff3..6ea0596ca2 100644 --- a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp @@ -432,70 +432,6 @@ void getTransitiveUsers(Value root, } } -// Value findTransitiveProducer(Value root) { -// SetVector users; -// getTransitiveUsers(root, users); -// assert(!users.empty() && "producer op has no consumers"); -// if (users.size() == 1) { -// return root; -// } - -// // Reconsile the user list. If there are multiple users, we need to -// // find a global producer in the same scope with the outtermost-scoped -// user. Region *currentScope = root.getDefiningOp()->getParentRegion(); -// Region *userScope = nullptr; - -// // check all users have same scope -// for (Operation *user : users) { -// auto currentScope = user->getParentRegion(); -// if (userScope) - -// auto parentRegions = getAllParentRegions(user); -// if (parentRegions.count(currentScope) == 0) { -// llvm::errs() << "user is not in the same scope as the producer\n"; -// llvm_unreachable("user is not in the same scope as the producer"); -// } - -// if (userScope->getParentOfType() == nullptr) { -// newScope = userScope; -// } -// } - -// auto getAllParentRegions = [](Operation *op) { -// SetVector parentRegions; -// Region *region = op->getParentRegion(); -// while (region) { -// parentRegions.insert(region); -// mlir::Operation *parentOp = region->getParentOp(); -// if (!parentOp) -// break; -// region = parentOp->getParentRegion(); -// } - -// return parentRegions; -// }; - -// // Find the outtermost-scoped user. -// for (Operation *user : users) { -// auto parentRegions = getAllParentRegions(user); -// if (parentRegions.count(currentScope) == 0) { -// llvm::errs() << "user is not in the same scope as the producer\n"; -// llvm_unreachable("user is not in the same scope as the producer"); -// } - -// if (userScope->getParentOfType() == nullptr) { -// newScope = userScope; -// } -// } - -// // Find the global producer in the same scope with the outtermost-scoped -// user. - -// // Promote producer to the same level as the consumer. - -// return root; -// } - // Loads will be in producer warp groups. For now, we only allow a single // warp group/task for a producer. For each LoadOp, create a channel from it // to any direct user which belongs to a different taskId. @@ -559,12 +495,34 @@ void collectAsyncChannels(SmallVector> &channels, }); } -// Update map, which will be keyed by getDstOp() of the channel. Use mapKeyVec -// to enforce deterministic order for map. +// Update map, which will be keyed by getDstOp() of the channel. Use +// orderedChannels to enforce deterministic order for map. void groupChannels( SmallVector &channels, - DenseMap> &groupedChannels, - SmallVector &mapKeyVec) { + DenseMap> &channelsGroupedByProducers, + DenseMap> &channelsGroupedByConsumers, + SmallVector &orderedChannels) { + + // Group channels by producer op. + DenseMap> producerChannels; + for (auto channel : channels) { + producerChannels[channel->getSrcOp()].push_back(channel); + } + +#ifndef NDEBUG + // Some sanity checks. + for (auto &item : producerChannels) { + auto &channels = item.second; + unsigned numBuffers = channels.front()->numBuffers; + for (auto c : channels) { + assert(c->numBuffers == numBuffers && "Unmatched number of buffers"); + } + } +#endif + + // Group channels by consumer op. + DenseMap> consumerChannels; + // Two channels can be combined if // src1 and src2 are in the same block and // (dst1 == dst2 or @@ -579,11 +537,11 @@ void groupChannels( return dst1->getBlock() == dst2->getBlock(); }; assert(channels.size() > 0 && "channel size is zero"); - // Compare with existing channels in the groupedChannels to see if it can be - // combined. + // Compare with existing channels in the consumerChannels to see if + // it can be combined. for (auto *c0 : channels) { bool merged = false; - for (auto &kv : groupedChannels) { + for (auto &kv : consumerChannels) { if (kv.second.size() > 0 && channelCanBeMerged(c0, kv.second.front())) { kv.second.push_back(c0); merged = true; @@ -592,15 +550,15 @@ void groupChannels( } if (!merged) { // Create a new entry. auto *keyOp = c0->getDstOp(); - if (!groupedChannels.count(keyOp)) - mapKeyVec.push_back(keyOp); - groupedChannels[keyOp].push_back(c0); + if (!consumerChannels.count(keyOp)) + orderedChannels.push_back(c0); + consumerChannels[keyOp].push_back(c0); } } // Reorder channels associated with one entry based on program order of the // producers. - for (auto &kv : groupedChannels) { + for (auto &kv : consumerChannels) { if (kv.second.size() > 1) { auto &allOps = kv.second.front()->getSrcOp()->getBlock()->getOperations(); std::sort( @@ -621,11 +579,38 @@ void groupChannels( } } + // Switch to using channel as the key instead of ops as ops can be volatile. + for (auto &kv : producerChannels) { + channelsGroupedByProducers[kv.second.front()] = kv.second; + } + for (auto &kv : consumerChannels) { + channelsGroupedByConsumers[kv.second.front()] = kv.second; + } + LLVM_DEBUG({ - LDBG("\n\nGrouped channels:"); - for (auto &kv : groupedChannels) { - DBGS() << "consumer op: "; - kv.getFirst()->dump(); + DBGS() << "\n\n"; + LDBG("Grouped channels by producer:"); + unsigned i = 0; + for (auto &kv : channelsGroupedByProducers) { + DBGS() << "Channel " << ++i << ":\n"; + DBGS() << "producer: "; + kv.getFirst()->getSrcOp()->dump(); + for (auto &channel : kv.second) { + DBGS() << "consumer: "; + channel->getDstOp()->dump(); + DBGS() << "] "; + LDBG("numBuffers: " << channel->numBuffers); + DBGS() << "\n"; + } + } + + DBGS() << "\n\n"; + LDBG("Grouped channels by consumer:"); + i = 0; + for (auto &kv : channelsGroupedByConsumers) { + DBGS() << "Channel " << ++i << ":\n"; + DBGS() << "consumer: "; + kv.getFirst()->getDstOp()->dump(); for (auto &channel : kv.second) { DBGS() << "producer: "; channel->getSrcOp()->dump(); @@ -633,7 +618,9 @@ void groupChannels( DBGS() << asyncTaskId << ", "; DBGS() << "] "; LDBG("numBuffers: " << channel->numBuffers); + DBGS() << "\n"; } + DBGS() << "\n"; } }); } @@ -930,11 +917,6 @@ void appendBufferIdxArgs(SmallVector &taskTopOps, *asyncTaskLoopForItr = newForOp.getOperation(); } } - - LLVM_DEBUG({ - LDBG("\n\nafter appendBufferIdxArgs"); - taskTopOps[0]->getParentOfType().dump(); - }); } // Create an allocation to hold the mbarriers. @@ -967,21 +949,22 @@ static Value createBarrierAlloc(triton::FuncOp funcOp, unsigned distance) { return barrierAlloc; } -// groupedChannels: channels are grouped together. +// channelsGroupedByConsumers: channels are grouped together. // Go through each group, check the first channel in the group, create a token // for each consumer taskId. Return a map that maps each channel + consumer // taskId to a token. Also update barrierAllocMap that maps each channel + // consumer taskId to a BarrierAlloc. -DenseMap> createToken( - const DenseMap> &groupedChannels, - const SmallVector &mapKeyVec, triton::FuncOp funcOp, - int numConsumerGroups, - DenseMap> &barrierAllocMap) { +DenseMap> +createToken(const DenseMap> + &channelsGroupedByConsumers, + const SmallVector &orderedChannels, + triton::FuncOp funcOp, int numConsumerGroups, + DenseMap> &barrierAllocMap) { DenseMap> ret; OpBuilder builder(funcOp); builder.setInsertionPointToStart(&(funcOp.getBody().front())); - for (auto *key : mapKeyVec) { - auto it = groupedChannels.find(key); + for (auto *key : orderedChannels) { + auto it = channelsGroupedByConsumers.find(key); Channel *channel = it->second.front(); for (auto consumerAsyncTaskId : channel->relation.second) { Value v; @@ -1011,34 +994,17 @@ DenseMap> createToken( // Create a buffer array for each producer op, if the producer is in a ForOp, // the buffer array will contain numBuffers. -DenseMap createBuffer(const SmallVector &channels, - triton::FuncOp funcOp, - int numConsumerGroups) { - - // Group channels by producer op. - DenseMap> groupedChannels; - for (auto channel : channels) { - groupedChannels[channel->getSrcOperand()].push_back(channel); - } - -#ifndef NDEBUG - // Some sanity checks. - for (auto &item : groupedChannels) { - auto &channels = item.second; - unsigned numBuffers = channels.front()->numBuffers; - for (auto c : channels) { - assert(c->numBuffers == numBuffers && "Unmatched number of buffers"); - } - } -#endif +DenseMap createBuffer( + DenseMap> &channelsGroupedByProducers, + triton::FuncOp funcOp, int numConsumerGroups) { DenseMap bufferMap; MLIRContext *context = funcOp.getContext(); OpBuilder builder(funcOp); - for (auto &item : groupedChannels) { + for (auto &item : channelsGroupedByProducers) { auto &channels = item.second; - auto srcValue = item.first; - auto srcOp = srcValue.getDefiningOp(); + auto srcValue = item.first->getSrcOperand(); + auto srcOp = item.first->getSrcOp(); unsigned numBuffers = channels.front()->numBuffers; if (auto tensorType = dyn_cast(srcValue.getType())) { @@ -1146,14 +1112,10 @@ static Operation *createAsyncCopy(const DenseMap &bufferMap, } static void createLocalCopy(const DenseMap &bufferMap, - DenseMap &producerCopyMap, Channel *channel, Value srcBufferIdx, Value dstBufferIdx) { Operation *srcOp = channel->getSrcOp(); Operation *dstOp = channel->getDstOp(); - if (producerCopyMap.contains(srcOp)) - return; - MLIRContext *context = srcOp->getContext(); auto buffer = bufferMap.find(channel)->second; @@ -1205,7 +1167,6 @@ static void createLocalCopy(const DenseMap &bufferMap, // Create local_alloc Operation *copy = srcBuilder.createWithAsyncTaskIds( srcOp->getLoc(), srcValue, srcView); - producerCopyMap[srcOp] = copy; } static int getTMALoadSize(tt::ExperimentalDescriptorLoadOp &tmaLoad) { @@ -1325,10 +1286,12 @@ optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder, return copy; } -// Lower producers for channels. Here channels are grouped in "groupedChannels". -// tokenMap tracks the set of tokens for each channel. -void buildAsyncComm( - const DenseMap> &groupedChannels, +// Lower producers for channels. Here channels are grouped in +// "channelsGroupedByConsumers". tokenMap tracks the set of tokens for each +// channel. +void insertAsyncComm( + const DenseMap> + &channelsGroupedByConsumers, const DenseMap> &tokenMap, const DenseMap> &barrierAllocMap, const DenseMap &bufferMap, int numConsumerGroups) { @@ -1374,7 +1337,7 @@ void buildAsyncComm( }; // Go through each channel group. - for (auto kv : groupedChannels) { + for (auto kv : channelsGroupedByConsumers) { auto headProducer = kv.second.front()->getSrcOp(); auto tailProducer = kv.second.back()->getSrcOp(); auto headConsumer = kv.second.front()->getDstOp(); @@ -1448,17 +1411,10 @@ void buildAsyncComm( DenseMap producerCopyMap; // Go through all channels in this channel group. for (auto &c : kv.second) { - if (isa(c->getSrcOp())) { - // After createAsyncCopy, c->getSrcOp()/headProducer are no longer - // valid. - createAsyncCopy(bufferMap, c, c->getSrcOp(), asyncTasksPC, bufferIdx, - bufferIdx); - } else if (auto tmaLoad = dyn_cast( - c->getSrcOp())) { + if (auto tmaLoad = + dyn_cast(c->getSrcOp())) { tmaLoads.push_back(tmaLoad); buffers.push_back(bufferMap.find(c)->second); - } else { - createLocalCopy(bufferMap, producerCopyMap, c, bufferIdx, bufferIdx); } } @@ -1474,6 +1430,85 @@ void buildAsyncComm( } } +// Lower producers for channels. Here channels are grouped in +// "channelsGroupedByProducers" +void insertAsyncCopy(triton::FuncOp funcOp, + const DenseMap> + &channelsGroupedByProducers, + const DenseMap &bufferMap) { + + auto getAsyncTasks = [&](Operation *p, Operation *c, + SmallVector &asyncTaskP, + SmallVector &asyncTaskC, + SmallVector &asyncTasksPC) -> void { + asyncTaskP = getNestedAsyncTaskIds(p); + asyncTaskC = getNestedAsyncTaskIds(c); + asyncTasksPC.reserve(asyncTaskP.size() + asyncTaskC.size()); + asyncTasksPC.insert(asyncTasksPC.end(), asyncTaskP.begin(), + asyncTaskP.end()); + asyncTasksPC.insert(asyncTasksPC.end(), asyncTaskC.begin(), + asyncTaskC.end()); + }; + + // For each producer op, create a async_copy or local_store from the producer + // to the buffer. Create a local_load from the buffer at the dominating + // consumer. + mlir::DominanceInfo dom(funcOp); + + for (auto kv : channelsGroupedByProducers) { + // Finding the dominating channel if possible. + std::unordered_set mutuallyNonDominatingChannels; + for (auto &c : kv.second) { + // check if c is dominating all other previous channels. + auto it = mutuallyNonDominatingChannels.begin(); + while (it != mutuallyNonDominatingChannels.end()) { + auto channel = *it; + if (dom.properlyDominates(c->getDstOp(), channel->getDstOp())) { + it = mutuallyNonDominatingChannels.erase(it); + } else if (dom.properlyDominates(channel->getDstOp(), c->getDstOp())) { + break; + } else { + ++it; + } + } + if (it == mutuallyNonDominatingChannels.end()) + mutuallyNonDominatingChannels.insert(c); + } + + auto srcOp = kv.getFirst()->getSrcOp(); + Value bufferIdx; + Value phase = Value(); + if (auto forOp = srcOp->getParentOfType()) { + // We already added phase, bufferIdx to the ForOp. + auto tSize = forOp.getBody()->getArguments().size(); + assert(tSize >= 2); + bufferIdx = forOp.getBody()->getArguments().back(); + phase = forOp.getBody()->getArgument(tSize - 2); // next to last argument + } else { + llvm_unreachable("Producer is not in a ForOp"); + } + + for (auto channel : mutuallyNonDominatingChannels) { + // No need to create async copy for TMA load which is handled in + // insertAsyncComm. + if (auto tmaLoad = dyn_cast(srcOp)) { + continue; + } + if (isa(srcOp)) { + SmallVector asyncTaskP, asyncTaskC, asyncTasksPC; + getAsyncTasks(srcOp, channel->getDstOp(), asyncTaskP, asyncTaskC, + asyncTasksPC); + // After createAsyncCopy, c->getSrcOp()/headProducer are no longer + // valid. + createAsyncCopy(bufferMap, channel, channel->getSrcOp(), asyncTasksPC, + bufferIdx, bufferIdx); + } else { + createLocalCopy(bufferMap, channel, bufferIdx, bufferIdx); + } + } + } +} + void foldLocalLoads(triton::FuncOp funcOp) { // If loadResult has a single use which is LocalAlloc, we can get rid of // sharedLoad and replace all uses of LocalAlloc with viewLoad. @@ -1515,11 +1550,15 @@ class TritonGPUWSCodePartitionPass return; } - // Step 2: group channels where each entry of the groupedChannels is keyed - // by the dstOp. - DenseMap> groupedChannels; - SmallVector mapKeyVec; - groupChannels(channels, groupedChannels, mapKeyVec); + // Step 2: group channels + // - each entry of the channelsGroupedByProducers is keyed by the srcOp. + // - each entry of the channelsGroupedByConsumers is keyed by the dstOp. + DenseMap> channelsGroupedByProducers; + DenseMap> channelsGroupedByConsumers; + + SmallVector orderedChannels; + groupChannels(channels, channelsGroupedByProducers, + channelsGroupedByConsumers, orderedChannels); // Step 3: reorder producer ops and the backward slices of the producer ops. reorderProducerOps(channels); @@ -1538,10 +1577,11 @@ class TritonGPUWSCodePartitionPass // Step 5: Create tokens, and buffers. A set of tokens for each group of // channels and an array of buffers for each channel. DenseMap> barrierAllocMap; - DenseMap> tokenMap = createToken( - groupedChannels, mapKeyVec, funcOp, numConsumerGroups, barrierAllocMap); + DenseMap> tokenMap = + createToken(channelsGroupedByConsumers, orderedChannels, funcOp, + numConsumerGroups, barrierAllocMap); DenseMap bufferMap = - createBuffer(channels, funcOp, numConsumerGroups); + createBuffer(channelsGroupedByProducers, funcOp, numConsumerGroups); LLVM_DEBUG({ LDBG("\n\nafter createBuffer"); funcOp.dump(); @@ -1549,13 +1589,19 @@ class TritonGPUWSCodePartitionPass // Step 6: add async communication ops (ProducerAcquire etc). Also lower the // loads. - buildAsyncComm(groupedChannels, tokenMap, barrierAllocMap, bufferMap, - numConsumerGroups); + insertAsyncComm(channelsGroupedByConsumers, tokenMap, barrierAllocMap, + bufferMap, numConsumerGroups); LLVM_DEBUG({ LDBG("\n\nwith SyncOps"); funcOp.dump(); }); + insertAsyncCopy(funcOp, channelsGroupedByProducers, bufferMap); + LLVM_DEBUG({ + LDBG("\n\nwith async copy"); + funcOp.dump(); + }); + // If loadResult has a single use which is LocalAlloc, we can get rid of // sharedLoad and replace all uses of LocalAlloc with viewLoad. DenseMap opsToReplace; From bc9eac3984a7574f989223e039614689c7615ab2 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Fri, 6 Dec 2024 17:17:15 -0800 Subject: [PATCH 5/8] fix head/tail consumer finding --- .../TritonGPU/Transforms/WSCodePartition.cpp | 47 +++++++++++++++++-- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp index 6ea0596ca2..b83edeb73f 100644 --- a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp @@ -1338,10 +1338,49 @@ void insertAsyncComm( // Go through each channel group. for (auto kv : channelsGroupedByConsumers) { - auto headProducer = kv.second.front()->getSrcOp(); - auto tailProducer = kv.second.back()->getSrcOp(); - auto headConsumer = kv.second.front()->getDstOp(); - auto tailConsumer = kv.second.back()->getDstOp(); + // Find head and tail ops. + DenseSet producerOps; + DenseSet consumerOps; + for (auto &c : kv.second) { + producerOps.insert(c->getSrcOp()); + consumerOps.insert(c->getDstOp()); + } + + // Find head producer + auto producerBlock = kv.second.front()->getSrcOp()->getBlock(); + Operation *headProducer = nullptr; + for (auto &op : producerBlock->getOperations()) { + if (producerOps.count(&op)) { + headProducer = &op; + break; + } + } + // Find tail producer + Operation *tailProducer = nullptr; + for (auto &op : reverse(producerBlock->getOperations())) { + if (producerOps.count(&op)) { + tailProducer = &op; + break; + } + } + + // Find head consumer and tail consumer + auto consumerBlock = kv.second.front()->getDstOp()->getBlock(); + Operation *headConsumer = nullptr; + for (auto &op : consumerBlock->getOperations()) { + if (consumerOps.count(&op)) { + headConsumer = &op; + break; + } + } + Operation *tailConsumer = nullptr; + for (auto &op : reverse(consumerBlock->getOperations())) { + if (consumerOps.count(&op)) { + tailConsumer = &op; + break; + } + } + // We have one set of tokens for each channel group. auto tokens = tokenMap.find(kv.second.front())->second; From 5f19398ce4d1ad11779b62097f52658ac3aaf67a Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Fri, 6 Dec 2024 18:07:13 -0800 Subject: [PATCH 6/8] bug fixes --- .../TritonGPU/Transforms/WSCodePartition.cpp | 99 +++++++++---------- 1 file changed, 45 insertions(+), 54 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp index b83edeb73f..c688f0986c 100644 --- a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp @@ -534,7 +534,12 @@ void groupChannels( Operation *dst1 = c1->getDstOp(), *dst2 = c2->getDstOp(); if (dst1 == dst2) return true; - return dst1->getBlock() == dst2->getBlock(); + if (dst1->getBlock() != dst2->getBlock() || !dst1->hasOneUse() || + !dst2->hasOneUse()) + return false; + Operation *dst1User = *(dst1->getUsers().begin()); + Operation *dst2User = *(dst2->getUsers().begin()); + return dst1User == dst2User && dst1User->getBlock() == dst1->getBlock(); }; assert(channels.size() > 0 && "channel size is zero"); // Compare with existing channels in the consumerChannels to see if @@ -1105,7 +1110,6 @@ static Operation *createAsyncCopy(const DenseMap &bufferMap, auto sharedLoad = builder.createWithAsyncTaskIds( loadOp.getLoc(), loadOp.getType(), viewLoad /*,wait->getResult(0)*/); // Replace all uses of loadResult - // TODO: replace the real consumer loadResult.replaceAllUsesWith(sharedLoad.getResult()); loadOp.erase(); return copy; @@ -1142,30 +1146,31 @@ static void createLocalCopy(const DenseMap &bufferMap, /*mutableMemory=*/true); // Consumer part. - OpBuilderWithAsyncTaskIds dstBuilder(dstOp); - dstBuilder.setInsertionPointAfter(srcOp); - Value zero = dstBuilder.createWithAsyncTaskIds( - srcOp->getLoc(), 0, 32); + OpBuilderWithAsyncTaskIds builder(dstOp); + builder.setAsyncTaskIdsFromOp(dstOp); + builder.setInsertionPoint(dstOp); + Value zero = builder.createWithAsyncTaskIds( + dstOp->getLoc(), 0, 32); SmallVector loadOffsets(sliceType.getRank() + 1, zero); loadOffsets[0] = dstBufferIdx; - auto dstView = dstBuilder.createWithAsyncTaskIds( - srcOp->getLoc(), subviewTy, buffer, loadOffsets); - auto sharedLoad = dstBuilder.createWithAsyncTaskIds( - srcOp->getLoc(), srcValue.getType(), dstView); + auto dstView = builder.createWithAsyncTaskIds( + dstOp->getLoc(), subviewTy, buffer, loadOffsets); + auto sharedLoad = builder.createWithAsyncTaskIds( + dstOp->getLoc(), srcValue.getType(), dstView); srcValue.replaceAllUsesWith(sharedLoad.getResult()); // Producer part. Create local_store for new producers. - OpBuilderWithAsyncTaskIds srcBuilder(srcOp); - srcBuilder.setInsertionPoint(srcOp->getParentOp()); - zero = srcBuilder.createWithAsyncTaskIds( - srcOp->getLoc(), 0, 32); + builder.setAsynTaskIdsFromArray(channel->relation.first); + builder.setInsertionPoint(srcOp->getParentOp()); + zero = builder.createWithAsyncTaskIds(srcOp->getLoc(), + 0, 32); SmallVector storeOffsets(sliceType.getRank() + 1, zero); storeOffsets[0] = srcBufferIdx; - srcBuilder.setInsertionPointAfter(srcOp); - auto srcView = srcBuilder.createWithAsyncTaskIds( + builder.setInsertionPointAfter(srcOp); + auto srcView = builder.createWithAsyncTaskIds( srcOp->getLoc(), subviewTy, buffer, storeOffsets); // Create local_alloc - Operation *copy = srcBuilder.createWithAsyncTaskIds( + Operation *copy = builder.createWithAsyncTaskIds( srcOp->getLoc(), srcValue, srcView); } @@ -1323,19 +1328,6 @@ void insertAsyncComm( return nullptr; }; - auto getAsyncTasks = [&](Operation *p, Operation *c, - SmallVector &asyncTaskP, - SmallVector &asyncTaskC, - SmallVector &asyncTasksPC) -> void { - asyncTaskP = getNestedAsyncTaskIds(p); - asyncTaskC = getNestedAsyncTaskIds(c); - asyncTasksPC.reserve(asyncTaskP.size() + asyncTaskC.size()); - asyncTasksPC.insert(asyncTasksPC.end(), asyncTaskP.begin(), - asyncTaskP.end()); - asyncTasksPC.insert(asyncTasksPC.end(), asyncTaskC.begin(), - asyncTaskC.end()); - }; - // Go through each channel group. for (auto kv : channelsGroupedByConsumers) { // Find head and tail ops. @@ -1383,10 +1375,15 @@ void insertAsyncComm( // We have one set of tokens for each channel group. auto tokens = tokenMap.find(kv.second.front())->second; + auto masterChannel = kv.getFirst(); + + SmallVector asyncTaskP; + asyncTaskP.push_back(masterChannel->relation.first); + SmallVector &asyncTaskC = masterChannel->relation.second; + SmallVector asyncTasksPC = asyncTaskP; + asyncTasksPC.insert(asyncTasksPC.end(), asyncTaskC.begin(), + asyncTaskC.end()); - SmallVector asyncTaskP, asyncTaskC, asyncTasksPC; - getAsyncTasks(headProducer, headConsumer, asyncTaskP, asyncTaskC, - asyncTasksPC); OpBuilderWithAsyncTaskIds builder(headProducer->getContext()); if (auto funcOp = dyn_cast(headProducer->getParentOp())) { builder.setInsertionPointToStart(&(funcOp.getBody().front())); @@ -1411,7 +1408,7 @@ void insertAsyncComm( headProducer->getLoc(), 0, 1); } - builder.setAsynTaskIdsFromArray(asyncTaskP); + builder.setAsynTaskIdsFromArray(masterChannel->relation.first); for (auto token : tokens) { // Insert ProducerAcquireOp before the producer. builder.setInsertionPoint(headProducer); @@ -1475,20 +1472,6 @@ void insertAsyncCopy(triton::FuncOp funcOp, const DenseMap> &channelsGroupedByProducers, const DenseMap &bufferMap) { - - auto getAsyncTasks = [&](Operation *p, Operation *c, - SmallVector &asyncTaskP, - SmallVector &asyncTaskC, - SmallVector &asyncTasksPC) -> void { - asyncTaskP = getNestedAsyncTaskIds(p); - asyncTaskC = getNestedAsyncTaskIds(c); - asyncTasksPC.reserve(asyncTaskP.size() + asyncTaskC.size()); - asyncTasksPC.insert(asyncTasksPC.end(), asyncTaskP.begin(), - asyncTaskP.end()); - asyncTasksPC.insert(asyncTasksPC.end(), asyncTaskC.begin(), - asyncTaskC.end()); - }; - // For each producer op, create a async_copy or local_store from the producer // to the buffer. Create a local_load from the buffer at the dominating // consumer. @@ -1524,19 +1507,28 @@ void insertAsyncCopy(triton::FuncOp funcOp, bufferIdx = forOp.getBody()->getArguments().back(); phase = forOp.getBody()->getArgument(tSize - 2); // next to last argument } else { - llvm_unreachable("Producer is not in a ForOp"); + // Producer is not in a ForOp, create phase and bufferIdx here which will + // be used by both producer and consumers. + OpBuilderWithAsyncTaskIds builder(srcOp); + SmallVector asyncTasksPC = getAsyncTaskIds(srcOp); + for (auto channel : mutuallyNonDominatingChannels) + asyncTasksPC.append(getAsyncTaskIds(channel->getDstOp())); + builder.setAsynTaskIdsFromArray(asyncTasksPC); + bufferIdx = builder.createWithAsyncTaskIds( + srcOp->getLoc(), 0, 32); + phase = builder.createWithAsyncTaskIds( + srcOp->getLoc(), 0, 1); } for (auto channel : mutuallyNonDominatingChannels) { // No need to create async copy for TMA load which is handled in // insertAsyncComm. - if (auto tmaLoad = dyn_cast(srcOp)) { + if (isa(srcOp)) { continue; } if (isa(srcOp)) { - SmallVector asyncTaskP, asyncTaskC, asyncTasksPC; - getAsyncTasks(srcOp, channel->getDstOp(), asyncTaskP, asyncTaskC, - asyncTasksPC); + SmallVector asyncTasksPC = getAsyncTaskIds(srcOp); + asyncTasksPC.append(getAsyncTaskIds(channel->getDstOp())); // After createAsyncCopy, c->getSrcOp()/headProducer are no longer // valid. createAsyncCopy(bufferMap, channel, channel->getSrcOp(), asyncTasksPC, @@ -1594,7 +1586,6 @@ class TritonGPUWSCodePartitionPass // - each entry of the channelsGroupedByConsumers is keyed by the dstOp. DenseMap> channelsGroupedByProducers; DenseMap> channelsGroupedByConsumers; - SmallVector orderedChannels; groupChannels(channels, channelsGroupedByProducers, channelsGroupedByConsumers, orderedChannels); From 735779cd266fd2052d63c6e8b5e9bbd473f741a6 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Mon, 9 Dec 2024 11:13:29 -0800 Subject: [PATCH 7/8] A lit test --- .../WarpSpecialization/ws_code_partition.mlir | 129 ++++++++++++++++++ 1 file changed, 129 insertions(+) diff --git a/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir b/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir index 0461ce39b6..252fbaf9f0 100644 --- a/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir +++ b/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir @@ -304,3 +304,132 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + + +// ----- + +// CHECK-LABEL: @_matmul_layernorm_persistent_one_producer_one_consumer_one_epilog +// CHECK: %[[#TASKID:]] = triton_nvidia_gpu.get_async_task_id : i32 +// CHECK: %c0_i32_0 = arith.constant 0 : i32 +// CHECK: %[[#WG0:]] = arith.cmpi eq, %[[#TASKID]], %c0_i32_0 : i32 +// CHECK: scf.if %[[#WG0]] +// CHECK: triton_nvidia_gpu.reg_dealloc 40 +// CHECK: scf.for +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_nvidia_gpu.barrier_expect +// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local +// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local +// CHECK: %c1_i32 = arith.constant 1 : i32 +// CHECK: %[[#WG1:]] = arith.cmpi eq, %[[#TASKID]], %c1_i32 : i32 +// CHECK: scf.if %[[#WG1]] +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.wait_barrier +// CHECK: triton_gpu.local_load +// CHECK: triton_gpu.local_load +// CHECK: triton_nvidia_gpu.warp_group_dot +// CHECK: triton_nvidia_gpu.consumer_release +// CHECK: triton_gpu.local_store +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: %c2_i32 = arith.constant 2 : i32 +// CHECK: %[[#WG2:]] = arith.cmpi eq, %[[#TASKID]], %c2_i32 : i32 +// CHECK: scf.if %[[#WG2]] +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: scf.for +// CHECK: scf.for +// CHECK: triton_gpu.local_load +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: tt.experimental_descriptor_store +// CHECK: triton_nvidia_gpu.consumer_release + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @_matmul_layernorm_persistent_one_producer_one_consumer_one_epilog(%arg0: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg1: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg2: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg3: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg4: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: f32) attributes {noinline = false} { + %c63_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 63 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 128 : i32 + %c0_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i32 + %c64_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i32 + %c132_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 132 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 127 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 256 : i32 + %c255_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 255 : i32 + %cst = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} dense<0.000000e+00> : tensor<128x256xf32, #mma> + %cst_0 = arith.constant {async_task_id = dense<2> : vector<1xi32>} dense<1.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %0 = arith.addi %arg7, %c63_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %1 = arith.divsi %0, %c64_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %2 = arith.addi %arg5, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %3 = arith.divsi %2, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %4 = arith.addi %arg6, %c255_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %5 = arith.divsi %4, %c256_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %6 = arith.muli %3, %5 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %7 = tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %8 = arith.sitofp %arg6 {async_task_id = dense<2> : vector<1xi32>} : i32 to f32 + %9 = tt.splat %8 {async_task_id = dense<2> : vector<1xi32>} : f32 -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %10 = tt.splat %arg11 {async_task_id = dense<2> : vector<1xi32>} : f32 -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + scf.for %arg12 = %7 to %6 step %c132_i32 : i32 { + %11 = arith.muli %arg12, %c128_i32 {async_task_id = dense<[0, 2]> : vector<2xi32>} : i32 + %true = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %12 = scf.for %arg13 = %c0_i32 to %1 step %c1_i32 iter_args(%arg14 = %cst) -> (tensor<128x256xf32, #mma>) : i32 { + %45 = arith.muli %arg13, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %46 = tt.experimental_descriptor_load %arg0[%11, %45] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<128x64xf16, #blocked> + %47 = triton_gpu.local_alloc %46 {async_task_id = dense<1> : vector<1xi32>} : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %48 = tt.experimental_descriptor_load %arg1[%45, %c0_i32] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x256xf16, #blocked1> + %49 = triton_gpu.local_alloc %48 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> + %50 = triton_nvidia_gpu.warp_group_dot %47, %49, %arg14 {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %50 : tensor<128x256xf32, #mma> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %13 = "tt.reduce"(%12) <{axis = 1 : i32}> ({ + ^bb0(%arg13: f32, %arg14: f32): + %45 = arith.addf %arg13, %arg14 {async_task_id = dense<2> : vector<1xi32>} : f32 + tt.reduce.return %45 {async_task_id = dense<2> : vector<1xi32>} : f32 + }) {async_task_id = dense<2> : vector<1xi32>} : (tensor<128x256xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %14 = arith.divf %13, %9 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %15 = tt.expand_dims %14 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %16 = tt.broadcast %15 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x1xf32, #mma> -> tensor<128x256xf32, #mma> + %17 = arith.subf %12, %16 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %18 = arith.mulf %17, %17 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %19 = "tt.reduce"(%18) <{axis = 1 : i32}> ({ + ^bb0(%arg13: f32, %arg14: f32): + %45 = arith.addf %arg13, %arg14 {async_task_id = dense<2> : vector<1xi32>} : f32 + tt.reduce.return %45 {async_task_id = dense<2> : vector<1xi32>} : f32 + }) {async_task_id = dense<2> : vector<1xi32>} : (tensor<128x256xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %20 = arith.divf %19, %9 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %21 = arith.addf %20, %10 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %22 = math.sqrt %21 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %23 = arith.divf %cst_0, %22 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %24 = tt.experimental_descriptor_load %arg3[%c0_i32] {async_task_id = dense<2> : vector<1xi32>} : !tt.ptr -> tensor<256xf16, #blocked2> + %25 = tt.experimental_descriptor_load %arg4[%c0_i32] {async_task_id = dense<2> : vector<1xi32>} : !tt.ptr -> tensor<256xf16, #blocked2> + %26 = tt.expand_dims %23 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %27 = tt.broadcast %26 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x1xf32, #mma> -> tensor<128x256xf32, #mma> + %28 = arith.mulf %17, %27 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %29 = triton_gpu.convert_layout %24 {async_task_id = dense<2> : vector<1xi32>} : tensor<256xf16, #blocked2> -> tensor<256xf16, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %30 = tt.expand_dims %29 {async_task_id = dense<2> : vector<1xi32>, axis = 0 : i32} : tensor<256xf16, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xf16, #blocked1> + %31 = triton_gpu.convert_layout %30 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf16, #blocked1> -> tensor<1x256xf16, #blocked3> + %32 = arith.extf %31 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf16, #blocked3> to tensor<1x256xf32, #blocked3> + %33 = triton_gpu.convert_layout %32 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf32, #blocked3> -> tensor<1x256xf32, #mma> + %34 = tt.broadcast %33 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf32, #mma> -> tensor<128x256xf32, #mma> + %35 = arith.mulf %28, %34 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %36 = triton_gpu.convert_layout %25 {async_task_id = dense<2> : vector<1xi32>} : tensor<256xf16, #blocked2> -> tensor<256xf16, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %37 = tt.expand_dims %36 {async_task_id = dense<2> : vector<1xi32>, axis = 0 : i32} : tensor<256xf16, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xf16, #blocked1> + %38 = triton_gpu.convert_layout %37 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf16, #blocked1> -> tensor<1x256xf16, #blocked3> + %39 = arith.extf %38 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf16, #blocked3> to tensor<1x256xf32, #blocked3> + %40 = triton_gpu.convert_layout %39 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf32, #blocked3> -> tensor<1x256xf32, #mma> + %41 = tt.broadcast %40 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf32, #mma> -> tensor<128x256xf32, #mma> + %42 = arith.addf %35, %41 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %43 = arith.truncf %42 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %44 = triton_gpu.convert_layout %43 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> + tt.experimental_descriptor_store %arg2[%11, %c0_i32], %44 {async_task_id = dense<2> : vector<1xi32>} : !tt.ptr, tensor<128x256xf16, #blocked1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + tt.return + } +} From a182fedfa9083e87ea165f1d966e7a95908df50d Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Wed, 11 Dec 2024 13:33:36 -0800 Subject: [PATCH 8/8] Address comments --- .../TritonGPU/Transforms/WSCodePartition.cpp | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp index c688f0986c..8443c1e9f7 100644 --- a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp @@ -403,8 +403,8 @@ struct Channel { unsigned numBuffers; }; -// Find transitive users of the root op. Ignore control flow ops such as yield -// in between. +// Find transitive users of the root op. Track through control flow ops (such as +// yield) to get to the real users. void getTransitiveUsers(Value root, SetVector> &users) { for (Operation *userOp : root.getUsers()) { @@ -462,8 +462,8 @@ void collectAsyncChannels(SmallVector> &channels, SetVector> users; getTransitiveUsers(result, users); - for (auto pc : users) { - auto userOp = pc.first; + for (auto user : users) { + auto userOp = user.first; auto consumerTaskIds = getAsyncTaskIds(userOp); if (consumerTaskIds.empty()) continue; @@ -474,7 +474,7 @@ void collectAsyncChannels(SmallVector> &channels, // Add a channel from the single producer task to consumerTaskIds. if (consumerTaskIds.size() > 0) { channels.push_back(std::make_unique( - producerTaskId, consumerTaskIds, userOp, pc.second, + producerTaskId, consumerTaskIds, userOp, user.second, producerNumBuffers)); } } @@ -495,8 +495,13 @@ void collectAsyncChannels(SmallVector> &channels, }); } -// Update map, which will be keyed by getDstOp() of the channel. Use -// orderedChannels to enforce deterministic order for map. +// Group channels in two ways: +// - by producer ops. One producer corresponds to multiple channels. This +// grouping will be used to create buffers per shared producer. +// - by consumer ops. One consumer corresponds to multiple channels. This +// grouping will be used to create barriers per shared consumer. +// Also compute orderedChannels, which will be keyed by getDstOp() of channels, +// to enforce deterministic order for map. void groupChannels( SmallVector &channels, DenseMap> &channelsGroupedByProducers, @@ -1006,6 +1011,7 @@ DenseMap createBuffer( DenseMap bufferMap; MLIRContext *context = funcOp.getContext(); OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); for (auto &item : channelsGroupedByProducers) { auto &channels = item.second; auto srcValue = item.first->getSrcOperand(); @@ -1038,7 +1044,6 @@ DenseMap createBuffer( Type memdescType = tt::MemDescType::get(bufferShape, elemType, sharedLayout, sharedMemorySpace, /*mutableMemory*/ true); - builder.setInsertionPointToStart(&(funcOp.getBody().front())); Value buffer = builder.create(funcOp.getLoc(), memdescType); @@ -1115,6 +1120,8 @@ static Operation *createAsyncCopy(const DenseMap &bufferMap, return copy; } +// Create a local copy for a channel that is populated by the producer and +// accessed by the consumer. static void createLocalCopy(const DenseMap &bufferMap, Channel *channel, Value srcBufferIdx, Value dstBufferIdx) { @@ -1505,7 +1512,6 @@ void insertAsyncCopy(triton::FuncOp funcOp, auto tSize = forOp.getBody()->getArguments().size(); assert(tSize >= 2); bufferIdx = forOp.getBody()->getArguments().back(); - phase = forOp.getBody()->getArgument(tSize - 2); // next to last argument } else { // Producer is not in a ForOp, create phase and bufferIdx here which will // be used by both producer and consumers. @@ -1516,8 +1522,6 @@ void insertAsyncCopy(triton::FuncOp funcOp, builder.setAsynTaskIdsFromArray(asyncTasksPC); bufferIdx = builder.createWithAsyncTaskIds( srcOp->getLoc(), 0, 32); - phase = builder.createWithAsyncTaskIds( - srcOp->getLoc(), 0, 1); } for (auto channel : mutuallyNonDominatingChannels) { @@ -1626,6 +1630,7 @@ class TritonGPUWSCodePartitionPass funcOp.dump(); }); + // Step 7: Lower the loads. Also add local copy ops for non-load producers. insertAsyncCopy(funcOp, channelsGroupedByProducers, bufferMap); LLVM_DEBUG({ LDBG("\n\nwith async copy"); @@ -1634,7 +1639,6 @@ class TritonGPUWSCodePartitionPass // If loadResult has a single use which is LocalAlloc, we can get rid of // sharedLoad and replace all uses of LocalAlloc with viewLoad. - DenseMap opsToReplace; foldLocalLoads(funcOp); LLVM_DEBUG({ LDBG("\n\nsimplify localLoad + localAlloc");