diff --git a/third_party/tlx/dialect/lib/Transforms/PrintTTGIRToTLX.cpp b/third_party/tlx/dialect/lib/Transforms/PrintTTGIRToTLX.cpp index 3fcaef1338..e2b2530417 100644 --- a/third_party/tlx/dialect/lib/Transforms/PrintTTGIRToTLX.cpp +++ b/third_party/tlx/dialect/lib/Transforms/PrintTTGIRToTLX.cpp @@ -64,6 +64,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -184,28 +185,23 @@ static const TTGIRToTLXMapping opMappings[] = { {"scf.while", "while", "While loop"}, // Arith operations + // Binary arith ops (add, sub, mul, div, rem, xor, and, or) are handled + // as infix operators (a + b, a * b, etc.) in printSimplifiedOp. {"arith.constant", "const", "Constant value"}, - {"arith.addi", "add", "Integer addition"}, - {"arith.subi", "sub", "Integer subtraction"}, - {"arith.muli", "mul", "Integer multiplication"}, - {"arith.divsi", "div", "Signed integer division"}, - {"arith.remsi", "rem", "Signed integer remainder"}, - {"arith.addf", "addf", "Float addition"}, - {"arith.subf", "subf", "Float subtraction"}, - {"arith.mulf", "mulf", "Float multiplication"}, - {"arith.divf", "divf", "Float division"}, {"arith.cmpf", "cmpf", "Float comparison"}, {"arith.cmpi", "cmpi", "Integer comparison"}, {"arith.select", "select", "Select operation"}, {"arith.extui", "extui", "Extend unsigned integer"}, {"arith.extsi", "extsi", "Extend signed integer"}, {"arith.trunci", "trunci", "Truncate integer"}, - {"arith.xori", "xor", "Integer XOR"}, - {"arith.andi", "and", "Integer AND"}, - {"arith.ori", "or", "Integer OR"}, - {"arith.maxf", "maxf", "Float max"}, - {"arith.minf", "minf", "Float min"}, - {"arith.negf", "negf", "Float negation"}, + {"arith.maxf", "tl.maximum", "Float max"}, + {"arith.maxnumf", "tl.maximum", "Float max (NaN-propagating)"}, + {"arith.minf", "tl.minimum", "Float min"}, + {"arith.minnumf", "tl.minimum", "Float min (NaN-propagating)"}, + {"arith.maxsi", "tl.max", "Signed integer max"}, + {"arith.maxui", "tl.max", "Unsigned integer max"}, + {"arith.minsi", "tl.min", "Signed integer min"}, + {"arith.minui", "tl.min", "Unsigned integer min"}, {"arith.sitofp", "sitofp", "Signed int to float"}, {"arith.fptosi", "fptosi", "Float to signed int"}, {"arith.uitofp", "uitofp", "Unsigned int to float"}, @@ -215,28 +211,68 @@ static const TTGIRToTLXMapping opMappings[] = { {"arith.bitcast", "bitcast", "Bitcast"}, // Triton operations - {"tt.splat", "splat", "Splat scalar to tensor"}, - {"tt.broadcast", "broadcast", "Broadcast tensor"}, - {"tt.expand_dims", "expand_dims", "Expand dimensions"}, - {"tt.reduce", "reduce", "Reduce operation"}, - {"tt.dot", "dot", "Matrix multiply"}, - {"tt.load", "load", "Load from global memory"}, - {"tt.store", "store", "Store to global memory"}, + {"tt.splat", "tl.splat", "Splat scalar to tensor"}, + {"tt.broadcast", "tl.broadcast", "Broadcast tensor"}, + {"tt.expand_dims", "tl.expand_dims", "Expand dimensions"}, + {"tt.reduce", "tl.reduce", "Reduce operation"}, + {"tt.dot", "tl.dot", "Matrix multiply"}, + {"tt.load", "tl.load", "Load from global memory"}, + {"tt.store", "tl.store", "Store to global memory"}, {"tt.addptr", "addptr", "Add to pointer"}, - {"tt.make_range", "make_range", "Make range"}, - {"tt.trans", "trans", "Transpose"}, - {"tt.reshape", "reshape", "Reshape tensor"}, - {"tt.cat", "cat", "Concatenate"}, - {"tt.join", "join", "Join tensors"}, - {"tt.split", "split", "Split tensor"}, - {"tt.get_program_id", "get_program_id", "Get program ID"}, - {"tt.get_num_programs", "get_num_programs", "Get number of programs"}, + {"tt.make_range", "tl.arange", "Make range"}, + {"tt.trans", "tl.trans", "Transpose"}, + {"tt.reshape", "tl.reshape", "Reshape tensor"}, + {"tt.cat", "tl.cat", "Concatenate"}, + {"tt.join", "tl.join", "Join tensors"}, + {"tt.split", "tl.split", "Split tensor"}, + {"tt.get_program_id", "tl.program_id", "Get program ID"}, + {"tt.get_num_programs", "tl.num_programs", "Get number of programs"}, {"tt.return", "return", "Return from function"}, + // Math dialect operations + {"math.exp", "tl.math.exp", "Natural exponential"}, + {"math.exp2", "tl.math.exp2", "Base-2 exponential"}, + {"math.log", "tl.math.log", "Natural logarithm"}, + {"math.log2", "tl.math.log2", "Base-2 logarithm"}, + {"math.sin", "tl.math.sin", "Sine"}, + {"math.cos", "tl.math.cos", "Cosine"}, + {"math.sqrt", "tl.math.sqrt", "Square root"}, + {"math.rsqrt", "tl.math.rsqrt", "Reciprocal square root"}, + {"math.erf", "tl.math.erf", "Error function"}, + {"math.floor", "tl.math.floor", "Floor"}, + {"math.ceil", "tl.math.ceil", "Ceiling"}, + {"math.fabs", "tl.math.abs", "Float absolute value"}, + {"math.iabs", "tl.math.abs", "Integer absolute value"}, + {"math.fma", "tl.math.fma", "Fused multiply-add"}, + {"tt.precise_sqrt", "tl.math.sqrt_rn", "IEEE-rounded square root"}, + // GPU operations {"gpu.barrier", "gpu.barrier", "GPU barrier"}, }; +// Infix operator mapping for binary arith ops +llvm::StringMap buildInfixOpMap() { + llvm::StringMap map; + map["arith.addi"] = "+"; + map["arith.addf"] = "+"; + map["arith.subi"] = "-"; + map["arith.subf"] = "-"; + map["arith.muli"] = "*"; + map["arith.mulf"] = "*"; + map["arith.divsi"] = "//"; + map["arith.divui"] = "//"; + map["arith.divf"] = "/"; + map["arith.remsi"] = "%"; + map["arith.remui"] = "%"; + map["arith.xori"] = "^"; + map["arith.andi"] = "&"; + map["arith.ori"] = "|"; + map["arith.shli"] = "<<"; + map["arith.shrsi"] = ">>"; + map["arith.shrui"] = ">>"; + return map; +} + // Build a lookup map for fast operation name lookup llvm::StringMap buildOpNameMap() { llvm::StringMap map; @@ -251,17 +287,58 @@ llvm::StringMap buildOpNameMap() { // values std::string getValueName(Value v, - const DenseMap *argSubstitutionMap = nullptr) { + const DenseMap *argSubstitutionMap = nullptr, + bool inlineConstants = true) { // Check if this value should be substituted if (argSubstitutionMap) { auto it = argSubstitutionMap->find(v); if (it != argSubstitutionMap->end()) { // Recursively get the name of the substituted value (without // substitution) - return getValueName(it->second, nullptr); + return getValueName(it->second, nullptr, inlineConstants); + } + } + + // Pass through convert_layout: use its input operand's name instead + if (Operation *defOp = v.getDefiningOp()) { + if (defOp->getName().getStringRef() == "ttg.convert_layout" && + defOp->getNumOperands() > 0) { + return getValueName(defOp->getOperand(0), argSubstitutionMap, + inlineConstants); + } + } + + // Inline constants: if this value is defined by arith.constant, return the + // literal value + if (inlineConstants) { + if (Operation *defOp = v.getDefiningOp()) { + if (defOp->getName().getStringRef() == "arith.constant") { + if (auto valueAttr = defOp->getAttr("value")) { + std::string result; + llvm::raw_string_ostream os(result); + if (auto intAttr = dyn_cast(valueAttr)) { + if (intAttr.getType().isInteger(1)) { + os << (intAttr.getValue().getBoolValue() ? "True" : "False"); + } else { + os << intAttr.getValue(); + } + } else if (auto floatAttr = dyn_cast(valueAttr)) { + SmallString<16> str; + floatAttr.getValue().toString(str); + os << str; + } else { + // Fall through to normal name handling for unsupported constant + // types + goto normal_name; + } + os.flush(); + return result; + } + } } } +normal_name: std::string name; llvm::raw_string_ostream os(name); v.printAsOperand(os, OpPrintingFlags()); @@ -420,29 +497,29 @@ LocalAllocInfo analyzeLocalAlloc(Operation *localAllocOp) { } // Check if an operation should be skipped because it's folded into -// a barrier alloc +// a barrier alloc or not meaningful in TLX output bool shouldSkipOp(Operation *op, const DenseMap &allocInfoMap, llvm::DenseSet &skippedOps) { StringRef opName = op->getName().getStringRef(); - // Skip init_barrier - it's folded into alloc_barriers - if (opName == "ttng.init_barrier") { - return true; - } - - // Skip warp_return - it's implicit in the with block structure - if (opName == "ttg.warp_return") { - return true; - } - - // Skip warp_yield - it's implicit in the with block structure - if (opName == "ttg.warp_yield") { - return true; - } - - // Skip warp_specialize.partitions - it's not meaningful in TLX format - if (opName == "ttg.warp_specialize.partitions") { + // Operations to skip in TLX output: + // - ttng.init_barrier: folded into alloc_barriers + // - ttg.warp_return/warp_yield: implicit in with block structure + // - ttg.warp_specialize.partitions: not meaningful in TLX format + // - gpu.barrier: not needed in TLX + // - arith.constant: values are inlined at use sites + // - ttg.convert_layout: internal layout conversion + // - tt.return: function terminator + // - tt.reduce.return: internal to reduce operation + static const llvm::StringSet<> opsToSkip = { + "ttng.init_barrier", "ttg.warp_return", + "ttg.warp_yield", "ttg.warp_specialize.partitions", + "gpu.barrier", "arith.constant", + "ttg.convert_layout", "tt.return", + "tt.reduce.return", + }; + if (opsToSkip.contains(opName)) { return true; } @@ -599,6 +676,25 @@ void printIfOp(Operation *op, llvm::raw_ostream &os, } } +// Helper to check if a region has meaningful operations (not just skipped ops) +bool regionHasMeaningfulOps( + Region ®ion, const DenseMap &allocInfoMap, + llvm::DenseSet &skippedOps) { + for (Block &block : region) { + for (Operation &op : block) { + // Skip operations that would be filtered out + if (shouldSkipOp(&op, allocInfoMap, skippedOps)) + continue; + // Skip scf.yield as it's handled specially + if (op.getName().getStringRef() == "scf.yield") + continue; + // Found a meaningful operation + return true; + } + } + return false; +} + // Print warp_specialize operation in TLX async_tasks format void printWarpSpecialize( Operation *op, llvm::raw_ostream &os, @@ -647,6 +743,12 @@ void printWarpSpecialize( "ttg.warp_specialize.partitions") { // Each region in warp_specialize.partitions is a partition for (Region &partitionRegion : innerOp.getRegions()) { + // Skip empty partitions (only contain skipped ops) + if (!regionHasMeaningfulOps(partitionRegion, allocInfoMap, + skippedOps)) { + continue; + } + // Build substitution map for this partition DenseMap argSubstitutionMap; if (!partitionRegion.empty()) { @@ -701,6 +803,46 @@ void printSimplifiedOp( return; } + // Special handling for tt.reshape - print target shape + if (opName == "tt.reshape" && op->getNumResults() > 0) { + if (auto resultType = + dyn_cast(op->getResult(0).getType())) { + os << getValueName(op->getResult(0), argSubstitutionMap) << " = "; + os << "tl.reshape("; + os << getValueName(op->getOperand(0), argSubstitutionMap) << ", ["; + ArrayRef shape = resultType.getShape(); + for (size_t i = 0; i < shape.size(); ++i) { + if (i > 0) + os << ", "; + os << shape[i]; + } + os << "])\n"; + return; + } + } + + // Special handling for binary infix operators (a + b, a * b, etc.) + { + static llvm::StringMap infixOpMap = buildInfixOpMap(); + auto infixIt = infixOpMap.find(opName); + if (infixIt != infixOpMap.end() && op->getNumOperands() == 2 && + op->getNumResults() > 0) { + os << getValueName(op->getResult(0), argSubstitutionMap) << " = " + << getValueName(op->getOperand(0), argSubstitutionMap) << " " + << infixIt->second << " " + << getValueName(op->getOperand(1), argSubstitutionMap) << "\n"; + return; + } + } + + // Special handling for unary negation + if (opName == "arith.negf" && op->getNumOperands() == 1 && + op->getNumResults() > 0) { + os << getValueName(op->getResult(0), argSubstitutionMap) << " = -" + << getValueName(op->getOperand(0), argSubstitutionMap) << "\n"; + return; + } + // Special handling for local_alloc if (opName == "ttg.local_alloc") { auto it = allocInfoMap.find(op);