Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 192 additions & 50 deletions third_party/tlx/dialect/lib/Transforms/PrintTTGIRToTLX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

Expand Down Expand Up @@ -184,28 +185,23 @@ static const TTGIRToTLXMapping opMappings[] = {
{"scf.while", "while", "While loop"},

// Arith operations
// Binary arith ops (add, sub, mul, div, rem, xor, and, or) are handled
// as infix operators (a + b, a * b, etc.) in printSimplifiedOp.
{"arith.constant", "const", "Constant value"},
{"arith.addi", "add", "Integer addition"},
{"arith.subi", "sub", "Integer subtraction"},
{"arith.muli", "mul", "Integer multiplication"},
{"arith.divsi", "div", "Signed integer division"},
{"arith.remsi", "rem", "Signed integer remainder"},
{"arith.addf", "addf", "Float addition"},
{"arith.subf", "subf", "Float subtraction"},
{"arith.mulf", "mulf", "Float multiplication"},
{"arith.divf", "divf", "Float division"},
{"arith.cmpf", "cmpf", "Float comparison"},
{"arith.cmpi", "cmpi", "Integer comparison"},
{"arith.select", "select", "Select operation"},
{"arith.extui", "extui", "Extend unsigned integer"},
{"arith.extsi", "extsi", "Extend signed integer"},
{"arith.trunci", "trunci", "Truncate integer"},
{"arith.xori", "xor", "Integer XOR"},
{"arith.andi", "and", "Integer AND"},
{"arith.ori", "or", "Integer OR"},
{"arith.maxf", "maxf", "Float max"},
{"arith.minf", "minf", "Float min"},
{"arith.negf", "negf", "Float negation"},
{"arith.maxf", "tl.maximum", "Float max"},
{"arith.maxnumf", "tl.maximum", "Float max (NaN-propagating)"},
{"arith.minf", "tl.minimum", "Float min"},
{"arith.minnumf", "tl.minimum", "Float min (NaN-propagating)"},
{"arith.maxsi", "tl.max", "Signed integer max"},
{"arith.maxui", "tl.max", "Unsigned integer max"},
{"arith.minsi", "tl.min", "Signed integer min"},
{"arith.minui", "tl.min", "Unsigned integer min"},
{"arith.sitofp", "sitofp", "Signed int to float"},
{"arith.fptosi", "fptosi", "Float to signed int"},
{"arith.uitofp", "uitofp", "Unsigned int to float"},
Expand All @@ -215,28 +211,68 @@ static const TTGIRToTLXMapping opMappings[] = {
{"arith.bitcast", "bitcast", "Bitcast"},

// Triton operations
{"tt.splat", "splat", "Splat scalar to tensor"},
{"tt.broadcast", "broadcast", "Broadcast tensor"},
{"tt.expand_dims", "expand_dims", "Expand dimensions"},
{"tt.reduce", "reduce", "Reduce operation"},
{"tt.dot", "dot", "Matrix multiply"},
{"tt.load", "load", "Load from global memory"},
{"tt.store", "store", "Store to global memory"},
{"tt.splat", "tl.splat", "Splat scalar to tensor"},
{"tt.broadcast", "tl.broadcast", "Broadcast tensor"},
{"tt.expand_dims", "tl.expand_dims", "Expand dimensions"},
{"tt.reduce", "tl.reduce", "Reduce operation"},
{"tt.dot", "tl.dot", "Matrix multiply"},
{"tt.load", "tl.load", "Load from global memory"},
{"tt.store", "tl.store", "Store to global memory"},
{"tt.addptr", "addptr", "Add to pointer"},
{"tt.make_range", "make_range", "Make range"},
{"tt.trans", "trans", "Transpose"},
{"tt.reshape", "reshape", "Reshape tensor"},
{"tt.cat", "cat", "Concatenate"},
{"tt.join", "join", "Join tensors"},
{"tt.split", "split", "Split tensor"},
{"tt.get_program_id", "get_program_id", "Get program ID"},
{"tt.get_num_programs", "get_num_programs", "Get number of programs"},
{"tt.make_range", "tl.arange", "Make range"},
{"tt.trans", "tl.trans", "Transpose"},
{"tt.reshape", "tl.reshape", "Reshape tensor"},
{"tt.cat", "tl.cat", "Concatenate"},
{"tt.join", "tl.join", "Join tensors"},
{"tt.split", "tl.split", "Split tensor"},
{"tt.get_program_id", "tl.program_id", "Get program ID"},
{"tt.get_num_programs", "tl.num_programs", "Get number of programs"},
{"tt.return", "return", "Return from function"},

// Math dialect operations
{"math.exp", "tl.math.exp", "Natural exponential"},
{"math.exp2", "tl.math.exp2", "Base-2 exponential"},
{"math.log", "tl.math.log", "Natural logarithm"},
{"math.log2", "tl.math.log2", "Base-2 logarithm"},
{"math.sin", "tl.math.sin", "Sine"},
{"math.cos", "tl.math.cos", "Cosine"},
{"math.sqrt", "tl.math.sqrt", "Square root"},
{"math.rsqrt", "tl.math.rsqrt", "Reciprocal square root"},
{"math.erf", "tl.math.erf", "Error function"},
{"math.floor", "tl.math.floor", "Floor"},
{"math.ceil", "tl.math.ceil", "Ceiling"},
{"math.fabs", "tl.math.abs", "Float absolute value"},
{"math.iabs", "tl.math.abs", "Integer absolute value"},
{"math.fma", "tl.math.fma", "Fused multiply-add"},
{"tt.precise_sqrt", "tl.math.sqrt_rn", "IEEE-rounded square root"},

// GPU operations
{"gpu.barrier", "gpu.barrier", "GPU barrier"},
};

// Infix operator mapping for binary arith ops
llvm::StringMap<StringRef> buildInfixOpMap() {
llvm::StringMap<StringRef> map;
map["arith.addi"] = "+";
map["arith.addf"] = "+";
map["arith.subi"] = "-";
map["arith.subf"] = "-";
map["arith.muli"] = "*";
map["arith.mulf"] = "*";
map["arith.divsi"] = "//";
map["arith.divui"] = "//";
map["arith.divf"] = "/";
map["arith.remsi"] = "%";
map["arith.remui"] = "%";
map["arith.xori"] = "^";
map["arith.andi"] = "&";
map["arith.ori"] = "|";
map["arith.shli"] = "<<";
map["arith.shrsi"] = ">>";
map["arith.shrui"] = ">>";
return map;
}

// Build a lookup map for fast operation name lookup
llvm::StringMap<StringRef> buildOpNameMap() {
llvm::StringMap<StringRef> map;
Expand All @@ -251,17 +287,58 @@ llvm::StringMap<StringRef> buildOpNameMap() {
// values
std::string
getValueName(Value v,
const DenseMap<Value, Value> *argSubstitutionMap = nullptr) {
const DenseMap<Value, Value> *argSubstitutionMap = nullptr,
bool inlineConstants = true) {
// Check if this value should be substituted
if (argSubstitutionMap) {
auto it = argSubstitutionMap->find(v);
if (it != argSubstitutionMap->end()) {
// Recursively get the name of the substituted value (without
// substitution)
return getValueName(it->second, nullptr);
return getValueName(it->second, nullptr, inlineConstants);
}
}

// Pass through convert_layout: use its input operand's name instead
if (Operation *defOp = v.getDefiningOp()) {
if (defOp->getName().getStringRef() == "ttg.convert_layout" &&
defOp->getNumOperands() > 0) {
return getValueName(defOp->getOperand(0), argSubstitutionMap,
inlineConstants);
}
}

// Inline constants: if this value is defined by arith.constant, return the
// literal value
if (inlineConstants) {
if (Operation *defOp = v.getDefiningOp()) {
if (defOp->getName().getStringRef() == "arith.constant") {
if (auto valueAttr = defOp->getAttr("value")) {
std::string result;
llvm::raw_string_ostream os(result);
if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
if (intAttr.getType().isInteger(1)) {
os << (intAttr.getValue().getBoolValue() ? "True" : "False");
} else {
os << intAttr.getValue();
}
} else if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
SmallString<16> str;
floatAttr.getValue().toString(str);
os << str;
} else {
// Fall through to normal name handling for unsupported constant
// types
goto normal_name;
}
os.flush();
return result;
}
}
}
}

normal_name:
std::string name;
llvm::raw_string_ostream os(name);
v.printAsOperand(os, OpPrintingFlags());
Expand Down Expand Up @@ -420,29 +497,29 @@ LocalAllocInfo analyzeLocalAlloc(Operation *localAllocOp) {
}

// Check if an operation should be skipped because it's folded into
// a barrier alloc
// a barrier alloc or not meaningful in TLX output
bool shouldSkipOp(Operation *op,
const DenseMap<Operation *, LocalAllocInfo> &allocInfoMap,
llvm::DenseSet<Operation *> &skippedOps) {
StringRef opName = op->getName().getStringRef();

// Skip init_barrier - it's folded into alloc_barriers
if (opName == "ttng.init_barrier") {
return true;
}

// Skip warp_return - it's implicit in the with block structure
if (opName == "ttg.warp_return") {
return true;
}

// Skip warp_yield - it's implicit in the with block structure
if (opName == "ttg.warp_yield") {
return true;
}

// Skip warp_specialize.partitions - it's not meaningful in TLX format
if (opName == "ttg.warp_specialize.partitions") {
// Operations to skip in TLX output:
// - ttng.init_barrier: folded into alloc_barriers
// - ttg.warp_return/warp_yield: implicit in with block structure
// - ttg.warp_specialize.partitions: not meaningful in TLX format
// - gpu.barrier: not needed in TLX
// - arith.constant: values are inlined at use sites
// - ttg.convert_layout: internal layout conversion
// - tt.return: function terminator
// - tt.reduce.return: internal to reduce operation
static const llvm::StringSet<> opsToSkip = {
"ttng.init_barrier", "ttg.warp_return",
"ttg.warp_yield", "ttg.warp_specialize.partitions",
"gpu.barrier", "arith.constant",
"ttg.convert_layout", "tt.return",
"tt.reduce.return",
};
if (opsToSkip.contains(opName)) {
return true;
}

Expand Down Expand Up @@ -599,6 +676,25 @@ void printIfOp(Operation *op, llvm::raw_ostream &os,
}
}

// Helper to check if a region has meaningful operations (not just skipped ops)
bool regionHasMeaningfulOps(
Region &region, const DenseMap<Operation *, LocalAllocInfo> &allocInfoMap,
llvm::DenseSet<Operation *> &skippedOps) {
for (Block &block : region) {
for (Operation &op : block) {
// Skip operations that would be filtered out
if (shouldSkipOp(&op, allocInfoMap, skippedOps))
continue;
// Skip scf.yield as it's handled specially
if (op.getName().getStringRef() == "scf.yield")
continue;
// Found a meaningful operation
return true;
}
}
return false;
}

// Print warp_specialize operation in TLX async_tasks format
void printWarpSpecialize(
Operation *op, llvm::raw_ostream &os,
Expand Down Expand Up @@ -647,6 +743,12 @@ void printWarpSpecialize(
"ttg.warp_specialize.partitions") {
// Each region in warp_specialize.partitions is a partition
for (Region &partitionRegion : innerOp.getRegions()) {
// Skip empty partitions (only contain skipped ops)
if (!regionHasMeaningfulOps(partitionRegion, allocInfoMap,
skippedOps)) {
continue;
}

// Build substitution map for this partition
DenseMap<Value, Value> argSubstitutionMap;
if (!partitionRegion.empty()) {
Expand Down Expand Up @@ -701,6 +803,46 @@ void printSimplifiedOp(
return;
}

// Special handling for tt.reshape - print target shape
if (opName == "tt.reshape" && op->getNumResults() > 0) {
if (auto resultType =
dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
os << getValueName(op->getResult(0), argSubstitutionMap) << " = ";
os << "tl.reshape(";
os << getValueName(op->getOperand(0), argSubstitutionMap) << ", [";
ArrayRef<int64_t> shape = resultType.getShape();
for (size_t i = 0; i < shape.size(); ++i) {
if (i > 0)
os << ", ";
os << shape[i];
}
os << "])\n";
return;
}
}

// Special handling for binary infix operators (a + b, a * b, etc.)
{
static llvm::StringMap<StringRef> infixOpMap = buildInfixOpMap();
auto infixIt = infixOpMap.find(opName);
if (infixIt != infixOpMap.end() && op->getNumOperands() == 2 &&
op->getNumResults() > 0) {
os << getValueName(op->getResult(0), argSubstitutionMap) << " = "
<< getValueName(op->getOperand(0), argSubstitutionMap) << " "
<< infixIt->second << " "
<< getValueName(op->getOperand(1), argSubstitutionMap) << "\n";
return;
}
}

// Special handling for unary negation
if (opName == "arith.negf" && op->getNumOperands() == 1 &&
op->getNumResults() > 0) {
os << getValueName(op->getResult(0), argSubstitutionMap) << " = -"
<< getValueName(op->getOperand(0), argSubstitutionMap) << "\n";
return;
}

// Special handling for local_alloc
if (opName == "ttg.local_alloc") {
auto it = allocInfoMap.find(op);
Expand Down
Loading