diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md index 7b19a7bf6bf47..9fb0aaab06461 100644 --- a/mlir/docs/PassManagement.md +++ b/mlir/docs/PassManagement.md @@ -1398,6 +1398,27 @@ $ tree /tmp/pipeline_output │ │ ├── 1_1_pass4.mlir ``` +* `mlir-use-nameloc-as-prefix` + * If your source IR has named locations (`loc("named_location")"`) then passing this flag will use those + names (`named_location`) to prefix the corresponding SSA identifiers: + + ```mlir + %1 = memref.load %0[] : memref loc("alice") + %2 = memref.load %0[] : memref loc("bob") + %3 = memref.load %0[] : memref loc("bob") + ``` + + will print + + ```mlir + %alice = memref.load %0[] : memref + %bob = memref.load %0[] : memref + %bob_0 = memref.load %0[] : memref + ``` + + These names will also be preserved through passes to newly created operations if using the appropriate location. + + ## Crash and Failure Reproduction The [pass manager](#pass-manager) in MLIR contains a builtin mechanism to diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 1b93f3d3d04fe..9f2de582b03e5 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -1221,6 +1221,10 @@ class OpPrintingFlags { /// Return if printer should use unique SSA IDs. bool shouldPrintUniqueSSAIDs() const; + /// Return if the printer should use NameLocs as prefixes when printing SSA + /// IDs + bool shouldUseNameLocAsPrefix() const; + private: /// Elide large elements attributes if the number of elements is larger than /// the upper limit. @@ -1254,6 +1258,9 @@ class OpPrintingFlags { /// Print unique SSA IDs for values, block arguments and naming conflicts bool printUniqueSSAIDsFlag : 1; + + /// Print SSA IDs using NameLocs as prefixes + bool useNameLocAsPrefix : 1; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 61b90bc9b0a7b..99b7abe7db1f9 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -73,7 +73,8 @@ OpAsmParser::~OpAsmParser() = default; MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); } /// Parse a type list. -/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918 +/// This is out-of-line to work-around +/// https://github.com/llvm/llvm-project/issues/62918 ParseResult AsmParser::parseTypeList(SmallVectorImpl &result) { return parseCommaSeparatedList( [&]() { return parseType(result.emplace_back()); }); @@ -195,6 +196,10 @@ struct AsmPrinterOptions { "mlir-print-unique-ssa-ids", llvm::cl::init(false), llvm::cl::desc("Print unique SSA ID numbers for values, block arguments " "and naming conflicts across all regions")}; + + llvm::cl::opt useNameLocAsPrefix{ + "mlir-use-nameloc-as-prefix", llvm::cl::init(false), + llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")}; }; } // namespace @@ -212,7 +217,8 @@ OpPrintingFlags::OpPrintingFlags() : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), printGenericOpFormFlag(false), skipRegionsFlag(false), assumeVerifiedFlag(false), printLocalScope(false), - printValueUsersFlag(false), printUniqueSSAIDsFlag(false) { + printValueUsersFlag(false), printUniqueSSAIDsFlag(false), + useNameLocAsPrefix(false) { // Initialize based upon command line options, if they are available. if (!clOptions.isConstructed()) return; @@ -231,6 +237,7 @@ OpPrintingFlags::OpPrintingFlags() skipRegionsFlag = clOptions->skipRegionsOpt; printValueUsersFlag = clOptions->printValueUsers; printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs; + useNameLocAsPrefix = clOptions->useNameLocAsPrefix; } /// Enable the elision of large elements attributes, by printing a '...' @@ -362,6 +369,11 @@ bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const { return printUniqueSSAIDsFlag || shouldPrintGenericOpForm(); } +/// Return if the printer should use NameLocs as prefixes when printing SSA IDs. +bool OpPrintingFlags::shouldUseNameLocAsPrefix() const { + return useNameLocAsPrefix; +} + //===----------------------------------------------------------------------===// // NewLineCounter //===----------------------------------------------------------------------===// @@ -1506,11 +1518,22 @@ void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { } } +namespace { +/// Try to get value name from value's location, fallback to `name`. +StringRef maybeGetValueNameFromLoc(Value value, StringRef name) { + if (auto maybeNameLoc = value.getLoc()->findInstanceOf()) + return maybeNameLoc.getName(); + return name; +} +} // namespace + void SSANameState::numberValuesInRegion(Region ®ion) { auto setBlockArgNameFn = [&](Value arg, StringRef name) { assert(!valueIDs.count(arg) && "arg numbered multiple times"); assert(llvm::cast(arg).getOwner()->getParent() == ®ion && "arg not defined in current region"); + if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) + name = maybeGetValueNameFromLoc(arg, name); setValueName(arg, name); }; @@ -1553,7 +1576,10 @@ void SSANameState::numberValuesInBlock(Block &block) { specialNameBuffer.resize(strlen("arg")); specialName << nextArgumentID++; } - setValueName(arg, specialName.str()); + StringRef specialNameStr = specialName.str(); + if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) + specialNameStr = maybeGetValueNameFromLoc(arg, specialNameStr); + setValueName(arg, specialNameStr); } // Number the operations in this block. @@ -1567,6 +1593,8 @@ void SSANameState::numberValuesInOp(Operation &op) { auto setResultNameFn = [&](Value result, StringRef name) { assert(!valueIDs.count(result) && "result numbered multiple times"); assert(result.getDefiningOp() == &op && "result not defined by 'op'"); + if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) + name = maybeGetValueNameFromLoc(result, name); setValueName(result, name); // Record the result number for groups not anchored at 0. @@ -1607,6 +1635,12 @@ void SSANameState::numberValuesInOp(Operation &op) { } Value resultBegin = op.getResult(0); + if (printerFlags.shouldUseNameLocAsPrefix() && !valueIDs.count(resultBegin)) { + if (auto nameLoc = resultBegin.getLoc()->findInstanceOf()) { + setValueName(resultBegin, nameLoc.getName()); + } + } + // If the first result wasn't numbered, give it a default number. if (valueIDs.try_emplace(resultBegin, nextValueID).second) ++nextValueID; diff --git a/mlir/test/IR/print-use-nameloc-as-prefix.mlir b/mlir/test/IR/print-use-nameloc-as-prefix.mlir new file mode 100644 index 0000000000000..ddee8aed5586c --- /dev/null +++ b/mlir/test/IR/print-use-nameloc-as-prefix.mlir @@ -0,0 +1,105 @@ +// RUN: mlir-opt %s -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2' -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s --check-prefix=CHECK-PASS-PRESERVE + +// CHECK-LABEL: test_basic +func.func @test_basic() { + %0 = memref.alloc() : memref + // CHECK: %alice = memref.load + %1 = memref.load %0[] : memref loc("alice") + return +} + +// ----- + +// CHECK-LABEL: test_repeat_namelocs +func.func @test_repeat_namelocs() { + %0 = memref.alloc() : memref + // CHECK: %alice = memref.load + %1 = memref.load %0[] : memref loc("alice") + // CHECK: %alice_0 = memref.load + %2 = memref.load %0[] : memref loc("alice") + return +} + +// ----- + +// CHECK-LABEL: test_bb_args +func.func @test_bb_args1(%arg0 : memref loc("foo")) { + // CHECK: %alice = memref.load %foo + %1 = memref.load %arg0[] : memref loc("alice") + return +} + +// ----- + +func.func private @make_two_results() -> (index, index) + +// CHECK-LABEL: test_multiple_results +func.func @test_multiple_results(%cond: i1) { + // CHECK: %foo:2 = call @make_two_results + %0:2 = call @make_two_results() : () -> (index, index) loc("foo") + // CHECK: %bar:2 = call @make_two_results + %1, %2 = call @make_two_results() : () -> (index, index) loc("bar") + + // CHECK: %kevin:2 = scf.while (%arg1 = %bar#0, %arg2 = %bar#0) + %5:2 = scf.while (%arg1 = %1, %arg2 = %1) : (index, index) -> (index, index) { + %6 = arith.cmpi slt, %arg1, %arg2 : index + scf.condition(%6) %arg1, %arg2 : index, index + } do { + // CHECK: ^bb0(%alice: index, %bob: index) + ^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")): + %c1, %c2 = func.call @make_two_results() : () -> (index, index) loc("harriet") + // CHECK: scf.yield %harriet#1, %harriet#1 + scf.yield %c2, %c2 : index, index + } loc("kevin") + return +} + +// ----- + +#map = affine_map<(d0) -> (d0)> +#trait = { + iterator_types = ["parallel"], + indexing_maps = [#map, #map, #map] +} + +// CHECK-LABEL: test_op_asm_interface +func.func @test_op_asm_interface(%arg0: tensor, %arg1: tensor) { + // CHECK: %c0 = arith.constant + %0 = arith.constant 0 : index + // CHECK: %foo = arith.constant + %1 = arith.constant 1 : index loc("foo") + + linalg.generic #trait ins(%arg0: tensor) outs(%arg0, %arg1: tensor, tensor) { + // CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32) + ^bb0(%a: f32, %b: f32, %c: f32): + linalg.yield %a, %a : f32, f32 + } -> (tensor, tensor) + + linalg.generic #trait ins(%arg0: tensor) outs(%arg0, %arg1: tensor, tensor) { + // CHECK: ^bb0(%bar: f32, %alice: f32, %steve: f32) + ^bb0(%a: f32 loc("bar"), %b: f32 loc("alice"), %c: f32 loc("steve")): + // CHECK: linalg.yield %alice, %steve + linalg.yield %b, %c : f32, f32 + } -> (tensor, tensor) + + return +} + +// ----- + +// CHECK-LABEL: test_pass +func.func @test_pass(%arg0: memref<4xf32>, %arg1: memref<4xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + scf.for %arg2 = %c0 to %c4 step %c1 { + // CHECK-PASS-PRESERVE: %foo = memref.load + // CHECK-PASS-PRESERVE: memref.store %foo + // CHECK-PASS-PRESERVE: %foo_1 = memref.load + // CHECK-PASS-PRESERVE: memref.store %foo_1 + %0 = memref.load %arg0[%arg2] : memref<4xf32> loc("foo") + memref.store %0, %arg1[%arg2] : memref<4xf32> + } + return +}