Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions bin/triton-tensor-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,41 +39,44 @@ using namespace mlir;
// CLI options
//===--------------------------------------------------------------------===//

cl::OptionCategory PrinterCategory("Available Print Options",
"Options for the tensor layout printing.");
static cl::OptionCategory &getPrinterCategory() {
static cl::OptionCategory PrinterCategory(
"Available Print Options", "Options for the tensor layout printing.");
return PrinterCategory;
}

static cl::opt<std::string> InputFile(
"i", cl::desc("File that contains the tensor data layout attributes"),
cl::init(""), cl::value_desc("filename"), cl::cat(PrinterCategory));
cl::init(""), cl::value_desc("filename"), cl::cat(getPrinterCategory()));

static cl::opt<std::string>
OutputFile("o", cl::desc("Output file to write the layout into"),
cl::init(""), cl::value_desc("filename"),
cl::cat(PrinterCategory));
cl::cat(getPrinterCategory()));

static cl::opt<std::string>
DataLayoutStr("l", cl::desc("Tensor data layout attribute in string"),
cl::value_desc("layout-string"), cl::init(""),
cl::cat(PrinterCategory));
cl::cat(getPrinterCategory()));

static cl::list<std::string>
AliasName("alias-names",
cl::desc("A list of alias names (separated by comma) of the "
"layout attributes in the input file"),
cl::value_desc("name1,name2,name3,..."), cl::CommaSeparated,
cl::ZeroOrMore, cl::cat(PrinterCategory));
cl::ZeroOrMore, cl::cat(getPrinterCategory()));

static cl::opt<bool> UseHWPointOfView(
"use-hw-view",
llvm::cl::desc(
"Print the layout in hardware point of view. This means the output is "
"from the warp's perspective. Otherwise, the output is from the "
"tensor's perspective (e.g., each element maps to xxx thread)."),
cl::init(false), cl::cat(PrinterCategory));
cl::init(false), cl::cat(getPrinterCategory()));

static cl::opt<std::string> TensorStr(
"t", cl::desc("Tensor shape and element type (e.g., tensor<2x2xf32>)"),
cl::init(""), cl::value_desc("tensor-type"), cl::cat(PrinterCategory));
cl::init(""), cl::value_desc("tensor-type"), cl::cat(getPrinterCategory()));

//===--------------------------------------------------------------------===//
// Helper functions
Expand Down Expand Up @@ -180,7 +183,7 @@ static LogicalResult printLayoutFromString(MLIRContext *context,
//===--------------------------------------------------------------------===//

int main(int argc, char **argv) {
cl::HideUnrelatedOptions(PrinterCategory);
cl::HideUnrelatedOptions(getPrinterCategory());
cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n");

DialectRegistry registry;
Expand Down
Loading