Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] add option to print SSA IDs using NameLocs as prefixes #119996

Merged
merged 8 commits into from
Dec 17, 2024

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Dec 15, 2024

This PR adds an AsmPrinter option -mlir-use-nameloc-as-prefix which uses trailing NameLocs, if the source IR provides them, as prefixes when printing SSA IDs. Note: this PR only changes AsmPrinter.

For example:

%1 = memref.load %0[] : memref<i32> loc("alice")

prints

%alice = memref.load %0[] : memref<i32>

Currently single/multiple results and bb args are handled:

%1:2 = call @make_two_results() : () -> (index, index) loc("bar")
%2:2 = scf.while (%arg1 = %1#0, %arg2 = %1#0) : (index, index) -> (index, index) {
  %3 = arith.cmpi slt, %arg1, %arg2 : index
  scf.condition(%3) %arg1, %arg2 : index, index
} do {
^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
  %c1, %c2 = func.call @make_two_results() : () -> (index, index) loc("harriet")
  scf.yield %c2, %c2 : index, index
} loc("kevin")

becomes

%bar:2 = call @make_two_results() : () -> (index, index)
%kevin:2 = scf.while (%arg1 = %bar#0, %arg2 = %bar#0) : (index, index) -> (index, index) {
  %0 = arith.cmpi slt, %arg1, %arg2 : index
  scf.condition(%0) %arg1, %arg2 : index, index
} do {
^bb0(%alice: index, %bob: index):
  %harriet:2 = func.call @make_two_results() : () -> (index, index)
  scf.yield %harriet#1, %harriet#1 : index, index
}

Note, this approach preserves the NameLoc through passes so

// -test-loop-unrolling='unroll-factor=2' -mlir-use-nameloc-as-prefix
scf.for %arg2 = %c0 to %c4 step %c1 {
  %0 = memref.load %arg0[%arg2] : memref<4xf32> loc("foo")
  memref.store %0, %arg1[%arg2] : memref<4xf32>
}

becomes

scf.for %arg2 = %c0 to %c4 step %c2 {
  %foo = memref.load %arg0[%arg2] : memref<4xf32>
  memref.store %foo, %arg1[%arg2] : memref<4xf32>
  ...
  %foo_1 = memref.load %arg0[%1] : memref<4xf32>
  memref.store %foo_1, %arg1[%1] : memref<4xf32>
}

Call outs

The changes here are also compatible with OpAsmOpInterface. Note though, if an op implements getAsmBlockArgumentNames but not getAsmResultNames (like linalg.generic) then the affixed loc("...") will not be processed because setResultNameFn is never called.

Testing

Besides the added lit test, I "stress tested" this by turning it on by default and checking if anything broke (this commit), then fixing my bugs and stress testing again (this commit and this passing test, Windows test fail there was a flake).

Copy link

github-actions bot commented Dec 15, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@makslevental makslevental force-pushed the makslevental/retain-names-dev-v4 branch from e9fd4aa to 9e6c00d Compare December 15, 2024 04:31
@makslevental makslevental force-pushed the makslevental/retain-names-dev-v4 branch 4 times, most recently from 33116cf to 8fedeff Compare December 15, 2024 05:25
@makslevental makslevental force-pushed the makslevental/retain-names-dev-v4 branch from 8fedeff to bbca902 Compare December 15, 2024 05:25
@makslevental makslevental force-pushed the makslevental/retain-names-dev-v4 branch from 579868a to ca084f9 Compare December 15, 2024 06:24
@makslevental makslevental force-pushed the makslevental/retain-names-dev-v4 branch 5 times, most recently from bf24d4a to 457f56a Compare December 15, 2024 21:40
@makslevental makslevental changed the title [mlir] Retain original identifier names for debugging v3 [mlir] Print SSA IDs using NameLocs as prefixes Dec 15, 2024
@makslevental makslevental changed the title [mlir] Print SSA IDs using NameLocs as prefixes [mlir] add option to print SSA IDs using NameLocs as prefixes Dec 15, 2024
@makslevental makslevental force-pushed the makslevental/retain-names-dev-v4 branch 2 times, most recently from f87e003 to e7921cf Compare December 15, 2024 22:56
@makslevental makslevental force-pushed the makslevental/retain-names-dev-v4 branch 2 times, most recently from d61562c to ecb6adc Compare December 15, 2024 23:11
@makslevental makslevental marked this pull request as ready for review December 15, 2024 23:13
@llvmbot llvmbot added the mlir:core MLIR Core Infrastructure label Dec 15, 2024
@llvmbot llvmbot added the mlir label Dec 15, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 15, 2024

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

This PR adds an AsmPrinter option -mlir-use-nameloc-as-prefix which uses trailing NameLocs, if the source IR provides them, as prefixes when printing SSA IDs. Note: this PR only changes AsmPrinter.

For example:

%1 = memref.load %0[] : memref&lt;i32&gt; loc("alice")

prints

%alice = memref.load %0[] : memref&lt;i32&gt;

Currently single/multiple results and bb args are handled:

%1:2 = call @<!-- -->make_two_results() : () -&gt; (index, index) loc("bar")
%2:2 = scf.while (%arg1 = %1#<!-- -->0, %arg2 = %1#<!-- -->0) : (index, index) -&gt; (index, index) {
  %3 = arith.cmpi slt, %arg1, %arg2 : index
  scf.condition(%3) %arg1, %arg2 : index, index
} do {
^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
  %c1, %c2 = func.call @<!-- -->make_two_results() : () -&gt; (index, index) loc("harriet")
  scf.yield %c2, %c2 : index, index
} loc("kevin")

becomes

%bar:2 = call @<!-- -->make_two_results() : () -&gt; (index, index)
%kevin:2 = scf.while (%arg1 = %bar#<!-- -->0, %arg2 = %bar#<!-- -->0) : (index, index) -&gt; (index, index) {
  %0 = arith.cmpi slt, %arg1, %arg2 : index
  scf.condition(%0) %arg1, %arg2 : index, index
} do {
^bb0(%alice: index, %bob: index):
  %harriet:2 = func.call @<!-- -->make_two_results() : () -&gt; (index, index)
  scf.yield %harriet#<!-- -->1, %harriet#<!-- -->1 : index, index
}

The changes here are also compatible with OpAsmOpInterface. Note though, if an op implements getAsmBlockArgumentNames but not getAsmResultNames (like linalg.generic) then the affixed loc("...") will not be processed because setResultNameFn is never called.

Testing

Besides the added lit test, I "stress tested" this by turn it on by default and checking if anything broke (this commit), then fixing my bugs and stress testing again (this commit and this [passing test](https://buildkite.com/llvm-project/github-pull-requests/builds/129062#0193c8fd-9284-4c5f-b4c0-ec337d75f7be), Windows test fail there was a flake).


Full diff: https://github.com/llvm/llvm-project/pull/119996.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/OperationSupport.h (+7)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+57-6)
  • (added) mlir/test/IR/print-use-nameloc-as-prefix.mlir (+105)
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1b93f3d3d04fe8..9f2de582b03e56 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1221,6 +1221,10 @@ class OpPrintingFlags {
   /// Return if printer should use unique SSA IDs.
   bool shouldPrintUniqueSSAIDs() const;
 
+  /// Return if the printer should use NameLocs as prefixes when printing SSA
+  /// IDs
+  bool shouldUseNameLocAsPrefix() const;
+
 private:
   /// Elide large elements attributes if the number of elements is larger than
   /// the upper limit.
@@ -1254,6 +1258,9 @@ class OpPrintingFlags {
 
   /// Print unique SSA IDs for values, block arguments and naming conflicts
   bool printUniqueSSAIDsFlag : 1;
+
+  /// Print SSA IDs using NameLocs as prefixes
+  bool useNameLocAsPrefix : 1;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 61b90bc9b0a7bb..1a81414b358d06 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -73,7 +73,8 @@ OpAsmParser::~OpAsmParser() = default;
 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
 
 /// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
 ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
   return parseCommaSeparatedList(
       [&]() { return parseType(result.emplace_back()); });
@@ -195,6 +196,10 @@ struct AsmPrinterOptions {
       "mlir-print-unique-ssa-ids", llvm::cl::init(false),
       llvm::cl::desc("Print unique SSA ID numbers for values, block arguments "
                      "and naming conflicts across all regions")};
+
+  llvm::cl::opt<bool> useNameLocAsPrefix{
+      "mlir-use-nameloc-as-prefix", llvm::cl::init(false),
+      llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")};
 };
 } // namespace
 
@@ -212,7 +217,8 @@ OpPrintingFlags::OpPrintingFlags()
     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
       printGenericOpFormFlag(false), skipRegionsFlag(false),
       assumeVerifiedFlag(false), printLocalScope(false),
-      printValueUsersFlag(false), printUniqueSSAIDsFlag(false) {
+      printValueUsersFlag(false), printUniqueSSAIDsFlag(false),
+      useNameLocAsPrefix(false) {
   // Initialize based upon command line options, if they are available.
   if (!clOptions.isConstructed())
     return;
@@ -231,6 +237,7 @@ OpPrintingFlags::OpPrintingFlags()
   skipRegionsFlag = clOptions->skipRegionsOpt;
   printValueUsersFlag = clOptions->printValueUsers;
   printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs;
+  useNameLocAsPrefix = clOptions->useNameLocAsPrefix;
 }
 
 /// Enable the elision of large elements attributes, by printing a '...'
@@ -362,6 +369,11 @@ bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const {
   return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
 }
 
+/// Return if the printer should use NameLocs as prefixes when printing SSA IDs
+bool OpPrintingFlags::shouldUseNameLocAsPrefix() const {
+  return useNameLocAsPrefix;
+}
+
 //===----------------------------------------------------------------------===//
 // NewLineCounter
 //===----------------------------------------------------------------------===//
@@ -1511,13 +1523,30 @@ void SSANameState::numberValuesInRegion(Region &region) {
     assert(!valueIDs.count(arg) && "arg numbered multiple times");
     assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
            "arg not defined in current region");
-    setValueName(arg, name);
+    if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) {
+      auto nameLoc = cast<NameLoc>(arg.getLoc());
+      setValueName(arg, nameLoc.getName());
+    } else {
+      setValueName(arg, name);
+    }
   };
 
+  bool alreadySetNames = false;
   if (!printerFlags.shouldPrintGenericOpForm()) {
     if (Operation *op = region.getParentOp()) {
-      if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
+      if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op)) {
         asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
+        alreadySetNames = true;
+      }
+    }
+  }
+
+  if (printerFlags.shouldUseNameLocAsPrefix() && !alreadySetNames) {
+    for (BlockArgument arg : region.getArguments()) {
+      if (isa<NameLoc>(arg.getLoc())) {
+        auto nameLoc = cast<NameLoc>(arg.getLoc());
+        setBlockArgNameFn(arg, nameLoc.getName());
+      }
     }
   }
 
@@ -1553,7 +1582,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
       specialNameBuffer.resize(strlen("arg"));
       specialName << nextArgumentID++;
     }
-    setValueName(arg, specialName.str());
+    if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) {
+      auto nameLoc = cast<NameLoc>(arg.getLoc());
+      setValueName(arg, nameLoc.getName());
+    } else {
+      setValueName(arg, specialName.str());
+    }
   }
 
   // Number the operations in this block.
@@ -1567,7 +1601,13 @@ void SSANameState::numberValuesInOp(Operation &op) {
   auto setResultNameFn = [&](Value result, StringRef name) {
     assert(!valueIDs.count(result) && "result numbered multiple times");
     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
-    setValueName(result, name);
+    if (printerFlags.shouldUseNameLocAsPrefix() &&
+        isa<NameLoc>(result.getLoc())) {
+      auto nameLoc = cast<NameLoc>(result.getLoc());
+      setValueName(result, nameLoc.getName());
+    } else {
+      setValueName(result, name);
+    }
 
     // Record the result number for groups not anchored at 0.
     if (int resultNo = llvm::cast<OpResult>(result).getResultNumber())
@@ -1589,14 +1629,25 @@ void SSANameState::numberValuesInOp(Operation &op) {
     blockNames[block] = {-1, name};
   };
 
+  bool alreadySetNames = false;
   if (!printerFlags.shouldPrintGenericOpForm()) {
     if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
       asmInterface.getAsmBlockNames(setBlockNameFn);
       asmInterface.getAsmResultNames(setResultNameFn);
+      alreadySetNames = true;
     }
   }
 
   unsigned numResults = op.getNumResults();
+  if (printerFlags.shouldUseNameLocAsPrefix() && !alreadySetNames &&
+      numResults > 0) {
+    Value resultBegin = op.getResult(0);
+    if (isa<NameLoc>(resultBegin.getLoc())) {
+      auto nameLoc = cast<NameLoc>(resultBegin.getLoc());
+      setResultNameFn(resultBegin, nameLoc.getName());
+    }
+  }
+
   if (numResults == 0) {
     // If value users should be printed, operations with no result need an id.
     if (printerFlags.shouldPrintValueUsers()) {
diff --git a/mlir/test/IR/print-use-nameloc-as-prefix.mlir b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
new file mode 100644
index 00000000000000..fb555d9708ee86
--- /dev/null
+++ b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2' -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s --check-prefix=CHECK-PASS-PRESERVE
+
+// CHECK-LABEL: test_basic
+func.func @test_basic() {
+  %0 = memref.alloc() : memref<i32>
+  // CHECK: %alice
+  %1 = memref.load %0[] : memref<i32> loc("alice")
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_repeat_namelocs
+func.func @test_repeat_namelocs() {
+  %0 = memref.alloc() : memref<i32>
+  // CHECK: %alice
+  %1 = memref.load %0[] : memref<i32> loc("alice")
+  // CHECK: %alice_0
+  %2 = memref.load %0[] : memref<i32> loc("alice")
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_bb_args
+func.func @test_bb_args1(%arg0 : memref<i32> loc("foo")) {
+  // CHECK: %alice = memref.load %foo
+  %1 = memref.load %arg0[] : memref<i32> loc("alice")
+  return
+}
+
+// -----
+
+func.func private @make_two_results() -> (index, index)
+
+// CHECK-LABEL: test_multiple_results
+func.func @test_multiple_results(%cond: i1) {
+  // CHECK: %foo:2
+  %0:2 = call @make_two_results() : () -> (index, index) loc("foo")
+  // CHECK: %bar:2
+  %1, %2 = call @make_two_results() : () -> (index, index) loc("bar")
+
+  // CHECK: %kevin:2 = scf.while (%arg1 = %bar#0, %arg2 = %bar#0)
+  %5:2 = scf.while (%arg1 = %1, %arg2 = %1) : (index, index) -> (index, index) {
+    %6 = arith.cmpi slt, %arg1, %arg2 : index
+    scf.condition(%6) %arg1, %arg2 : index, index
+  } do {
+  // CHECK: ^bb0(%alice: index, %bob: index)
+  ^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
+    %c1, %c2 = func.call @make_two_results() : () -> (index, index) loc("harriet")
+    // CHECK: scf.yield %harriet#1, %harriet#1
+    scf.yield %c2, %c2 : index, index
+  } loc("kevin")
+  return
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+#trait = {
+  iterator_types = ["parallel"],
+  indexing_maps = [#map, #map, #map]
+}
+
+// CHECK-LABEL: test_op_asm_interface
+func.func @test_op_asm_interface(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // CHECK: %c0
+  %0 = arith.constant 0 : index
+  // CHECK: %foo
+  %1 = arith.constant 1 : index loc("foo")
+
+  linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+    // CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32)
+    ^bb0(%a: f32, %b: f32, %c: f32):
+      linalg.yield %a, %a : f32, f32
+  } -> (tensor<?xf32>, tensor<?xf32>)
+
+  linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+    // CHECK: ^bb0(%bar: f32, %alice: f32, %steve: f32)
+    ^bb0(%a: f32 loc("bar"), %b: f32 loc("alice"), %c: f32 loc("steve")):
+      // CHECK: linalg.yield %alice, %steve
+      linalg.yield %b, %c : f32, f32
+  } -> (tensor<?xf32>, tensor<?xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_pass
+func.func @test_pass(%arg0: memref<4xf32>, %arg1: memref<4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  scf.for %arg2 = %c0 to %c4 step %c1 {
+    // CHECK-PASS-PRESERVE: %foo = memref.load
+    // CHECK-PASS-PRESERVE: memref.store %foo
+    // CHECK-PASS-PRESERVE: %foo_1 = memref.load
+    // CHECK-PASS-PRESERVE: memref.store %foo_1
+    %0 = memref.load %arg0[%arg2] : memref<4xf32> loc("foo")
+    memref.store %0, %arg1[%arg2] : memref<4xf32>
+  }
+  return
+}

@llvmbot
Copy link
Member

llvmbot commented Dec 15, 2024

@llvm/pr-subscribers-mlir-core

Author: Maksim Levental (makslevental)

Changes

This PR adds an AsmPrinter option -mlir-use-nameloc-as-prefix which uses trailing NameLocs, if the source IR provides them, as prefixes when printing SSA IDs. Note: this PR only changes AsmPrinter.

For example:

%1 = memref.load %0[] : memref&lt;i32&gt; loc("alice")

prints

%alice = memref.load %0[] : memref&lt;i32&gt;

Currently single/multiple results and bb args are handled:

%1:2 = call @<!-- -->make_two_results() : () -&gt; (index, index) loc("bar")
%2:2 = scf.while (%arg1 = %1#<!-- -->0, %arg2 = %1#<!-- -->0) : (index, index) -&gt; (index, index) {
  %3 = arith.cmpi slt, %arg1, %arg2 : index
  scf.condition(%3) %arg1, %arg2 : index, index
} do {
^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
  %c1, %c2 = func.call @<!-- -->make_two_results() : () -&gt; (index, index) loc("harriet")
  scf.yield %c2, %c2 : index, index
} loc("kevin")

becomes

%bar:2 = call @<!-- -->make_two_results() : () -&gt; (index, index)
%kevin:2 = scf.while (%arg1 = %bar#<!-- -->0, %arg2 = %bar#<!-- -->0) : (index, index) -&gt; (index, index) {
  %0 = arith.cmpi slt, %arg1, %arg2 : index
  scf.condition(%0) %arg1, %arg2 : index, index
} do {
^bb0(%alice: index, %bob: index):
  %harriet:2 = func.call @<!-- -->make_two_results() : () -&gt; (index, index)
  scf.yield %harriet#<!-- -->1, %harriet#<!-- -->1 : index, index
}

The changes here are also compatible with OpAsmOpInterface. Note though, if an op implements getAsmBlockArgumentNames but not getAsmResultNames (like linalg.generic) then the affixed loc("...") will not be processed because setResultNameFn is never called.

Testing

Besides the added lit test, I "stress tested" this by turn it on by default and checking if anything broke (this commit), then fixing my bugs and stress testing again (this commit and this [passing test](https://buildkite.com/llvm-project/github-pull-requests/builds/129062#0193c8fd-9284-4c5f-b4c0-ec337d75f7be), Windows test fail there was a flake).


Full diff: https://github.com/llvm/llvm-project/pull/119996.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/OperationSupport.h (+7)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+57-6)
  • (added) mlir/test/IR/print-use-nameloc-as-prefix.mlir (+105)
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1b93f3d3d04fe8..9f2de582b03e56 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1221,6 +1221,10 @@ class OpPrintingFlags {
   /// Return if printer should use unique SSA IDs.
   bool shouldPrintUniqueSSAIDs() const;
 
+  /// Return if the printer should use NameLocs as prefixes when printing SSA
+  /// IDs
+  bool shouldUseNameLocAsPrefix() const;
+
 private:
   /// Elide large elements attributes if the number of elements is larger than
   /// the upper limit.
@@ -1254,6 +1258,9 @@ class OpPrintingFlags {
 
   /// Print unique SSA IDs for values, block arguments and naming conflicts
   bool printUniqueSSAIDsFlag : 1;
+
+  /// Print SSA IDs using NameLocs as prefixes
+  bool useNameLocAsPrefix : 1;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 61b90bc9b0a7bb..1a81414b358d06 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -73,7 +73,8 @@ OpAsmParser::~OpAsmParser() = default;
 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
 
 /// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
 ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
   return parseCommaSeparatedList(
       [&]() { return parseType(result.emplace_back()); });
@@ -195,6 +196,10 @@ struct AsmPrinterOptions {
       "mlir-print-unique-ssa-ids", llvm::cl::init(false),
       llvm::cl::desc("Print unique SSA ID numbers for values, block arguments "
                      "and naming conflicts across all regions")};
+
+  llvm::cl::opt<bool> useNameLocAsPrefix{
+      "mlir-use-nameloc-as-prefix", llvm::cl::init(false),
+      llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")};
 };
 } // namespace
 
@@ -212,7 +217,8 @@ OpPrintingFlags::OpPrintingFlags()
     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
       printGenericOpFormFlag(false), skipRegionsFlag(false),
       assumeVerifiedFlag(false), printLocalScope(false),
-      printValueUsersFlag(false), printUniqueSSAIDsFlag(false) {
+      printValueUsersFlag(false), printUniqueSSAIDsFlag(false),
+      useNameLocAsPrefix(false) {
   // Initialize based upon command line options, if they are available.
   if (!clOptions.isConstructed())
     return;
@@ -231,6 +237,7 @@ OpPrintingFlags::OpPrintingFlags()
   skipRegionsFlag = clOptions->skipRegionsOpt;
   printValueUsersFlag = clOptions->printValueUsers;
   printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs;
+  useNameLocAsPrefix = clOptions->useNameLocAsPrefix;
 }
 
 /// Enable the elision of large elements attributes, by printing a '...'
@@ -362,6 +369,11 @@ bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const {
   return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
 }
 
+/// Return if the printer should use NameLocs as prefixes when printing SSA IDs
+bool OpPrintingFlags::shouldUseNameLocAsPrefix() const {
+  return useNameLocAsPrefix;
+}
+
 //===----------------------------------------------------------------------===//
 // NewLineCounter
 //===----------------------------------------------------------------------===//
@@ -1511,13 +1523,30 @@ void SSANameState::numberValuesInRegion(Region &region) {
     assert(!valueIDs.count(arg) && "arg numbered multiple times");
     assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
            "arg not defined in current region");
-    setValueName(arg, name);
+    if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) {
+      auto nameLoc = cast<NameLoc>(arg.getLoc());
+      setValueName(arg, nameLoc.getName());
+    } else {
+      setValueName(arg, name);
+    }
   };
 
+  bool alreadySetNames = false;
   if (!printerFlags.shouldPrintGenericOpForm()) {
     if (Operation *op = region.getParentOp()) {
-      if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
+      if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op)) {
         asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
+        alreadySetNames = true;
+      }
+    }
+  }
+
+  if (printerFlags.shouldUseNameLocAsPrefix() && !alreadySetNames) {
+    for (BlockArgument arg : region.getArguments()) {
+      if (isa<NameLoc>(arg.getLoc())) {
+        auto nameLoc = cast<NameLoc>(arg.getLoc());
+        setBlockArgNameFn(arg, nameLoc.getName());
+      }
     }
   }
 
@@ -1553,7 +1582,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
       specialNameBuffer.resize(strlen("arg"));
       specialName << nextArgumentID++;
     }
-    setValueName(arg, specialName.str());
+    if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) {
+      auto nameLoc = cast<NameLoc>(arg.getLoc());
+      setValueName(arg, nameLoc.getName());
+    } else {
+      setValueName(arg, specialName.str());
+    }
   }
 
   // Number the operations in this block.
@@ -1567,7 +1601,13 @@ void SSANameState::numberValuesInOp(Operation &op) {
   auto setResultNameFn = [&](Value result, StringRef name) {
     assert(!valueIDs.count(result) && "result numbered multiple times");
     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
-    setValueName(result, name);
+    if (printerFlags.shouldUseNameLocAsPrefix() &&
+        isa<NameLoc>(result.getLoc())) {
+      auto nameLoc = cast<NameLoc>(result.getLoc());
+      setValueName(result, nameLoc.getName());
+    } else {
+      setValueName(result, name);
+    }
 
     // Record the result number for groups not anchored at 0.
     if (int resultNo = llvm::cast<OpResult>(result).getResultNumber())
@@ -1589,14 +1629,25 @@ void SSANameState::numberValuesInOp(Operation &op) {
     blockNames[block] = {-1, name};
   };
 
+  bool alreadySetNames = false;
   if (!printerFlags.shouldPrintGenericOpForm()) {
     if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
       asmInterface.getAsmBlockNames(setBlockNameFn);
       asmInterface.getAsmResultNames(setResultNameFn);
+      alreadySetNames = true;
     }
   }
 
   unsigned numResults = op.getNumResults();
+  if (printerFlags.shouldUseNameLocAsPrefix() && !alreadySetNames &&
+      numResults > 0) {
+    Value resultBegin = op.getResult(0);
+    if (isa<NameLoc>(resultBegin.getLoc())) {
+      auto nameLoc = cast<NameLoc>(resultBegin.getLoc());
+      setResultNameFn(resultBegin, nameLoc.getName());
+    }
+  }
+
   if (numResults == 0) {
     // If value users should be printed, operations with no result need an id.
     if (printerFlags.shouldPrintValueUsers()) {
diff --git a/mlir/test/IR/print-use-nameloc-as-prefix.mlir b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
new file mode 100644
index 00000000000000..fb555d9708ee86
--- /dev/null
+++ b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2' -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s --check-prefix=CHECK-PASS-PRESERVE
+
+// CHECK-LABEL: test_basic
+func.func @test_basic() {
+  %0 = memref.alloc() : memref<i32>
+  // CHECK: %alice
+  %1 = memref.load %0[] : memref<i32> loc("alice")
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_repeat_namelocs
+func.func @test_repeat_namelocs() {
+  %0 = memref.alloc() : memref<i32>
+  // CHECK: %alice
+  %1 = memref.load %0[] : memref<i32> loc("alice")
+  // CHECK: %alice_0
+  %2 = memref.load %0[] : memref<i32> loc("alice")
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_bb_args
+func.func @test_bb_args1(%arg0 : memref<i32> loc("foo")) {
+  // CHECK: %alice = memref.load %foo
+  %1 = memref.load %arg0[] : memref<i32> loc("alice")
+  return
+}
+
+// -----
+
+func.func private @make_two_results() -> (index, index)
+
+// CHECK-LABEL: test_multiple_results
+func.func @test_multiple_results(%cond: i1) {
+  // CHECK: %foo:2
+  %0:2 = call @make_two_results() : () -> (index, index) loc("foo")
+  // CHECK: %bar:2
+  %1, %2 = call @make_two_results() : () -> (index, index) loc("bar")
+
+  // CHECK: %kevin:2 = scf.while (%arg1 = %bar#0, %arg2 = %bar#0)
+  %5:2 = scf.while (%arg1 = %1, %arg2 = %1) : (index, index) -> (index, index) {
+    %6 = arith.cmpi slt, %arg1, %arg2 : index
+    scf.condition(%6) %arg1, %arg2 : index, index
+  } do {
+  // CHECK: ^bb0(%alice: index, %bob: index)
+  ^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
+    %c1, %c2 = func.call @make_two_results() : () -> (index, index) loc("harriet")
+    // CHECK: scf.yield %harriet#1, %harriet#1
+    scf.yield %c2, %c2 : index, index
+  } loc("kevin")
+  return
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+#trait = {
+  iterator_types = ["parallel"],
+  indexing_maps = [#map, #map, #map]
+}
+
+// CHECK-LABEL: test_op_asm_interface
+func.func @test_op_asm_interface(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // CHECK: %c0
+  %0 = arith.constant 0 : index
+  // CHECK: %foo
+  %1 = arith.constant 1 : index loc("foo")
+
+  linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+    // CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32)
+    ^bb0(%a: f32, %b: f32, %c: f32):
+      linalg.yield %a, %a : f32, f32
+  } -> (tensor<?xf32>, tensor<?xf32>)
+
+  linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+    // CHECK: ^bb0(%bar: f32, %alice: f32, %steve: f32)
+    ^bb0(%a: f32 loc("bar"), %b: f32 loc("alice"), %c: f32 loc("steve")):
+      // CHECK: linalg.yield %alice, %steve
+      linalg.yield %b, %c : f32, f32
+  } -> (tensor<?xf32>, tensor<?xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_pass
+func.func @test_pass(%arg0: memref<4xf32>, %arg1: memref<4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  scf.for %arg2 = %c0 to %c4 step %c1 {
+    // CHECK-PASS-PRESERVE: %foo = memref.load
+    // CHECK-PASS-PRESERVE: memref.store %foo
+    // CHECK-PASS-PRESERVE: %foo_1 = memref.load
+    // CHECK-PASS-PRESERVE: memref.store %foo_1
+    %0 = memref.load %arg0[%arg2] : memref<4xf32> loc("foo")
+    memref.store %0, %arg1[%arg2] : memref<4xf32>
+  }
+  return
+}

@makslevental makslevental force-pushed the makslevental/retain-names-dev-v4 branch 4 times, most recently from c7ebced to a65a8b3 Compare December 15, 2024 23:43
@makslevental makslevental force-pushed the makslevental/retain-names-dev-v4 branch from a65a8b3 to 93be82c Compare December 15, 2024 23:43
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this is nice, minimal start. Could this be documented somewhere? (Potentially debugging guide could be good one).

@makslevental
Copy link
Contributor Author

Overall I think this is nice, minimal start. Could this be documented somewhere? (Potentially debugging guide could be good one).

Sure I can add some notes there.

Comment on lines 1573 to 1575
if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) {
auto nameLoc = cast<NameLoc>(arg.getLoc());
setValueName(arg, nameLoc.getName());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share the implementation of these three usages? Seems like they all pretty much have the same pattern.

Also for the NameLoc itself, you likely want to use arg.getLoc().findInstanceOf<NameLoc>(), given that in some cases (e.g. DebugInfo) you may not have a top level NameLoc.

Copy link
Contributor Author

@makslevental makslevental Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to factor out the whole conditional but then you end up hiding the printerFlags.shouldUseNameLocAsPrefix() in the helper itself. I didn't really like that (would be surprising I think) so I factored out the "get the name from a (possibly) nested NameLoc". Let me know if you want more refactoring.

Copy link
Contributor

@Wheest Wheest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this a great way to incrementally bring us closer to some of the goals set out here.

I see that multiple results are supported, but would it make sense here to support have different names for different results, e.g.:

%kevin, %alice = call @make_two_results() : () -> (index, index)

Or even a mix of result groups:

%kevin:2, %alice = call @make_three_results() : () -> (index, index, index)

This was supported in the more invasive #79704.

// CHECK: %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
%max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {

I would imagine the syntax being something like:

%kevin, %alice = call @make_two_results() : () -> (index, index) loc("kevin", "alice")
%kevin:2, %alice = call @make_three_results() : () -> (index, index, index) loc("kevin", "", "alice") // empty string means use result group
%kevin, %alice:2 = call @make_three_results() : () -> (index, index, index) loc("kevin", "alice") // implicit grouping 

Although given this is building off of the existing NamedLoc system, then that would require a bigger change to change its syntax.

Having a delimiter character, e.g., loc("kevin🐢alice"), loc("kevin🐢🐢alice") could be a way to do this, but 🐢 would need to be chosen carefully.

@makslevental
Copy link
Contributor Author

makslevental commented Dec 16, 2024

I see that multiple results are supported, but would it make sense here to support have different names for different results, e.g.:

Ya that's tough exactly because you need to introduce a delimiter and then split on it. And then if something about how results are grouped in the future changes, maybe another delimiter would need to be added to support that change. Ie it would create a tiny little "NameLoc results naming DSL" and that requires and actual parser. So I decided it's easier and more robust to just let the existing disambiguation system do its thing.

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG overall, thanks

@@ -1398,6 +1398,24 @@ $ tree /tmp/pipeline_output
│ │ ├── 1_1_pass4.mlir
```

* `mlir-use-nameloc-as-prefix`
* If your source IR has named locations (`loc("named_location")"`) then passing this flag will use those
names (`named_location`) to prefix the corresponding SSA identifiers:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: newline after and ```mlir to ensure highlighting triggers.

%bob = memref.load %0[] : memref<i32>
%bob_0 = memref.load %0[] : memref<i32>
```
These names will also be preserved through passes to newly created operations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/(assuming/if/

using the appropriate location.

Yes, there are some weasel words here, but there isn't necessarily one single location in all cases that should be passed (different users could have different opinions for valid reasons).

@@ -362,6 +369,11 @@ bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const {
return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
}

/// Return if the printer should use NameLocs as prefixes when printing SSA IDs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: missing trailing period.

// CHECK-LABEL: test_basic
func.func @test_basic() {
%0 = memref.alloc() : memref<i32>
// CHECK: %alice
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alice = memref.load ? (then its just a bit more complete local example for folks who read tests rather than docs - just have to do it once, not for the others)

@joker-eph
Copy link
Collaborator

I don't have concerns with such a change on the printer side: however I do when it'll come to the parser. Does this change make sense without the parser changes?

@makslevental
Copy link
Contributor Author

Does this change make sense without the parser changes?

Yes - it's exactly architected in this way to minimize coupling to the parser and be useful on its own such that if no parser changes are made, it's still "good enough".

@makslevental makslevental force-pushed the makslevental/retain-names-dev-v4 branch from f7c8b37 to 4168f8e Compare December 17, 2024 15:42
@makslevental makslevental merged commit f539e00 into llvm:main Dec 17, 2024
5 of 6 checks passed
@makslevental makslevental deleted the makslevental/retain-names-dev-v4 branch December 17, 2024 16:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants