Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][transform] Fix handling of transitive includes in interpreter. #67241

Conversation

ingomueller-net
Copy link
Contributor

@ingomueller-net ingomueller-net commented Sep 23, 2023

Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not also declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules.

This PR extends the loading missing as follows: in defineDeclaredSymbols, not only are the definitions inserted that are forward-declared in the main module, but any such inserted definition is scanned for further dependencies, and those are processed in the same way as the forward-declarations from the main module.

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 23, 2023

@llvm/pr-subscribers-mlir

Changes

Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not also declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules.

This PR extends the loading missing as follows: in defineDeclaredSymbols, not only are the definitions inserted that are forward-declared in the main module, but any such inserted definition is scanned for further dependencies, and those are processed in the same way as the forward-declarations from the main module.


Full diff: https://github.com/llvm/llvm-project/pull/67241.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp (+61-8)
  • (added) mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir (+25)
  • (modified) mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir (+5)
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..aa2b1157c254b47 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -311,6 +311,9 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
   auto readOnlyName =
       StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
 
+  // Collect symbols missing in the block.
+  SmallVector<SymbolOpInterface> missingSymbols;
+  LLVM_DEBUG(DBGS() << "searching block for missing symbols:\n");
   for (Operation &op : llvm::make_early_inc_range(block)) {
     LLVM_DEBUG(DBGS() << op << "\n");
     auto symbol = dyn_cast<SymbolOpInterface>(op);
@@ -318,25 +321,33 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
       continue;
     if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
       continue;
+    LLVM_DEBUG(DBGS() << "  -> symbol missing\n");
+    missingSymbols.push_back(symbol);
+  }
 
-    LLVM_DEBUG(DBGS() << "looking for definition of symbol "
-                      << symbol.getNameAttr() << ":");
-    SymbolTable symbolTable(definitions);
-    Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
+  // Resolve missing symbols until they are all resolved.
+  while (!missingSymbols.empty()) {
+    SymbolOpInterface symbol = missingSymbols.pop_back_val();
+    LLVM_DEBUG(DBGS() << "looking for definition of symbol @"
+                      << symbol.getNameAttr().getValue() << ": ");
+    SymbolTable definitionsSymbolTable(definitions);
+    Operation *externalSymbol =
+        definitionsSymbolTable.lookup(symbol.getNameAttr());
     if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
         externalSymbol->getRegion(0).empty()) {
       LLVM_DEBUG(llvm::dbgs() << "not found\n");
       continue;
     }
 
-    auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
+    auto symbolFunc = dyn_cast<FunctionOpInterface>(symbol.getOperation());
     auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
     if (!symbolFunc || !externalSymbolFunc) {
       LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
       continue;
     }
 
-    LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
+    LLVM_DEBUG(llvm::dbgs() << "found " << externalSymbol << " from "
+                            << externalSymbol->getLoc() << "\n");
     if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
       return symbolFunc.emitError()
              << "external definition has a mismatching signature ("
@@ -367,10 +378,52 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
       }
     }
 
-    OpBuilder builder(&op);
-    builder.setInsertionPoint(&op);
+    OpBuilder builder(symbol);
+    builder.setInsertionPoint(symbol);
     builder.clone(*externalSymbol);
     symbol->erase();
+
+    LLVM_DEBUG(DBGS() << "scanning definition of @"
+                      << externalSymbolFunc.getNameAttr().getValue()
+                      << " for symbol usages\n");
+    externalSymbolFunc.walk([&](CallOpInterface callOp) {
+      LLVM_DEBUG(DBGS() << "  found symbol usage in:\n" << callOp << "\n");
+      CallInterfaceCallable callable = callOp.getCallableForCallee();
+      if (!isa<SymbolRefAttr>(callable)) {
+        LLVM_DEBUG(DBGS() << "    not a 'SymbolRefAttr'\n");
+        return WalkResult::advance();
+      }
+
+      StringRef callableSymbol =
+          cast<SymbolRefAttr>(callable).getLeafReference();
+      LLVM_DEBUG(DBGS() << "    looking for @" << callableSymbol
+                        << " in definitions: ");
+
+      Operation *callableOp = definitionsSymbolTable.lookup(callableSymbol);
+      if (!isa<SymbolRefAttr>(callable)) {
+        LLVM_DEBUG(llvm::dbgs() << "not found\n");
+        return WalkResult::advance();
+      }
+      LLVM_DEBUG(llvm::dbgs() << "found " << callableOp << " from "
+                              << callableOp->getLoc() << "\n");
+
+      if (!block.getParent() || !block.getParent()->getParentOp()) {
+        LLVM_DEBUG(DBGS() << "could not get parent of provided block");
+        return WalkResult::advance();
+      }
+
+      SymbolTable targetSymbolTable(block.getParent()->getParentOp());
+      if (targetSymbolTable.lookup(callableSymbol)) {
+        LLVM_DEBUG(DBGS() << "    symbol @" << callableSymbol
+                          << " already present in target\n");
+        return WalkResult::advance();
+      }
+
+      LLVM_DEBUG(DBGS() << "    cloning op into target\n");
+      builder.clone(*callableOp);
+
+      return WalkResult::advance();
+    });
   }
 
   return success();
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir
new file mode 100644
index 000000000000000..3a122ce2f77c3a8
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// The definition of the @foo named sequence is provided in another file. It
+// will be included because of the pass option. Repeated application of the
+// same pass, with or without the library option, should not be a problem.
+// Note that the same diagnostic produced twice at the same location only
+// needs to be matched once.
+
+// expected-remark @below {{message}}
+// expected-remark @below {{unannotated}}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence private @bar(!transform.any_op {transform.readonly})
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    include @bar failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
index 1149bda98ab8527..9aa2d46d5abb995 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
@@ -1,6 +1,11 @@
 // RUN: mlir-opt %s
 
 module attributes {transform.with_named_sequence} {
+  transform.named_sequence @bar(%arg0: !transform.any_op) {
+    transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+
   transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) {
     transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
     transform.yield

Comment on lines 410 to 422
if (!block.getParent() || !block.getParent()->getParentOp()) {
LLVM_DEBUG(DBGS() << "could not get parent of provided block");
return WalkResult::advance();
}

SymbolTable targetSymbolTable(block.getParent()->getParentOp());
if (targetSymbolTable.lookup(callableSymbol)) {
LLVM_DEBUG(DBGS() << " symbol @" << callableSymbol
<< " already present in target\n");
return WalkResult::advance();
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not perfectly happy with this part yet. I am wondering whether defineDeclaredSymbols shouldn't accept the parent op right away rather than a block. If that isn't acceptable for some reason, I wonder whether this shouldn't fail rather just outputting a silent debug message.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to be careful here, as this basically a proto-linker. One low-tech solution could be to just clone all symbols from the library file(s) into the main one, not just the declared/referenced ones. The difficult issue in any case is name clashes between files. AFAIR, named_sequence doesn't have the visibility attribute and it probably should for this purpose. It would explicitly tell us where renaming a symbol to avoid the name clash is acceptable (private symbol) and when it isn't. When it isn't, the process needs to run in three stages: (1) collect all definitions (i.e., with bodies) of named sequences that need to be cloned without actually cloning them + check that there is exactly one definition of each if renaming is not allowed; (2) if needed and allowed, rename private symbols and update their uses in each module individually; (3) actually clone (also let's consider moving them instead).

Limiting the process to only the symbols that are actually needed is an optimization. We can do it, but can also punt. In the general case, we need something like connected components for all named sequences in the main file as it may also have unused declarations. There's CallGraph for that and LLVM has algorithms to compute connected components so no need to write this manually.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is kind of a linker. I am actually wondering if it shouldn't live in a more central place?

But, boy, it's the third time that I realize that this is more involved than I originally thought :P But I think I understand the issue and should be able to fix it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, it could become a more generic utility. But I wouldn't go out of my way right now as I'm not aware of any other use case, just keeping it relatively separable by being interface-based (symbol tables, call graph) should be sufficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NamedSequenceOp has the FunctionOpInterface, which inherits from SymbolOpInterface, which is where getVisibility is declared/defined, so named sequences already have the visibility attribute as it seems.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I think that you propose makes sense. Linking in all symbols instead of just the referenced once is something that I need for #67120.

One question, though: SymbolOpInterface alone does not know the concept of "external" (aka the distinction between definition vs declaration); only FunctionOpInterface does. However, the previous implementation of defineDeclaredSymbols tested for SymbolOpInterface and symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty(), which corresponds to the definition of isExternal in FunctionOpInterface. I think the more general case to deal with that is to implement the case that allows external functions (aka declarations) for the special case FunctionOpInterface and to not allow clashes of public symbols otherwise. Do you agree?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good as a first approximation. I suspect that one will have to lift the declaration/definition distinction out of FunctionOpInterface into either SymbolOpInterface or a new interface as it makes sense for other kinds of symbols such as memref globals. We don't care about these for our use case, so let's stick with functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Limiting the process to only the symbols that are actually needed is an optimization. We can do it, but can also punt. In the general case, we need something like connected components for all named sequences in the main file as it may also have unused declarations. There's CallGraph for that and LLVM has algorithms to compute connected components so no need to write this manually.

I briefly started to think about this. The call graph analysis works on CallOpInterface and CallableOpInterface, which I think should be enough for transform.include and transform.sequence at the moment. However, if some op used a symbol defined in the library module that does not implement the CallableOpInterface, we wouldn't capture that relationship with the call graph analysis. I don't know of any (transform) op that does that currently, so it may not be an issue. For full generality, wouldn't we have to build a "used by" graph of symbols (or even a mixture of that and the call graph)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the general case, yes, we would need to care about things like addressof that are not a call. I don't have an immediate plan for having something like that in the transform dialect.

Until now, the interpreter would only load those symbols from the
provided library files that were declared in the main transform module.
However, sequences in the library may include other sequences on their
own. Until now, if such sequences were not *also* declared in the main
transform module, the interpreter would fail to resolve them. Forward
declaring all of them is undesirable as it defeats the purpose of
encapsulation into library modules.

This PR extends the loading missing as follows: in
`defineDeclaredSymbols`, not only are the definitions inserted that are
forward-declared in the main module, but any such inserted definition is
scanned for further dependencies, and those are processed in the same
way as the forward-declarations from the main module.
@ingomueller-net ingomueller-net force-pushed the transform-interpreter-recursive-include branch from ab3e70b to 2b7d6b4 Compare September 25, 2023 08:03
@ingomueller-net
Copy link
Contributor Author

Closing in favor or #67560, which has a better design and is more mature and complete.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants