diff --git a/third_party/tlx/dialect/lib/Transforms/PrintTTGIRToTLX.cpp b/third_party/tlx/dialect/lib/Transforms/PrintTTGIRToTLX.cpp index ccd5e3805b..9f562f50a7 100644 --- a/third_party/tlx/dialect/lib/Transforms/PrintTTGIRToTLX.cpp +++ b/third_party/tlx/dialect/lib/Transforms/PrintTTGIRToTLX.cpp @@ -58,6 +58,7 @@ //===----------------------------------------------------------------------===// #include "IR/Dialect.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -149,8 +150,17 @@ static const TTGIRToTLXMapping opMappings[] = { "Wait for TMA stores to complete"}, {"tt.make_tensor_descriptor", "tlx.make_tensor_descriptor", "Create TMA descriptor on device"}, + {"ttng.tensormap_create", "tlx.make_tensor_descriptor", + "Create TMA descriptor on device (Blackwell)"}, {"tt.reinterpret_tensor_descriptor", "tlx.reinterpret_tensor_descriptor", "Reinterpret TMA descriptor with new shape"}, + {"ttng.reinterpret_tensor_descriptor", "tlx.reinterpret_tensor_descriptor", + "Reinterpret TMA descriptor (Blackwell)"}, + {"ttg.global_scratch_alloc", "tlx.allocate_tensor_descriptor", + "Allocate global scratch for TMA descriptors"}, + {"ttng.tensormap_fenceproxy_acquire", + "tl.extra.cuda.experimental_tensormap_fenceproxy_acquire", + "Fence proxy acquire for TMA descriptor"}, // MMA operations {"ttng.warp_group_dot", "tlx.warp_group_dot", @@ -194,12 +204,7 @@ static const TTGIRToTLXMapping opMappings[] = { // Binary arith ops (add, sub, mul, div, rem, xor, and, or) are handled // as infix operators (a + b, a * b, etc.) in printSimplifiedOp. {"arith.constant", "const", "Constant value"}, - {"arith.cmpf", "cmpf", "Float comparison"}, - {"arith.cmpi", "cmpi", "Integer comparison"}, {"arith.select", "select", "Select operation"}, - {"arith.extui", "extui", "Extend unsigned integer"}, - {"arith.extsi", "extsi", "Extend signed integer"}, - {"arith.trunci", "trunci", "Truncate integer"}, {"arith.maxf", "tl.maximum", "Float max"}, {"arith.maxnumf", "tl.maximum", "Float max (NaN-propagating)"}, {"arith.minf", "tl.minimum", "Float min"}, @@ -208,13 +213,6 @@ static const TTGIRToTLXMapping opMappings[] = { {"arith.maxui", "tl.max", "Unsigned integer max"}, {"arith.minsi", "tl.min", "Signed integer min"}, {"arith.minui", "tl.min", "Unsigned integer min"}, - {"arith.sitofp", "sitofp", "Signed int to float"}, - {"arith.fptosi", "fptosi", "Float to signed int"}, - {"arith.uitofp", "uitofp", "Unsigned int to float"}, - {"arith.fptoui", "fptoui", "Float to unsigned int"}, - {"arith.truncf", "truncf", "Truncate float"}, - {"arith.extf", "extf", "Extend float"}, - {"arith.bitcast", "bitcast", "Bitcast"}, // Triton operations {"tt.splat", "tl.splat", "Splat scalar to tensor"}, @@ -279,6 +277,70 @@ llvm::StringMap buildInfixOpMap() { return map; } +// Get comparison operator string for arith.cmpi predicates +StringRef getCmpIOperator(int64_t predicate) { + switch (predicate) { + case 0: + return "=="; // eq + case 1: + return "!="; // ne + case 2: + return "<"; // slt + case 3: + return "<="; // sle + case 4: + return ">"; // sgt + case 5: + return ">="; // sge + case 6: + return "<"; // ult + case 7: + return "<="; // ule + case 8: + return ">"; // ugt + case 9: + return ">="; // uge + default: + return "??"; + } +} + +// Get comparison operator string for arith.cmpf predicates +StringRef getCmpFOperator(int64_t predicate) { + switch (predicate) { + case 0: + return "False"; // false + case 1: + return "=="; // oeq + case 2: + return ">"; // ogt + case 3: + return ">="; // oge + case 4: + return "<"; // olt + case 5: + return "<="; // ole + case 6: + return "!="; // one + case 8: + return "=="; // ueq + case 9: + return ">"; // ugt + case 10: + return ">="; // uge + case 11: + return "<"; // ult + case 12: + return "<="; // ule + case 13: + return "!="; // une + case 15: + return "True"; // true + default: + return "??"; + } +} + // Build a lookup map for fast operation name lookup llvm::StringMap buildOpNameMap() { llvm::StringMap map; @@ -305,9 +367,16 @@ getValueName(Value v, } } - // Pass through convert_layout: use its input operand's name instead + // Pass through convert_layout and type casts: use the input operand's name if (Operation *defOp = v.getDefiningOp()) { - if (defOp->getName().getStringRef() == "ttg.convert_layout" && + static const llvm::StringSet<> transparentOps = { + "ttg.convert_layout", "arith.extui", "arith.extsi", + "arith.extf", "arith.trunci", "arith.truncf", + "arith.sitofp", "arith.uitofp", "arith.fptosi", + "arith.fptoui", "arith.bitcast", "arith.index_cast", + "arith.index_castui", + }; + if (transparentOps.contains(defOp->getName().getStringRef()) && defOp->getNumOperands() > 0) { return getValueName(defOp->getOperand(0), argSubstitutionMap, inlineConstants); @@ -344,6 +413,36 @@ getValueName(Value v, } } + // Inline constants: if this value is defined by arith.constant, return the + // literal value + if (inlineConstants) { + if (Operation *defOp = v.getDefiningOp()) { + if (defOp->getName().getStringRef() == "arith.constant") { + if (auto valueAttr = defOp->getAttr("value")) { + std::string result; + llvm::raw_string_ostream os(result); + if (auto intAttr = dyn_cast(valueAttr)) { + if (intAttr.getType().isInteger(1)) { + os << (intAttr.getValue().getBoolValue() ? "True" : "False"); + } else { + os << intAttr.getValue(); + } + } else if (auto floatAttr = dyn_cast(valueAttr)) { + SmallString<16> str; + floatAttr.getValue().toString(str); + os << str; + } else { + // Fall through to normal name handling for unsupported constant + // types + goto normal_name; + } + os.flush(); + return result; + } + } + } + } + normal_name: std::string name; llvm::raw_string_ostream os(name); @@ -516,6 +615,7 @@ bool shouldSkipOp(Operation *op, // - gpu.barrier: not needed in TLX // - arith.constant: values are inlined at use sites // - ttg.convert_layout: internal layout conversion + // - arith cast ops: type coercions transparent in Python // - tt.return: function terminator // - tt.reduce.return: internal to reduce operation static const llvm::StringSet<> opsToSkip = { @@ -523,7 +623,13 @@ bool shouldSkipOp(Operation *op, "ttg.warp_yield", "ttg.warp_specialize.partitions", "gpu.barrier", "arith.constant", "ttg.convert_layout", "tt.return", - "tt.reduce.return", + "tt.reduce.return", "arith.extui", + "arith.extsi", "arith.extf", + "arith.trunci", "arith.truncf", + "arith.sitofp", "arith.uitofp", + "arith.fptosi", "arith.fptoui", + "arith.bitcast", "arith.index_cast", + "arith.index_castui", }; if (opsToSkip.contains(opName)) { return true; @@ -562,6 +668,23 @@ bool shouldSkipOp(Operation *op, return skippedOps.count(op) > 0; } +static const llvm::StringSet<> castOpsSet = { + "arith.extui", "arith.extsi", "arith.trunci", "arith.extf", + "arith.truncf", "arith.bitcast", "arith.sitofp", "arith.uitofp", + "arith.fptosi", "arith.fptoui", "arith.index_cast", "arith.index_castui", +}; + +static Value resolveThroughCasts(Value v) { + while (auto *op = v.getDefiningOp()) { + if (castOpsSet.contains(op->getName().getStringRef()) && + op->getNumOperands() > 0) + v = op->getOperand(0); + else + break; + } + return v; +} + // Forward declarations void printRegion(Region ®ion, llvm::raw_ostream &os, const llvm::StringMap &opNameMap, @@ -570,6 +693,28 @@ void printRegion(Region ®ion, llvm::raw_ostream &os, DenseMap *argSubstitutionMap = nullptr, ArrayRef yieldTargets = {}); +struct ForLoopInfo { + unsigned iterArgIdx; // header block arg index of the iterator + std::string start; // init value expression + std::string end; // bound expression + std::string step; // step expression + Operation *stepOp; // addi op to add to skippedOps +}; + +void printCFRegion(Region ®ion, llvm::raw_ostream &os, + const llvm::StringMap &opNameMap, + const DenseMap &allocInfoMap, + llvm::DenseSet &skippedOps, unsigned indent, + DenseMap *argSubstitutionMap = nullptr); + +void printCFBlocks(Block *startBlock, Block *stopBlock, llvm::raw_ostream &os, + const llvm::StringMap &opNameMap, + const DenseMap &allocInfoMap, + llvm::DenseSet &skippedOps, unsigned indent, + DenseMap *argSubstitutionMap, + llvm::SmallDenseSet &visitedBlocks, + const DenseMap &forLoopHeaders); + // Print scf.for in Python range syntax void printForOp(Operation *op, llvm::raw_ostream &os, const llvm::StringMap &opNameMap, @@ -849,6 +994,20 @@ void printSimplifiedOp( return; } + // Special handling for cmpi/cmpf - print as infix comparison + if ((opName == "arith.cmpi" || opName == "arith.cmpf") && + op->getNumOperands() == 2 && op->getNumResults() > 0) { + if (auto predAttr = op->getAttrOfType("predicate")) { + int64_t pred = predAttr.getInt(); + StringRef cmpOp = (opName == "arith.cmpi") ? getCmpIOperator(pred) + : getCmpFOperator(pred); + os << getValueName(op->getResult(0), argSubstitutionMap) << " = " + << getValueName(op->getOperand(0), argSubstitutionMap) << " " << cmpOp + << " " << getValueName(op->getOperand(1), argSubstitutionMap) << "\n"; + return; + } + } + // Special handling for local_alloc if (opName == "ttg.local_alloc") { auto it = allocInfoMap.find(op); @@ -1028,12 +1187,531 @@ void printBlock(Block &block, llvm::raw_ostream &os, } } +// If the condition value is defined by a cmpi/cmpf in the same block as the +// cf.cond_br, return the inlined comparison expression (e.g., "var_0 < var_1") +// and add the defining op to skippedOps so it won't be printed separately. +// Returns empty string if inlining is not possible. +std::string getInlinedCondExpr(Value cond, + llvm::DenseSet &skippedOps, + DenseMap *argSubstitutionMap) { + // Resolve through transparent cast ops to find the actual comparison + Value resolved = resolveThroughCasts(cond); + + auto *defOp = resolved.getDefiningOp(); + if (!defOp || defOp->getNumOperands() != 2 || defOp->getNumResults() == 0) + return ""; + auto opName = defOp->getName().getStringRef(); + if (opName != "arith.cmpi" && opName != "arith.cmpf") + return ""; + auto predAttr = defOp->getAttrOfType("predicate"); + if (!predAttr) + return ""; + // Only inline if all uses of the comparison result are in CF terminators + // (cond_br condition or branch operands), which the structured printer + // handles directly. + for (auto *user : defOp->getResult(0).getUsers()) { + if (!user->hasTrait()) + return ""; + } + int64_t pred = predAttr.getInt(); + StringRef cmpOp = + (opName == "arith.cmpi") ? getCmpIOperator(pred) : getCmpFOperator(pred); + skippedOps.insert(defOp); + return getValueName(defOp->getOperand(0), argSubstitutionMap) + " " + + cmpOp.str() + " " + + getValueName(defOp->getOperand(1), argSubstitutionMap); +} + +// Print non-terminator ops from a block (used by CF-aware printer) +void printBlockOps(Block &block, llvm::raw_ostream &os, + const llvm::StringMap &opNameMap, + const DenseMap &allocInfoMap, + llvm::DenseSet &skippedOps, unsigned indent, + DenseMap *argSubstitutionMap) { + for (Operation &op : block) { + if (op.hasTrait()) + break; + if (shouldSkipOp(&op, allocInfoMap, skippedOps)) + continue; + + // Reuse the same special-case handling from printBlock + if (op.getName().getStringRef() == "scf.yield") + continue; + + if (op.getName().getStringRef() == "ttg.warp_specialize") { + printWarpSpecialize(&op, os, opNameMap, allocInfoMap, skippedOps, indent); + continue; + } + if (op.getName().getStringRef() == "scf.for") { + printForOp(&op, os, opNameMap, allocInfoMap, skippedOps, indent, + argSubstitutionMap); + continue; + } + if (op.getName().getStringRef() == "scf.if") { + printIfOp(&op, os, opNameMap, allocInfoMap, skippedOps, indent, + argSubstitutionMap); + continue; + } + if (op.getNumRegions() > 0) { + printSimplifiedOp(&op, os, opNameMap, allocInfoMap, indent, + argSubstitutionMap); + for (unsigned i = 0; i < indent; ++i) + os << " "; + os << "{\n"; + for (Region ®ion : op.getRegions()) { + printRegion(region, os, opNameMap, allocInfoMap, skippedOps, indent + 1, + argSubstitutionMap); + } + for (unsigned i = 0; i < indent; ++i) + os << " "; + os << "}\n"; + } else { + printSimplifiedOp(&op, os, opNameMap, allocInfoMap, indent, + argSubstitutionMap); + } + } +} + +// Print block arg assignments: dest_arg = src_value +// If skipArgIdx >= 0, skip that arg index (used for for-loop iterators). +void printBlockArgAssignments(Block *dest, OperandRange operands, + llvm::raw_ostream &os, unsigned indent, + DenseMap *argSubstitutionMap, + int skipArgIdx = -1) { + for (unsigned i = 0; i < dest->getNumArguments() && i < operands.size(); + ++i) { + if ((int)i == skipArgIdx) + continue; + std::string destName = + getValueName(dest->getArgument(i), argSubstitutionMap); + std::string srcName = getValueName(operands[i], argSubstitutionMap); + if (destName != srcName) { + for (unsigned j = 0; j < indent; ++j) + os << " "; + os << destName << " = " << srcName << "\n"; + } + } +} + +// Detect if a header block represents a for-loop: iter starts at init, +// condition is iter < end, update is iter = iter + step. +bool detectForLoopPattern(Block *header, ForLoopInfo &info, + DenseMap *argSubstitutionMap) { + if (header->getNumArguments() == 0) + return false; + auto condBr = dyn_cast(header->getTerminator()); + if (!condBr) + return false; + + // Resolve condition through casts to find cmpi + Value condResolved = resolveThroughCasts(condBr.getCondition()); + auto *cmpiOp = condResolved.getDefiningOp(); + if (!cmpiOp || cmpiOp->getName().getStringRef() != "arith.cmpi") + return false; + auto predAttr = cmpiOp->getAttrOfType("predicate"); + if (!predAttr) + return false; + int64_t pred = predAttr.getInt(); + // slt (2) or ult (6) + if (pred != 2 && pred != 6) + return false; + + // LHS must be a header block arg (the iterator) + Value lhs = resolveThroughCasts(cmpiOp->getOperand(0)); + auto iterArg = dyn_cast(lhs); + if (!iterArg || iterArg.getOwner() != header) + return false; + unsigned iterIdx = iterArg.getArgNumber(); + + // Find loop body blocks via BFS from trueDest (not crossing header) + Block *trueDest = condBr.getTrueDest(); + llvm::SmallDenseSet bodyBlocks; + llvm::SmallVector worklist; + worklist.push_back(trueDest); + while (!worklist.empty()) { + Block *b = worklist.pop_back_val(); + if (!b || b == header || bodyBlocks.count(b)) + continue; + bodyBlocks.insert(b); + auto *t = b->getTerminator(); + if (auto br = dyn_cast(t)) + worklist.push_back(br.getDest()); + else if (auto cb = dyn_cast(t)) { + worklist.push_back(cb.getTrueDest()); + worklist.push_back(cb.getFalseDest()); + } + } + + // Find step from back-edge predecessor + Operation *stepOp = nullptr; + std::string stepStr; + for (Block *pred : header->getPredecessors()) { + if (!bodyBlocks.count(pred)) + continue; + auto *predTerm = pred->getTerminator(); + Value updateVal; + if (auto br = dyn_cast(predTerm)) { + if (br.getDest() == header && iterIdx < br.getDestOperands().size()) + updateVal = br.getDestOperands()[iterIdx]; + } else if (auto cb = dyn_cast(predTerm)) { + if (cb.getTrueDest() == header && + iterIdx < cb.getTrueDestOperands().size()) + updateVal = cb.getTrueDestOperands()[iterIdx]; + else if (cb.getFalseDest() == header && + iterIdx < cb.getFalseDestOperands().size()) + updateVal = cb.getFalseDestOperands()[iterIdx]; + } + if (!updateVal) + continue; + Value resolved = resolveThroughCasts(updateVal); + auto *addOp = resolved.getDefiningOp(); + if (!addOp || addOp->getName().getStringRef() != "arith.addi" || + addOp->getNumOperands() != 2) + return false; + Value a0 = resolveThroughCasts(addOp->getOperand(0)); + Value a1 = resolveThroughCasts(addOp->getOperand(1)); + if (a0 == iterArg) { + stepStr = getValueName(a1, argSubstitutionMap); + stepOp = addOp; + } else if (a1 == iterArg) { + stepStr = getValueName(a0, argSubstitutionMap); + stepOp = addOp; + } else { + return false; + } + break; + } + if (stepStr.empty()) + return false; + + // Find init from non-body predecessor + std::string initStr; + for (Block *pred : header->getPredecessors()) { + if (bodyBlocks.count(pred)) + continue; + auto *predTerm = pred->getTerminator(); + Value initVal; + if (auto br = dyn_cast(predTerm)) { + if (br.getDest() == header && iterIdx < br.getDestOperands().size()) + initVal = br.getDestOperands()[iterIdx]; + } else if (auto cb = dyn_cast(predTerm)) { + if (cb.getTrueDest() == header && + iterIdx < cb.getTrueDestOperands().size()) + initVal = cb.getTrueDestOperands()[iterIdx]; + else if (cb.getFalseDest() == header && + iterIdx < cb.getFalseDestOperands().size()) + initVal = cb.getFalseDestOperands()[iterIdx]; + } + if (!initVal) + continue; + initStr = getValueName(initVal, argSubstitutionMap); + break; + } + if (initStr.empty()) + return false; + + info.iterArgIdx = iterIdx; + info.start = initStr; + info.end = getValueName(cmpiOp->getOperand(1), argSubstitutionMap); + info.step = stepStr; + info.stepOp = stepOp; + return true; +} + +// Find the immediate post-dominator (merge block) for a cf.cond_br. +// For a simple if-else diamond, this is the single successor shared by +// both branches. We walk forward from each branch to find the first block +// that is reachable from both sides. +Block *findMergeBlock(cf::CondBranchOp condBr) { + Block *trueDest = condBr.getTrueDest(); + Block *falseDest = condBr.getFalseDest(); + + // Simple case: both branches go to the same block + if (trueDest == falseDest) + return trueDest; + + // Collect all blocks reachable from trueDest (following unconditional + // branches only, stopping at conditional branches or blocks with multiple + // predecessors from outside the chain) + llvm::SmallDenseSet trueReachable; + Block *b = trueDest; + while (b) { + trueReachable.insert(b); + auto *term = b->getTerminator(); + if (auto br = dyn_cast(term)) { + b = br.getDest(); + } else { + break; + } + } + + // Walk from falseDest, find first block also reachable from true side + b = falseDest; + while (b) { + if (trueReachable.count(b)) + return b; + auto *term = b->getTerminator(); + if (auto br = dyn_cast(term)) { + b = br.getDest(); + } else { + break; + } + } + + // No merge found — check if trueDest's successor chain leads to falseDest + // or vice versa (one-armed if) + b = trueDest; + while (b) { + auto *term = b->getTerminator(); + if (auto br = dyn_cast(term)) { + if (br.getDest() == falseDest) + return falseDest; + b = br.getDest(); + } else { + break; + } + } + b = falseDest; + while (b) { + auto *term = b->getTerminator(); + if (auto br = dyn_cast(term)) { + if (br.getDest() == trueDest) + return trueDest; + b = br.getDest(); + } else { + break; + } + } + + return nullptr; +} + +// Print a CF region by walking the CFG and emitting structured if/else/while. +// Handles blocks from `startBlock` up to (but not including) `stopBlock`. +void printCFBlocks(Block *startBlock, Block *stopBlock, llvm::raw_ostream &os, + const llvm::StringMap &opNameMap, + const DenseMap &allocInfoMap, + llvm::DenseSet &skippedOps, unsigned indent, + DenseMap *argSubstitutionMap, + llvm::SmallDenseSet &visitedBlocks, + const DenseMap &forLoopHeaders) { + Block *current = startBlock; + while (current && current != stopBlock) { + if (visitedBlocks.count(current)) + return; + visitedBlocks.insert(current); + + // Pre-scan: if the block terminates with cf.cond_br whose condition comes + // from a cmpi/cmpf, mark the comparison as skipped before printing block + // ops so it gets inlined into the if/while line instead of printed twice. + std::string preComputedCondExpr; + if (auto condBrPre = dyn_cast(current->getTerminator())) { + preComputedCondExpr = getInlinedCondExpr(condBrPre.getCondition(), + skippedOps, argSubstitutionMap); + } + + // Print non-terminator operations + printBlockOps(*current, os, opNameMap, allocInfoMap, skippedOps, indent, + argSubstitutionMap); + + Operation *term = current->getTerminator(); + + // cf.cond_br: emit if/else structure + if (auto condBr = dyn_cast(term)) { + Block *trueDest = condBr.getTrueDest(); + Block *falseDest = condBr.getFalseDest(); + Block *mergeBlock = findMergeBlock(condBr); + + // Check if this is a while loop header: the false branch exits the + // loop (goes to mergeBlock or stopBlock) and the true branch is the + // loop body that eventually branches back to current. + // Pattern: current block has args, true branch leads back to current. + bool isWhileLoop = false; + if (current->getNumArguments() > 0) { + // BFS to check if the true-side eventually branches back to current + llvm::SmallVector worklist; + llvm::SmallDenseSet visited; + worklist.push_back(trueDest); + while (!worklist.empty() && !isWhileLoop) { + Block *walk = worklist.pop_back_val(); + if (!walk || visited.count(walk)) + continue; + visited.insert(walk); + if (walk == current) { + isWhileLoop = true; + break; + } + auto *t = walk->getTerminator(); + if (auto br = dyn_cast(t)) { + worklist.push_back(br.getDest()); + } else if (auto cb = dyn_cast(t)) { + worklist.push_back(cb.getTrueDest()); + worklist.push_back(cb.getFalseDest()); + } + } + } + + if (isWhileLoop) { + // Check if this matches a for-loop pattern + auto forIt = forLoopHeaders.find(current); + if (forIt != forLoopHeaders.end()) { + const ForLoopInfo &fli = forIt->second; + // Add step op to skippedOps so it's not printed separately + if (fli.stepOp) + skippedOps.insert(fli.stepOp); + int skipIdx = (int)fli.iterArgIdx; + + for (unsigned i = 0; i < indent; ++i) + os << " "; + std::string iterName = getValueName( + current->getArgument(fli.iterArgIdx), argSubstitutionMap); + if (fli.step == "1") + os << "for " << iterName << " in tl.range(" << fli.start << ", " + << fli.end << "):\n"; + else + os << "for " << iterName << " in tl.range(" << fli.start << ", " + << fli.end << ", " << fli.step << "):\n"; + + // Print true-dest arg assignments (skip iterator) + printBlockArgAssignments(trueDest, condBr.getTrueDestOperands(), os, + indent + 1, argSubstitutionMap, skipIdx); + + // Print loop body + printCFBlocks(trueDest, current, os, opNameMap, allocInfoMap, + skippedOps, indent + 1, argSubstitutionMap, + visitedBlocks, forLoopHeaders); + + // Continue with exit + printBlockArgAssignments(falseDest, condBr.getFalseDestOperands(), os, + indent, argSubstitutionMap, skipIdx); + current = falseDest; + continue; + } + + // Regular while loop + std::string condExpr = + preComputedCondExpr.empty() + ? getValueName(condBr.getCondition(), argSubstitutionMap) + : preComputedCondExpr; + for (unsigned i = 0; i < indent; ++i) + os << " "; + os << "while " << condExpr << ":\n"; + + // Print true-dest arg assignments if any + printBlockArgAssignments(trueDest, condBr.getTrueDestOperands(), os, + indent + 1, argSubstitutionMap); + + // Print loop body (true branch), stopping when we get back to current + printCFBlocks(trueDest, current, os, opNameMap, allocInfoMap, + skippedOps, indent + 1, argSubstitutionMap, visitedBlocks, + forLoopHeaders); + + // After the while, continue with the false dest (exit) + printBlockArgAssignments(falseDest, condBr.getFalseDestOperands(), os, + indent, argSubstitutionMap); + current = falseDest; + continue; + } + + // Regular if/else + std::string condExpr = + preComputedCondExpr.empty() + ? getValueName(condBr.getCondition(), argSubstitutionMap) + : preComputedCondExpr; + for (unsigned i = 0; i < indent; ++i) + os << " "; + os << "if " << condExpr << ":\n"; + + // Print true-dest arg assignments + printBlockArgAssignments(trueDest, condBr.getTrueDestOperands(), os, + indent + 1, argSubstitutionMap); + + if (trueDest != mergeBlock) { + printCFBlocks(trueDest, mergeBlock, os, opNameMap, allocInfoMap, + skippedOps, indent + 1, argSubstitutionMap, visitedBlocks, + forLoopHeaders); + } + + // Print else branch if it's not the merge block or has operands + if (falseDest != mergeBlock || condBr.getFalseDestOperands().size() > 0) { + for (unsigned i = 0; i < indent; ++i) + os << " "; + os << "else:\n"; + printBlockArgAssignments(falseDest, condBr.getFalseDestOperands(), os, + indent + 1, argSubstitutionMap); + if (falseDest != mergeBlock) { + printCFBlocks(falseDest, mergeBlock, os, opNameMap, allocInfoMap, + skippedOps, indent + 1, argSubstitutionMap, + visitedBlocks, forLoopHeaders); + } + } + + // Continue with merge block + if (mergeBlock) { + current = mergeBlock; + continue; + } + return; + } + + // cf.br: unconditional branch — print arg assignments and continue + if (auto br = dyn_cast(term)) { + Block *dest = br.getDest(); + // Skip iterator arg assignment when branching to a for-loop header + int skipIdx = -1; + auto forIt = forLoopHeaders.find(dest); + if (forIt != forLoopHeaders.end()) + skipIdx = (int)forIt->second.iterArgIdx; + printBlockArgAssignments(dest, br.getDestOperands(), os, indent, + argSubstitutionMap, skipIdx); + // If dest is already visited (back-edge) or is the stop block, stop + if (visitedBlocks.count(dest) || dest == stopBlock) + return; + current = dest; + continue; + } + + // Unknown terminator — just stop + return; + } +} + +// Entry point for CF-aware region printing +void printCFRegion(Region ®ion, llvm::raw_ostream &os, + const llvm::StringMap &opNameMap, + const DenseMap &allocInfoMap, + llvm::DenseSet &skippedOps, unsigned indent, + DenseMap *argSubstitutionMap) { + if (region.empty()) + return; + + // Pre-scan: detect for-loop headers + DenseMap forLoopHeaders; + for (Block &block : region) { + ForLoopInfo info; + if (detectForLoopPattern(&block, info, argSubstitutionMap)) + forLoopHeaders[&block] = info; + } + + Block &entry = region.front(); + llvm::SmallDenseSet visitedBlocks; + printCFBlocks(&entry, nullptr, os, opNameMap, allocInfoMap, skippedOps, + indent, argSubstitutionMap, visitedBlocks, forLoopHeaders); +} + void printRegion(Region ®ion, llvm::raw_ostream &os, const llvm::StringMap &opNameMap, const DenseMap &allocInfoMap, llvm::DenseSet &skippedOps, unsigned indent, DenseMap *argSubstitutionMap, ArrayRef yieldTargets) { + // For multi-block regions with CF control flow, use the CF-aware printer + if (std::distance(region.begin(), region.end()) > 1) { + printCFRegion(region, os, opNameMap, allocInfoMap, skippedOps, indent, + argSubstitutionMap); + return; + } + // Single-block region: print sequentially for (Block &block : region) { printBlock(block, os, opNameMap, allocInfoMap, skippedOps, indent, argSubstitutionMap, yieldTargets);