From ac0ac04df95ab3bed6225b184a5907dcb2564425 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Mon, 27 Jan 2025 09:13:09 -0800 Subject: [PATCH] Fix issues for if nested in a loop --- .../TritonGPU/Transforms/WSDataPartition.cpp | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp index 17d73ea9ac..21d6024e4f 100644 --- a/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp @@ -60,12 +60,15 @@ void fixTaskId(triton::FuncOp &funcOp) { // Do not update loads. if (isa(defOp)) continue; + // Skip control flow ops. + if (isa(op)) + continue; auto defTaskIds = getAsyncTaskIds(defOp); // Make sure defTaskIds cover asyncTaskIds. Call addAsyncTaskIds if // necessary. if (!oneVecCoversTheOther(defTaskIds, asyncTaskIds)) { // Const ops with same value but different task ids can be folded. - if (isa(defOp)) { + if (defOp->getDialect()->getNamespace() == "arith") { LLVM_DEBUG({ LDBG("backward fixing taskId for"); defOp->dump(); @@ -80,7 +83,7 @@ void fixTaskId(triton::FuncOp &funcOp) { if (operand.hasOneUse() && !oneVecCoversTheOther(asyncTaskIds, defTaskIds)) { // YieldOp may lose task attribute during MLIR canonicalization. - if (isa(op)) { + if (isa(op)) { LLVM_DEBUG({ LDBG("forward fixing taskId for"); defOp->dump(); @@ -131,6 +134,23 @@ void getBackwardSliceToPartition(Value root, unsigned dim, int sliceSize, } else if (auto dotOp = dyn_cast(op)) { queue.push_back(dim == 0 ? dotOp.getA() : dotOp.getB()); queue.push_back(dotOp.getC()); + } else if (auto ifOp = dyn_cast(op)) { + // track yield value + // find result index of v + unsigned resultIndex = 0; + for (int i = 0; i < op->getNumResults(); ++i) { + if (op->getResult(i) == v) { + resultIndex = i; + break; + } + } + + auto thenYieldArg = ifOp.thenYield().getOperand(resultIndex); + backwardSlice.insert(ifOp.thenYield()); + queue.push_back(thenYieldArg); + auto elseYieldArg = ifOp.elseYield().getOperand(resultIndex); + backwardSlice.insert(ifOp.elseYield()); + queue.push_back(elseYieldArg); } else { llvm_unreachable("Unexpected op"); }