-
Notifications
You must be signed in to change notification settings - Fork 12.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] add option to print SSA IDs using NameLoc
s as prefixes
#119996
[mlir] add option to print SSA IDs using NameLoc
s as prefixes
#119996
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
e9fd4aa
to
9e6c00d
Compare
33116cf
to
8fedeff
Compare
8fedeff
to
bbca902
Compare
579868a
to
ca084f9
Compare
bf24d4a
to
457f56a
Compare
NameLoc
s as prefixes
f87e003
to
e7921cf
Compare
d61562c
to
ecb6adc
Compare
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesThis PR adds an For example: %1 = memref.load %0[] : memref<i32> loc("alice") prints %alice = memref.load %0[] : memref<i32> Currently single/multiple results and bb args are handled: %1:2 = call @<!-- -->make_two_results() : () -> (index, index) loc("bar")
%2:2 = scf.while (%arg1 = %1#<!-- -->0, %arg2 = %1#<!-- -->0) : (index, index) -> (index, index) {
%3 = arith.cmpi slt, %arg1, %arg2 : index
scf.condition(%3) %arg1, %arg2 : index, index
} do {
^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
%c1, %c2 = func.call @<!-- -->make_two_results() : () -> (index, index) loc("harriet")
scf.yield %c2, %c2 : index, index
} loc("kevin") becomes %bar:2 = call @<!-- -->make_two_results() : () -> (index, index)
%kevin:2 = scf.while (%arg1 = %bar#<!-- -->0, %arg2 = %bar#<!-- -->0) : (index, index) -> (index, index) {
%0 = arith.cmpi slt, %arg1, %arg2 : index
scf.condition(%0) %arg1, %arg2 : index, index
} do {
^bb0(%alice: index, %bob: index):
%harriet:2 = func.call @<!-- -->make_two_results() : () -> (index, index)
scf.yield %harriet#<!-- -->1, %harriet#<!-- -->1 : index, index
} The changes here are also compatible with TestingBesides the added lit test, I "stress tested" this by turn it on by default and checking if anything broke (this commit), then fixing my bugs and stress testing again (this commit and this [passing test](https://buildkite.com/llvm-project/github-pull-requests/builds/129062#0193c8fd-9284-4c5f-b4c0-ec337d75f7be), Windows test fail there was a flake). Full diff: https://github.com/llvm/llvm-project/pull/119996.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1b93f3d3d04fe8..9f2de582b03e56 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1221,6 +1221,10 @@ class OpPrintingFlags {
/// Return if printer should use unique SSA IDs.
bool shouldPrintUniqueSSAIDs() const;
+ /// Return if the printer should use NameLocs as prefixes when printing SSA
+ /// IDs
+ bool shouldUseNameLocAsPrefix() const;
+
private:
/// Elide large elements attributes if the number of elements is larger than
/// the upper limit.
@@ -1254,6 +1258,9 @@ class OpPrintingFlags {
/// Print unique SSA IDs for values, block arguments and naming conflicts
bool printUniqueSSAIDsFlag : 1;
+
+ /// Print SSA IDs using NameLocs as prefixes
+ bool useNameLocAsPrefix : 1;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 61b90bc9b0a7bb..1a81414b358d06 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -73,7 +73,8 @@ OpAsmParser::~OpAsmParser() = default;
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
/// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
return parseCommaSeparatedList(
[&]() { return parseType(result.emplace_back()); });
@@ -195,6 +196,10 @@ struct AsmPrinterOptions {
"mlir-print-unique-ssa-ids", llvm::cl::init(false),
llvm::cl::desc("Print unique SSA ID numbers for values, block arguments "
"and naming conflicts across all regions")};
+
+ llvm::cl::opt<bool> useNameLocAsPrefix{
+ "mlir-use-nameloc-as-prefix", llvm::cl::init(false),
+ llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")};
};
} // namespace
@@ -212,7 +217,8 @@ OpPrintingFlags::OpPrintingFlags()
: printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
printGenericOpFormFlag(false), skipRegionsFlag(false),
assumeVerifiedFlag(false), printLocalScope(false),
- printValueUsersFlag(false), printUniqueSSAIDsFlag(false) {
+ printValueUsersFlag(false), printUniqueSSAIDsFlag(false),
+ useNameLocAsPrefix(false) {
// Initialize based upon command line options, if they are available.
if (!clOptions.isConstructed())
return;
@@ -231,6 +237,7 @@ OpPrintingFlags::OpPrintingFlags()
skipRegionsFlag = clOptions->skipRegionsOpt;
printValueUsersFlag = clOptions->printValueUsers;
printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs;
+ useNameLocAsPrefix = clOptions->useNameLocAsPrefix;
}
/// Enable the elision of large elements attributes, by printing a '...'
@@ -362,6 +369,11 @@ bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const {
return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
}
+/// Return if the printer should use NameLocs as prefixes when printing SSA IDs
+bool OpPrintingFlags::shouldUseNameLocAsPrefix() const {
+ return useNameLocAsPrefix;
+}
+
//===----------------------------------------------------------------------===//
// NewLineCounter
//===----------------------------------------------------------------------===//
@@ -1511,13 +1523,30 @@ void SSANameState::numberValuesInRegion(Region ®ion) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion &&
"arg not defined in current region");
- setValueName(arg, name);
+ if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) {
+ auto nameLoc = cast<NameLoc>(arg.getLoc());
+ setValueName(arg, nameLoc.getName());
+ } else {
+ setValueName(arg, name);
+ }
};
+ bool alreadySetNames = false;
if (!printerFlags.shouldPrintGenericOpForm()) {
if (Operation *op = region.getParentOp()) {
- if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
+ if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op)) {
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
+ alreadySetNames = true;
+ }
+ }
+ }
+
+ if (printerFlags.shouldUseNameLocAsPrefix() && !alreadySetNames) {
+ for (BlockArgument arg : region.getArguments()) {
+ if (isa<NameLoc>(arg.getLoc())) {
+ auto nameLoc = cast<NameLoc>(arg.getLoc());
+ setBlockArgNameFn(arg, nameLoc.getName());
+ }
}
}
@@ -1553,7 +1582,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
specialNameBuffer.resize(strlen("arg"));
specialName << nextArgumentID++;
}
- setValueName(arg, specialName.str());
+ if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) {
+ auto nameLoc = cast<NameLoc>(arg.getLoc());
+ setValueName(arg, nameLoc.getName());
+ } else {
+ setValueName(arg, specialName.str());
+ }
}
// Number the operations in this block.
@@ -1567,7 +1601,13 @@ void SSANameState::numberValuesInOp(Operation &op) {
auto setResultNameFn = [&](Value result, StringRef name) {
assert(!valueIDs.count(result) && "result numbered multiple times");
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
- setValueName(result, name);
+ if (printerFlags.shouldUseNameLocAsPrefix() &&
+ isa<NameLoc>(result.getLoc())) {
+ auto nameLoc = cast<NameLoc>(result.getLoc());
+ setValueName(result, nameLoc.getName());
+ } else {
+ setValueName(result, name);
+ }
// Record the result number for groups not anchored at 0.
if (int resultNo = llvm::cast<OpResult>(result).getResultNumber())
@@ -1589,14 +1629,25 @@ void SSANameState::numberValuesInOp(Operation &op) {
blockNames[block] = {-1, name};
};
+ bool alreadySetNames = false;
if (!printerFlags.shouldPrintGenericOpForm()) {
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
asmInterface.getAsmBlockNames(setBlockNameFn);
asmInterface.getAsmResultNames(setResultNameFn);
+ alreadySetNames = true;
}
}
unsigned numResults = op.getNumResults();
+ if (printerFlags.shouldUseNameLocAsPrefix() && !alreadySetNames &&
+ numResults > 0) {
+ Value resultBegin = op.getResult(0);
+ if (isa<NameLoc>(resultBegin.getLoc())) {
+ auto nameLoc = cast<NameLoc>(resultBegin.getLoc());
+ setResultNameFn(resultBegin, nameLoc.getName());
+ }
+ }
+
if (numResults == 0) {
// If value users should be printed, operations with no result need an id.
if (printerFlags.shouldPrintValueUsers()) {
diff --git a/mlir/test/IR/print-use-nameloc-as-prefix.mlir b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
new file mode 100644
index 00000000000000..fb555d9708ee86
--- /dev/null
+++ b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2' -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s --check-prefix=CHECK-PASS-PRESERVE
+
+// CHECK-LABEL: test_basic
+func.func @test_basic() {
+ %0 = memref.alloc() : memref<i32>
+ // CHECK: %alice
+ %1 = memref.load %0[] : memref<i32> loc("alice")
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_repeat_namelocs
+func.func @test_repeat_namelocs() {
+ %0 = memref.alloc() : memref<i32>
+ // CHECK: %alice
+ %1 = memref.load %0[] : memref<i32> loc("alice")
+ // CHECK: %alice_0
+ %2 = memref.load %0[] : memref<i32> loc("alice")
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_bb_args
+func.func @test_bb_args1(%arg0 : memref<i32> loc("foo")) {
+ // CHECK: %alice = memref.load %foo
+ %1 = memref.load %arg0[] : memref<i32> loc("alice")
+ return
+}
+
+// -----
+
+func.func private @make_two_results() -> (index, index)
+
+// CHECK-LABEL: test_multiple_results
+func.func @test_multiple_results(%cond: i1) {
+ // CHECK: %foo:2
+ %0:2 = call @make_two_results() : () -> (index, index) loc("foo")
+ // CHECK: %bar:2
+ %1, %2 = call @make_two_results() : () -> (index, index) loc("bar")
+
+ // CHECK: %kevin:2 = scf.while (%arg1 = %bar#0, %arg2 = %bar#0)
+ %5:2 = scf.while (%arg1 = %1, %arg2 = %1) : (index, index) -> (index, index) {
+ %6 = arith.cmpi slt, %arg1, %arg2 : index
+ scf.condition(%6) %arg1, %arg2 : index, index
+ } do {
+ // CHECK: ^bb0(%alice: index, %bob: index)
+ ^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
+ %c1, %c2 = func.call @make_two_results() : () -> (index, index) loc("harriet")
+ // CHECK: scf.yield %harriet#1, %harriet#1
+ scf.yield %c2, %c2 : index, index
+ } loc("kevin")
+ return
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+#trait = {
+ iterator_types = ["parallel"],
+ indexing_maps = [#map, #map, #map]
+}
+
+// CHECK-LABEL: test_op_asm_interface
+func.func @test_op_asm_interface(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+ // CHECK: %c0
+ %0 = arith.constant 0 : index
+ // CHECK: %foo
+ %1 = arith.constant 1 : index loc("foo")
+
+ linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+ // CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32)
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ linalg.yield %a, %a : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+
+ linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+ // CHECK: ^bb0(%bar: f32, %alice: f32, %steve: f32)
+ ^bb0(%a: f32 loc("bar"), %b: f32 loc("alice"), %c: f32 loc("steve")):
+ // CHECK: linalg.yield %alice, %steve
+ linalg.yield %b, %c : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_pass
+func.func @test_pass(%arg0: memref<4xf32>, %arg1: memref<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ scf.for %arg2 = %c0 to %c4 step %c1 {
+ // CHECK-PASS-PRESERVE: %foo = memref.load
+ // CHECK-PASS-PRESERVE: memref.store %foo
+ // CHECK-PASS-PRESERVE: %foo_1 = memref.load
+ // CHECK-PASS-PRESERVE: memref.store %foo_1
+ %0 = memref.load %arg0[%arg2] : memref<4xf32> loc("foo")
+ memref.store %0, %arg1[%arg2] : memref<4xf32>
+ }
+ return
+}
|
@llvm/pr-subscribers-mlir-core Author: Maksim Levental (makslevental) ChangesThis PR adds an For example: %1 = memref.load %0[] : memref<i32> loc("alice") prints %alice = memref.load %0[] : memref<i32> Currently single/multiple results and bb args are handled: %1:2 = call @<!-- -->make_two_results() : () -> (index, index) loc("bar")
%2:2 = scf.while (%arg1 = %1#<!-- -->0, %arg2 = %1#<!-- -->0) : (index, index) -> (index, index) {
%3 = arith.cmpi slt, %arg1, %arg2 : index
scf.condition(%3) %arg1, %arg2 : index, index
} do {
^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
%c1, %c2 = func.call @<!-- -->make_two_results() : () -> (index, index) loc("harriet")
scf.yield %c2, %c2 : index, index
} loc("kevin") becomes %bar:2 = call @<!-- -->make_two_results() : () -> (index, index)
%kevin:2 = scf.while (%arg1 = %bar#<!-- -->0, %arg2 = %bar#<!-- -->0) : (index, index) -> (index, index) {
%0 = arith.cmpi slt, %arg1, %arg2 : index
scf.condition(%0) %arg1, %arg2 : index, index
} do {
^bb0(%alice: index, %bob: index):
%harriet:2 = func.call @<!-- -->make_two_results() : () -> (index, index)
scf.yield %harriet#<!-- -->1, %harriet#<!-- -->1 : index, index
} The changes here are also compatible with TestingBesides the added lit test, I "stress tested" this by turn it on by default and checking if anything broke (this commit), then fixing my bugs and stress testing again (this commit and this [passing test](https://buildkite.com/llvm-project/github-pull-requests/builds/129062#0193c8fd-9284-4c5f-b4c0-ec337d75f7be), Windows test fail there was a flake). Full diff: https://github.com/llvm/llvm-project/pull/119996.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1b93f3d3d04fe8..9f2de582b03e56 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1221,6 +1221,10 @@ class OpPrintingFlags {
/// Return if printer should use unique SSA IDs.
bool shouldPrintUniqueSSAIDs() const;
+ /// Return if the printer should use NameLocs as prefixes when printing SSA
+ /// IDs
+ bool shouldUseNameLocAsPrefix() const;
+
private:
/// Elide large elements attributes if the number of elements is larger than
/// the upper limit.
@@ -1254,6 +1258,9 @@ class OpPrintingFlags {
/// Print unique SSA IDs for values, block arguments and naming conflicts
bool printUniqueSSAIDsFlag : 1;
+
+ /// Print SSA IDs using NameLocs as prefixes
+ bool useNameLocAsPrefix : 1;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 61b90bc9b0a7bb..1a81414b358d06 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -73,7 +73,8 @@ OpAsmParser::~OpAsmParser() = default;
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
/// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
return parseCommaSeparatedList(
[&]() { return parseType(result.emplace_back()); });
@@ -195,6 +196,10 @@ struct AsmPrinterOptions {
"mlir-print-unique-ssa-ids", llvm::cl::init(false),
llvm::cl::desc("Print unique SSA ID numbers for values, block arguments "
"and naming conflicts across all regions")};
+
+ llvm::cl::opt<bool> useNameLocAsPrefix{
+ "mlir-use-nameloc-as-prefix", llvm::cl::init(false),
+ llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")};
};
} // namespace
@@ -212,7 +217,8 @@ OpPrintingFlags::OpPrintingFlags()
: printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
printGenericOpFormFlag(false), skipRegionsFlag(false),
assumeVerifiedFlag(false), printLocalScope(false),
- printValueUsersFlag(false), printUniqueSSAIDsFlag(false) {
+ printValueUsersFlag(false), printUniqueSSAIDsFlag(false),
+ useNameLocAsPrefix(false) {
// Initialize based upon command line options, if they are available.
if (!clOptions.isConstructed())
return;
@@ -231,6 +237,7 @@ OpPrintingFlags::OpPrintingFlags()
skipRegionsFlag = clOptions->skipRegionsOpt;
printValueUsersFlag = clOptions->printValueUsers;
printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs;
+ useNameLocAsPrefix = clOptions->useNameLocAsPrefix;
}
/// Enable the elision of large elements attributes, by printing a '...'
@@ -362,6 +369,11 @@ bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const {
return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
}
+/// Return if the printer should use NameLocs as prefixes when printing SSA IDs
+bool OpPrintingFlags::shouldUseNameLocAsPrefix() const {
+ return useNameLocAsPrefix;
+}
+
//===----------------------------------------------------------------------===//
// NewLineCounter
//===----------------------------------------------------------------------===//
@@ -1511,13 +1523,30 @@ void SSANameState::numberValuesInRegion(Region ®ion) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion &&
"arg not defined in current region");
- setValueName(arg, name);
+ if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) {
+ auto nameLoc = cast<NameLoc>(arg.getLoc());
+ setValueName(arg, nameLoc.getName());
+ } else {
+ setValueName(arg, name);
+ }
};
+ bool alreadySetNames = false;
if (!printerFlags.shouldPrintGenericOpForm()) {
if (Operation *op = region.getParentOp()) {
- if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
+ if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op)) {
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
+ alreadySetNames = true;
+ }
+ }
+ }
+
+ if (printerFlags.shouldUseNameLocAsPrefix() && !alreadySetNames) {
+ for (BlockArgument arg : region.getArguments()) {
+ if (isa<NameLoc>(arg.getLoc())) {
+ auto nameLoc = cast<NameLoc>(arg.getLoc());
+ setBlockArgNameFn(arg, nameLoc.getName());
+ }
}
}
@@ -1553,7 +1582,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
specialNameBuffer.resize(strlen("arg"));
specialName << nextArgumentID++;
}
- setValueName(arg, specialName.str());
+ if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) {
+ auto nameLoc = cast<NameLoc>(arg.getLoc());
+ setValueName(arg, nameLoc.getName());
+ } else {
+ setValueName(arg, specialName.str());
+ }
}
// Number the operations in this block.
@@ -1567,7 +1601,13 @@ void SSANameState::numberValuesInOp(Operation &op) {
auto setResultNameFn = [&](Value result, StringRef name) {
assert(!valueIDs.count(result) && "result numbered multiple times");
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
- setValueName(result, name);
+ if (printerFlags.shouldUseNameLocAsPrefix() &&
+ isa<NameLoc>(result.getLoc())) {
+ auto nameLoc = cast<NameLoc>(result.getLoc());
+ setValueName(result, nameLoc.getName());
+ } else {
+ setValueName(result, name);
+ }
// Record the result number for groups not anchored at 0.
if (int resultNo = llvm::cast<OpResult>(result).getResultNumber())
@@ -1589,14 +1629,25 @@ void SSANameState::numberValuesInOp(Operation &op) {
blockNames[block] = {-1, name};
};
+ bool alreadySetNames = false;
if (!printerFlags.shouldPrintGenericOpForm()) {
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
asmInterface.getAsmBlockNames(setBlockNameFn);
asmInterface.getAsmResultNames(setResultNameFn);
+ alreadySetNames = true;
}
}
unsigned numResults = op.getNumResults();
+ if (printerFlags.shouldUseNameLocAsPrefix() && !alreadySetNames &&
+ numResults > 0) {
+ Value resultBegin = op.getResult(0);
+ if (isa<NameLoc>(resultBegin.getLoc())) {
+ auto nameLoc = cast<NameLoc>(resultBegin.getLoc());
+ setResultNameFn(resultBegin, nameLoc.getName());
+ }
+ }
+
if (numResults == 0) {
// If value users should be printed, operations with no result need an id.
if (printerFlags.shouldPrintValueUsers()) {
diff --git a/mlir/test/IR/print-use-nameloc-as-prefix.mlir b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
new file mode 100644
index 00000000000000..fb555d9708ee86
--- /dev/null
+++ b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2' -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s --check-prefix=CHECK-PASS-PRESERVE
+
+// CHECK-LABEL: test_basic
+func.func @test_basic() {
+ %0 = memref.alloc() : memref<i32>
+ // CHECK: %alice
+ %1 = memref.load %0[] : memref<i32> loc("alice")
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_repeat_namelocs
+func.func @test_repeat_namelocs() {
+ %0 = memref.alloc() : memref<i32>
+ // CHECK: %alice
+ %1 = memref.load %0[] : memref<i32> loc("alice")
+ // CHECK: %alice_0
+ %2 = memref.load %0[] : memref<i32> loc("alice")
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_bb_args
+func.func @test_bb_args1(%arg0 : memref<i32> loc("foo")) {
+ // CHECK: %alice = memref.load %foo
+ %1 = memref.load %arg0[] : memref<i32> loc("alice")
+ return
+}
+
+// -----
+
+func.func private @make_two_results() -> (index, index)
+
+// CHECK-LABEL: test_multiple_results
+func.func @test_multiple_results(%cond: i1) {
+ // CHECK: %foo:2
+ %0:2 = call @make_two_results() : () -> (index, index) loc("foo")
+ // CHECK: %bar:2
+ %1, %2 = call @make_two_results() : () -> (index, index) loc("bar")
+
+ // CHECK: %kevin:2 = scf.while (%arg1 = %bar#0, %arg2 = %bar#0)
+ %5:2 = scf.while (%arg1 = %1, %arg2 = %1) : (index, index) -> (index, index) {
+ %6 = arith.cmpi slt, %arg1, %arg2 : index
+ scf.condition(%6) %arg1, %arg2 : index, index
+ } do {
+ // CHECK: ^bb0(%alice: index, %bob: index)
+ ^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
+ %c1, %c2 = func.call @make_two_results() : () -> (index, index) loc("harriet")
+ // CHECK: scf.yield %harriet#1, %harriet#1
+ scf.yield %c2, %c2 : index, index
+ } loc("kevin")
+ return
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+#trait = {
+ iterator_types = ["parallel"],
+ indexing_maps = [#map, #map, #map]
+}
+
+// CHECK-LABEL: test_op_asm_interface
+func.func @test_op_asm_interface(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+ // CHECK: %c0
+ %0 = arith.constant 0 : index
+ // CHECK: %foo
+ %1 = arith.constant 1 : index loc("foo")
+
+ linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+ // CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32)
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ linalg.yield %a, %a : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+
+ linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+ // CHECK: ^bb0(%bar: f32, %alice: f32, %steve: f32)
+ ^bb0(%a: f32 loc("bar"), %b: f32 loc("alice"), %c: f32 loc("steve")):
+ // CHECK: linalg.yield %alice, %steve
+ linalg.yield %b, %c : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_pass
+func.func @test_pass(%arg0: memref<4xf32>, %arg1: memref<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ scf.for %arg2 = %c0 to %c4 step %c1 {
+ // CHECK-PASS-PRESERVE: %foo = memref.load
+ // CHECK-PASS-PRESERVE: memref.store %foo
+ // CHECK-PASS-PRESERVE: %foo_1 = memref.load
+ // CHECK-PASS-PRESERVE: memref.store %foo_1
+ %0 = memref.load %arg0[%arg2] : memref<4xf32> loc("foo")
+ memref.store %0, %arg1[%arg2] : memref<4xf32>
+ }
+ return
+}
|
c7ebced
to
a65a8b3
Compare
a65a8b3
to
93be82c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall I think this is nice, minimal start. Could this be documented somewhere? (Potentially debugging guide could be good one).
Sure I can add some notes there. |
mlir/lib/IR/AsmPrinter.cpp
Outdated
if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) { | ||
auto nameLoc = cast<NameLoc>(arg.getLoc()); | ||
setValueName(arg, nameLoc.getName()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you share the implementation of these three usages? Seems like they all pretty much have the same pattern.
Also for the NameLoc itself, you likely want to use arg.getLoc().findInstanceOf<NameLoc>()
, given that in some cases (e.g. DebugInfo) you may not have a top level NameLoc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to factor out the whole conditional but then you end up hiding the printerFlags.shouldUseNameLocAsPrefix()
in the helper itself. I didn't really like that (would be surprising I think) so I factored out the "get the name from a (possibly) nested NameLoc
". Let me know if you want more refactoring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this a great way to incrementally bring us closer to some of the goals set out here.
I see that multiple results are supported, but would it make sense here to support have different names for different results, e.g.:
%kevin, %alice = call @make_two_results() : () -> (index, index)
Or even a mix of result groups:
%kevin:2, %alice = call @make_three_results() : () -> (index, index, index)
This was supported in the more invasive #79704.
llvm-project/mlir/test/IR/print-retain-identifiers.mlir
Lines 82 to 83 in b257ab5
// CHECK: %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) { | |
%max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) { |
I would imagine the syntax being something like:
%kevin, %alice = call @make_two_results() : () -> (index, index) loc("kevin", "alice")
%kevin:2, %alice = call @make_three_results() : () -> (index, index, index) loc("kevin", "", "alice") // empty string means use result group
%kevin, %alice:2 = call @make_three_results() : () -> (index, index, index) loc("kevin", "alice") // implicit grouping
Although given this is building off of the existing NamedLoc
system, then that would require a bigger change to change its syntax.
Having a delimiter character, e.g., loc("kevin🐢alice")
, loc("kevin🐢🐢alice")
could be a way to do this, but 🐢 would need to be chosen carefully.
Ya that's tough exactly because you need to introduce a delimiter and then split on it. And then if something about how results are grouped in the future changes, maybe another delimiter would need to be added to support that change. Ie it would create a tiny little " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG overall, thanks
@@ -1398,6 +1398,24 @@ $ tree /tmp/pipeline_output | |||
│ │ ├── 1_1_pass4.mlir | |||
``` | |||
|
|||
* `mlir-use-nameloc-as-prefix` | |||
* If your source IR has named locations (`loc("named_location")"`) then passing this flag will use those | |||
names (`named_location`) to prefix the corresponding SSA identifiers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: newline after and ```mlir to ensure highlighting triggers.
mlir/docs/PassManagement.md
Outdated
%bob = memref.load %0[] : memref<i32> | ||
%bob_0 = memref.load %0[] : memref<i32> | ||
``` | ||
These names will also be preserved through passes to newly created operations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/(assuming/if/
using the appropriate location.
Yes, there are some weasel words here, but there isn't necessarily one single location in all cases that should be passed (different users could have different opinions for valid reasons).
mlir/lib/IR/AsmPrinter.cpp
Outdated
@@ -362,6 +369,11 @@ bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const { | |||
return printUniqueSSAIDsFlag || shouldPrintGenericOpForm(); | |||
} | |||
|
|||
/// Return if the printer should use NameLocs as prefixes when printing SSA IDs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: missing trailing period.
// CHECK-LABEL: test_basic | ||
func.func @test_basic() { | ||
%0 = memref.alloc() : memref<i32> | ||
// CHECK: %alice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alice = memref.load ? (then its just a bit more complete local example for folks who read tests rather than docs - just have to do it once, not for the others)
I don't have concerns with such a change on the printer side: however I do when it'll come to the parser. Does this change make sense without the parser changes? |
Yes - it's exactly architected in this way to minimize coupling to the parser and be useful on its own such that if no parser changes are made, it's still "good enough". |
f7c8b37
to
4168f8e
Compare
This PR adds an
AsmPrinter
option-mlir-use-nameloc-as-prefix
which uses trailingNameLoc
s, if the source IR provides them, as prefixes when printing SSA IDs. Note: this PR only changesAsmPrinter
.For example:
prints
Currently single/multiple results and bb args are handled:
becomes
Note, this approach preserves the
NameLoc
through passes sobecomes
Call outs
The changes here are also compatible with
OpAsmOpInterface
. Note though, if an op implementsgetAsmBlockArgumentNames
but notgetAsmResultNames
(likelinalg.generic
) then the affixedloc("...")
will not be processed becausesetResultNameFn
is never called.Testing
Besides the added lit test, I "stress tested" this by turning it on by default and checking if anything broke (this commit), then fixing my bugs and stress testing again (this commit and this passing test, Windows test fail there was a flake).