Skip to content

Commit

Permalink
Added region arg support & ambiguous name test
Browse files Browse the repository at this point in the history
  • Loading branch information
Wheest committed Jan 29, 2024
1 parent 233a987 commit b257ab5
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 62 deletions.
28 changes: 22 additions & 6 deletions mlir/lib/AsmParser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,8 +611,9 @@ class OperationParser : public Parser {
/// an object of type 'OperationName'. Otherwise, failure is returned.
FailureOr<OperationName> parseCustomOperationName();

/// Store the SSA names for the current operation as attrs for debug purposes.
void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
/// Store the identifier names for the current operation as attrs for debug
/// purposes.
void storeIdentifierNames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
DenseMap<Value, StringRef> argNames;

//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -1273,8 +1274,8 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
}

/// Store the SSA names for the current operation as attrs for debug purposes.
void OperationParser::storeSSANames(Operation *&op,
ArrayRef<ResultRecord> resultIDs) {
void OperationParser::storeIdentifierNames(Operation *&op,
ArrayRef<ResultRecord> resultIDs) {

// Store the name(s) of the result(s) of this operation.
if (op->getNumResults() > 0) {
Expand Down Expand Up @@ -1322,6 +1323,18 @@ void OperationParser::storeSSANames(Operation *&op,
}
}
}

// Store names of region arguments (e.g., for FuncOps)
if (op->getNumRegions() > 0 && op->getRegion(0).getNumArguments() > 0) {
llvm::SmallVector<llvm::StringRef, 1> regionArgNames;
for (BlockArgument arg : op->getRegion(0).getArguments()) {
auto it = argNames.find(arg);
if (it != argNames.end()) {
regionArgNames.push_back(it->second.drop_front(1));
}
}
op->setAttr("mlir.regionArgNames", builder.getStrArrayAttr(regionArgNames));
}
}

namespace {
Expand Down Expand Up @@ -2093,9 +2106,9 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
// Otherwise, create the operation and try to parse a location for it.
Operation *op = opBuilder.create(opState);

// If enabled, store the SSA name(s) for the operation
// If enabled, store the original identifier name(s) for the operation
if (state.config.shouldRetainIdentifierNames())
storeSSANames(op, resultIDs);
storeIdentifierNames(op, resultIDs);

if (parseTrailingLocationSpecifier(op))
return nullptr;
Expand Down Expand Up @@ -2246,6 +2259,9 @@ ParseResult OperationParser::parseRegionBody(Region &region, SMLoc startLoc,
if (state.asmState)
state.asmState->addDefinition(arg, argInfo.location);

if (state.config.shouldRetainIdentifierNames())
argNames.insert({arg, argInfo.name});

// Record the definition for this argument.
if (addDefinition(argInfo, arg))
return failure();
Expand Down
122 changes: 69 additions & 53 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1303,7 +1303,9 @@ class SSANameState {
/// Set the original identifier names if available. Used in debugging with
/// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
void setRetainedIdentifierNames(Operation &op,
SmallVector<int, 2> &resultGroups);
SmallVector<int, 2> &resultGroups,
bool hasRegion = false);
void setRetainedIdentifierNames(Region &region);

/// This is the value ID for each SSA value. If this returns NameSentinel,
/// then the valueID has an entry in valueNames.
Expand Down Expand Up @@ -1492,6 +1494,9 @@ void SSANameState::numberValuesInRegion(Region &region) {
setValueName(arg, name);
};

// Use manually specified region arg names if available
setRetainedIdentifierNames(region);

if (!printerFlags.shouldPrintGenericOpForm()) {
if (Operation *op = region.getParentOp()) {
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
Expand Down Expand Up @@ -1603,64 +1608,75 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}

void SSANameState::setRetainedIdentifierNames(
Operation &op, SmallVector<int, 2> &resultGroups) {
// Get the original names for the results if available
if (ArrayAttr resultNamesAttr =
op.getAttrOfType<ArrayAttr>("mlir.resultNames")) {
auto resultNames = resultNamesAttr.getValue();
auto results = op.getResults();
// Conservative in the case that the #results has changed
for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) {
auto resultName = resultNames[i].cast<StringAttr>().strref();
if (!resultName.empty()) {
if (!usedNames.count(resultName))
setValueName(results[i], resultName, /*allowNumeric=*/true);
// If a result has a name, it is the start of a result group.
if (i > 0)
resultGroups.push_back(i);
}
}
op.removeDiscardableAttr("mlir.resultNames");
}

// Get the original name for the op args if available
if (ArrayAttr opArgNamesAttr =
op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
auto opArgNames = opArgNamesAttr.getValue();
auto opArgs = op.getOperands();
// Conservative in the case that the #operands has changed
for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
auto opArgName = opArgNames[i].cast<StringAttr>().strref();
if (!usedNames.count(opArgName))
setValueName(opArgs[i], opArgName, /*allowNumeric=*/true);
void SSANameState::setRetainedIdentifierNames(Operation &op,
SmallVector<int, 2> &resultGroups,
bool hasRegion) {

// Lambda which fetches the list of relevant attributes (e.g.,
// mlir.resultNames) and associates them with the relevant values
auto handleNamedAttributes =
[this](Operation &op, const Twine &attrName, auto getValuesFunc,
std::optional<std::function<void(int)>> customAction =
std::nullopt) {
if (ArrayAttr namesAttr = op.getAttrOfType<ArrayAttr>(attrName.str())) {
auto names = namesAttr.getValue();
auto values = getValuesFunc();
// Conservative in case the number of values has changed
for (size_t i = 0; i < values.size() && i < names.size(); ++i) {
auto name = names[i].cast<StringAttr>().strref();
if (!name.empty()) {
if (!this->usedNames.count(name))
this->setValueName(values[i], name, true);
if (customAction.has_value())
customAction.value()(i);
}
}
op.removeDiscardableAttr(attrName.str());
}
};

if (hasRegion) {
// Get the original name(s) for the region arg(s) if available (e.g., for
// FuncOp args). Requires hasRegion flag to ensure scoping is correct
if (hasRegion && op.getNumRegions() > 0 &&
op.getRegion(0).getNumArguments() > 0) {
handleNamedAttributes(op, "mlir.regionArgNames",
[&]() { return op.getRegion(0).getArguments(); });
}
op.removeDiscardableAttr("mlir.opArgNames");
}

// Get the original name for the block if available
if (StringAttr blockNameAttr =
op.getAttrOfType<StringAttr>("mlir.blockName")) {
blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
op.removeDiscardableAttr("mlir.blockName");
}

// Get the original name for the block args if available
if (ArrayAttr blockArgNamesAttr =
op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
auto blockArgNames = blockArgNamesAttr.getValue();
auto blockArgs = op.getBlock()->getArguments();
// Conservative in the case that the #args has changed
for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
if (!usedNames.count(blockArgName))
setValueName(blockArgs[i], blockArgName, /*allowNumeric=*/true);
} else {
// Get the original names for the results if available
handleNamedAttributes(
op, "mlir.resultNames", [&]() { return op.getResults(); },
[&resultGroups](int i) { /*handles result groups*/
if (i > 0)
resultGroups.push_back(i);
});

// Get the original name for the op args if available
handleNamedAttributes(op, "mlir.opArgNames",
[&]() { return op.getOperands(); });

// Get the original name for the block if available
if (StringAttr blockNameAttr =
op.getAttrOfType<StringAttr>("mlir.blockName")) {
blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
op.removeDiscardableAttr("mlir.blockName");
}
op.removeDiscardableAttr("mlir.blockArgNames");

// Get the original name(s) for the block arg(s) if available
handleNamedAttributes(op, "mlir.blockArgNames",
[&]() { return op.getBlock()->getArguments(); });
}
return;
}

void SSANameState::setRetainedIdentifierNames(Region &region) {
if (Operation *op = region.getParentOp()) {
SmallVector<int, 2> resultGroups;
setRetainedIdentifierNames(*op, resultGroups, true);
}
}

void SSANameState::getResultIDAndNumber(
OpResult result, Value &lookupValue,
std::optional<int> &lookupResultNo) const {
Expand Down
27 changes: 24 additions & 3 deletions mlir/test/IR/print-retain-identifiers.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
// Test SSA results (with single return values)
//===----------------------------------------------------------------------===//

// CHECK: func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
// CHECK: func.func @add_one(%my_input: f64) -> f64 {
func.func @add_one(%my_input: f64) -> f64 {
// CHECK: %my_constant = arith.constant 1.000000e+00 : f64
%my_constant = arith.constant 1.000000e+00 : f64
// CHECK: %my_output = arith.addf %my_input, %my_constant : f64
Expand Down Expand Up @@ -71,7 +71,7 @@ func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) {

// -----

////===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Test multiple return values, with a grouped value tuple
//===----------------------------------------------------------------------===//

Expand All @@ -90,3 +90,24 @@ func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64
}

// -----

//===----------------------------------------------------------------------===//
// Test identifiers which may clash with OpAsmOpInterface names (e.g., cst, %1, etc)
//===----------------------------------------------------------------------===//

// CHECK: func.func @clash(%arg1: f64, %arg0: f64, %arg2: f64) -> f64 {
func.func @clash(%arg1: f64, %arg0: f64, %arg2: f64) -> f64 {
%my_constant = arith.constant 1.000000e+00 : f64
// CHECK: %cst = arith.constant 2.000000e+00 : f64
%cst = arith.constant 2.000000e+00 : f64
// CHECK: %cst_1 = arith.constant 3.000000e+00 : f64
%cst_1 = arith.constant 3.000000e+00 : f64
// CHECK: %1 = arith.addf %arg1, %cst : f64
%1 = arith.addf %arg1, %cst : f64
// CHECK: %0 = arith.addf %arg1, %cst_1 : f64
%0 = arith.addf %arg1, %cst_1 : f64
// CHECK: return %1 : f64
return %1 : f64
}

// -----

0 comments on commit b257ab5

Please sign in to comment.