Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][transform] Allow passing various library files to interpreter. #67120

Merged

Conversation

ingomueller-net
Copy link
Contributor

@ingomueller-net ingomueller-net commented Sep 22, 2023

The transfrom interpreter accepts an argument to a "library" file with named sequences. This patch exteneds this functionality such that (1) several such individual files are accepted and (2) folders can be passed in, in which all *.mlir files are loaded.

This PR depends on and therefor currently includes #67241 #67560.

@llvmbot llvmbot added the mlir label Sep 22, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 22, 2023

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Changes

The transfrom interpreter accepts an argument to a "library" file with named sequences. This patch exteneds this functionality such that (1) several such individual files are accepted and (2) folders can be passed in, in which all *.mlir files are loaded.


Full diff: https://github.com/llvm/llvm-project/pull/67120.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h (+29-10)
  • (modified) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp (+111-45)
  • (modified) mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir (+1-1)
  • (modified) mlir/test/Integration/Dialect/Transform/match_matmul.mlir (+1-1)
  • (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp (+11-6)
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
index 91903e254b0d5b3..2f33260ca9bd5a0 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
@@ -33,7 +33,8 @@ namespace detail {
 /// Template-free implementation of TransformInterpreterPassBase::initialize.
 LogicalResult interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
-    StringRef transformLibraryFileName,
+    ArrayRef<std::string> transformLibraryFileNames,
+    ArrayRef<std::string> transformLibraryDirNames,
     std::shared_ptr<OwningOpRef<ModuleOp>> &module,
     std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
@@ -48,7 +49,8 @@ LogicalResult interpreterBaseRunOnOperationImpl(
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName);
@@ -62,9 +64,12 @@ LogicalResult interpreterBaseRunOnOperationImpl(
 ///     transform script. If empty, `debugTransformRootTag` is considered or the
 ///     pass root operation must contain a single top-level transform op that
 ///     will be interpreted.
-///   - transformLibraryFileName: if non-empty, the name of the file containing
+///   - transformLibraryFileNames: if non-empty, the names of files containing
 ///     definitions of external symbols referenced in the transform script.
 ///     These definitions will be used to replace declarations.
+///   - transformLibraryDirNames: if non-empty, the name of directories
+///     containing definitions of external symbols referenced in the transform
+///     script. These definitions will be used to replace declarations.
 ///   - debugPayloadRootTag: if non-empty, the value of the attribute named
 ///     `kTransformDialectTagAttrName` indicating the single op that is
 ///     considered the payload root of the transform interpreter; otherwise, the
@@ -115,17 +120,30 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
     REQUIRE_PASS_OPTION(transformFileName);
     REQUIRE_PASS_OPTION(debugPayloadRootTag);
     REQUIRE_PASS_OPTION(debugTransformRootTag);
-    REQUIRE_PASS_OPTION(transformLibraryFileName);
 
 #undef REQUIRE_PASS_OPTION
 
+#define REQUIRE_PASS_LIST_OPTION(NAME)                                         \
+  static_assert(                                                               \
+      std::is_same_v<                                                          \
+          std::remove_reference_t<decltype(std::declval<Concrete &>().NAME)>,  \
+          Pass::ListOption<std::string>>,                                      \
+      "required " #NAME " string pass option is missing")
+
+    REQUIRE_PASS_LIST_OPTION(transformLibraryFileNames);
+    REQUIRE_PASS_LIST_OPTION(transformLibraryDirNames);
+
+#undef REQUIRE_PASS_LIST_OPTION
+
     StringRef transformFileName =
         static_cast<Concrete *>(this)->transformFileName;
-    StringRef transformLibraryFileName =
-        static_cast<Concrete *>(this)->transformLibraryFileName;
+    ArrayRef<std::string> transformLibraryFileNames =
+        static_cast<Concrete *>(this)->transformLibraryFileNames;
+    ArrayRef<std::string> transformLibraryDirNames =
+        static_cast<Concrete *>(this)->transformLibraryDirNames;
     return detail::interpreterBaseInitializeImpl(
-        context, transformFileName, transformLibraryFileName,
-        sharedTransformModule, transformLibraryModule,
+        context, transformFileName, transformLibraryFileNames,
+        transformLibraryDirNames, sharedTransformModule, transformLibraryModule,
         [this](OpBuilder &builder, Location loc) {
           return static_cast<Concrete *>(this)->constructTransformModule(
               builder, loc);
@@ -159,8 +177,9 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
             op, pass->getArgument(), sharedTransformModule,
             transformLibraryModule,
             /*extraMappings=*/{}, options, pass->transformFileName,
-            pass->transformLibraryFileName, pass->debugPayloadRootTag,
-            pass->debugTransformRootTag, binaryName)) ||
+            pass->transformLibraryFileNames, pass->transformLibraryDirNames,
+            pass->debugPayloadRootTag, pass->debugTransformRootTag,
+            binaryName)) ||
         failed(pass->runAfterInterpreter(op))) {
       return pass->signalPassFailure();
     }
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..1bdbba8cb6346b6 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -161,15 +161,24 @@ static llvm::raw_ostream &
 printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
                const Pass::Option<std::string> &debugPayloadRootTag,
                const Pass::Option<std::string> &debugTransformRootTag,
-               const Pass::Option<std::string> &transformLibraryFileName,
+               const Pass::ListOption<std::string> &transformLibraryFileNames,
+               const Pass::ListOption<std::string> &transformLibraryDirNames,
                StringRef binaryName) {
-  std::string transformLibraryOption = "";
-  if (!transformLibraryFileName.empty()) {
-    transformLibraryOption =
-        llvm::formatv(" {0}={1}", transformLibraryFileName.getArgStr(),
-                      transformLibraryFileName.getValue())
-            .str();
+  std::string transformLibraryOptions = "";
+  {
+    llvm::raw_string_ostream optionStream(transformLibraryOptions);
+    if (!transformLibraryFileNames.empty()) {
+      optionStream << " " << transformLibraryFileNames.getArgStr() << "='";
+      llvm::interleave(transformLibraryFileNames, optionStream, ",");
+      optionStream << "'";
+    }
+    if (!transformLibraryDirNames.empty()) {
+      optionStream << " " << transformLibraryDirNames.getArgStr() << "='";
+      llvm::interleave(transformLibraryDirNames, optionStream, ",");
+      optionStream << "'";
+    }
   }
+
   os << llvm::formatv(
       "{7} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}{6}})\"", rootOpName,
       passName, debugPayloadRootTag.getArgStr(),
@@ -180,7 +189,7 @@ printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
       debugTransformRootTag.empty()
           ? StringRef(kTransformDialectTagTransformContainerValue)
           : debugTransformRootTag,
-      transformLibraryOption, binaryName);
+      transformLibraryOptions, binaryName);
   return os;
 }
 
@@ -200,7 +209,8 @@ void saveReproToTempFile(
     llvm::raw_ostream &os, Operation *target, Operation *transform,
     StringRef passName, const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     StringRef binaryName) {
   using llvm::sys::fs::TempFile;
   Operation *root = getRootOperation(target);
@@ -227,7 +237,8 @@ void saveReproToTempFile(
   os << "=== Transform Interpreter Repro ===\n";
   printReproCall(os, root->getName().getStringRef(), passName,
                  debugPayloadRootTag, debugTransformRootTag,
-                 transformLibraryFileName, binaryName)
+                 transformLibraryFileNames, transformLibraryDirNames,
+                 binaryName)
       << " " << filename << "\n";
   os << "===================================\n";
 }
@@ -238,7 +249,8 @@ static void performOptionalDebugActions(
     Operation *target, Operation *transform, StringRef passName,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     StringRef binaryName) {
   MLIRContext *context = target->getContext();
 
@@ -279,10 +291,10 @@ static void performOptionalDebugActions(
 
   DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, {
     llvm::dbgs() << "=== Transform Interpreter Repro ===\n";
-    printReproCall(llvm::dbgs() << "cat <<EOF | ",
-                   root->getName().getStringRef(), passName,
-                   debugPayloadRootTag, debugTransformRootTag,
-                   transformLibraryFileName, binaryName)
+    printReproCall(
+        llvm::dbgs() << "cat <<EOF | ", root->getName().getStringRef(),
+        passName, debugPayloadRootTag, debugTransformRootTag,
+        transformLibraryFileNames, transformLibraryDirNames, binaryName)
         << "\n";
     printModuleForRepro(llvm::dbgs(), root, transform);
     llvm::dbgs() << "\nEOF\n";
@@ -292,7 +304,8 @@ static void performOptionalDebugActions(
   DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
     saveReproToTempFile(llvm::dbgs(), target, transform, passName,
                         debugPayloadRootTag, debugTransformRootTag,
-                        transformLibraryFileName, binaryName);
+                        transformLibraryFileNames, transformLibraryDirNames,
+                        binaryName);
   });
 
   // Remove temporary attributes if they were set.
@@ -383,7 +396,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName) {
@@ -449,7 +463,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   // repro to stderr and/or a file.
   performOptionalDebugActions(target, transformRoot, passName,
                               debugPayloadRootTag, debugTransformRootTag,
-                              transformLibraryFileName, binaryName);
+                              transformLibraryFileNames,
+                              transformLibraryDirNames, binaryName);
 
   // Step 5
   // ------
@@ -460,51 +475,102 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
 
 LogicalResult transform::detail::interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
-    StringRef transformLibraryFileName,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &module,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+    ArrayRef<std::string> transformLibraryFileNames,
+    ArrayRef<std::string> transformLibraryDirNames,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
         moduleBuilder) {
-  OwningOpRef<ModuleOp> parsed;
-  if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
-    return failure();
-  if (parsed && failed(mlir::verify(*parsed)))
-    return failure();
+  // Parse module from file.
+  OwningOpRef<ModuleOp> moduleFromFile;
+  {
+    auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
+    if (failed(parseTransformModuleFromFile(context, transformFileName,
+                                            moduleFromFile))) {
+      emitError(loc, "failed to parse transform module");
+      return failure();
+    }
+    if (moduleFromFile && failed(mlir::verify(*moduleFromFile))) {
+      emitError(loc, "failed to parse transform module");
+      return failure();
+    }
+  }
 
-  OwningOpRef<ModuleOp> parsedLibrary;
-  if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
-                                          parsedLibrary)))
-    return failure();
-  if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
-    return failure();
+  // Assemble list of library files.
+  SmallVector<std::string> libraryFileNames;
+  libraryFileNames.append(transformLibraryFileNames.begin(),
+                          transformLibraryFileNames.end());
+  // XXX: `.mlir` files from transformLibraryDirNames
+
+  // Parse modules from library files.
+  SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
+  for (const std::string &libraryFileName : libraryFileNames) {
+    OwningOpRef<ModuleOp> parsedLibrary;
+    auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
+    if (failed(parseTransformModuleFromFile(context, libraryFileName,
+                                            parsedLibrary))) {
+      emitError(loc, "failed to parse transform library module");
+      return failure();
+    }
+    if (parsedLibrary && failed(mlir::verify(*parsedLibrary))) {
+      emitError(loc, "failed to verify transform module");
+      return failure();
+    }
+    parsedLibraries.push_back(std::move(parsedLibrary));
+  }
 
-  if (parsed) {
-    module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+  // Build shared transform module.
+  if (moduleFromFile) {
+    sharedTransformModule =
+        std::make_shared<OwningOpRef<ModuleOp>>(std::move(moduleFromFile));
   } else if (moduleBuilder) {
     // TODO: better location story.
-    auto location = UnknownLoc::get(context);
+    auto loc = UnknownLoc::get(context);
     auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
-        ModuleOp::create(location, "__transform"));
+        ModuleOp::create(loc, "__transform"));
 
     OpBuilder b(context);
     b.setInsertionPointToEnd(localModule->get().getBody());
-    if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
-      if (failed(*result))
+    if (std::optional<LogicalResult> result = moduleBuilder(b, loc)) {
+      if (failed(*result)) {
+        emitError(loc, "failed to create shared transform module");
         return failure();
-      module = std::move(localModule);
+      }
+      sharedTransformModule = std::move(localModule);
     }
   }
 
-  if (!parsedLibrary || !*parsedLibrary)
+  if (parsedLibraries.empty())
     return success();
 
-  if (module && *module) {
-    if (failed(defineDeclaredSymbols(*module->get().getBody(),
-                                     parsedLibrary.get())))
+  // Merge parsed libraries into one module.
+  // TODO: better location story.
+  auto loc = UnknownLoc::get(context);
+  OwningOpRef<ModuleOp> mergedParsedLibraries =
+      ModuleOp::create(loc, "__transform");
+  mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
+                                       UnitAttr::get(context));
+
+  IRRewriter rewriter(context);
+  rewriter.setInsertionPointToEnd(mergedParsedLibraries->getBody());
+  for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
+    rewriter.inlineBlockBefore(parsedLibrary->getBody(),
+                               mergedParsedLibraries->getBody(),
+                               mergedParsedLibraries->getBody()->end());
+    if (failed(mlir::verify(*mergedParsedLibraries))) {
+      emitError(loc, "failed to verify merged transform module");
+      return failure();
+    }
+  }
+
+  // Merge parsed libraries into shared module or return as library module.
+  if (sharedTransformModule && *sharedTransformModule) {
+    if (failed(defineDeclaredSymbols(*sharedTransformModule->get().getBody(),
+                                     mergedParsedLibraries.get())))
       return failure();
   } else {
-    libraryModule =
-        std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibrary));
+    transformLibraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
+        std::move(mergedParsedLibraries));
   }
   return success();
 }
diff --git a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
index 73bc243ad76060d..c7c055c7567e6f3 100644
--- a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-names=%p/match_matmul_common.mlir' --verify-diagnostics
 
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index f164a3d1bd99dd0..3e06f2ca0895841 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-names=%p/match_matmul_common.mlir' --verify-diagnostics
 
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index f73deef9d5fd48c..28e5256385efaea 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -161,7 +161,8 @@ class TestTransformDialectInterpreterPass
     if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
             getOperation(), getArgument(), getSharedTransformModule(),
             getTransformLibraryModule(), extraMapping, options,
-            transformFileName, transformLibraryFileName, debugPayloadRootTag,
+            transformFileName, transformLibraryFileNames,
+            transformLibraryDirNames, debugPayloadRootTag,
             debugTransformRootTag, getBinaryName())))
       return signalPassFailure();
   }
@@ -216,12 +217,16 @@ class TestTransformDialectInterpreterPass
           "the given value as container IR for top-level transform ops. This "
           "allows user control on what transformation to apply. If empty, "
           "select the container of the top-level transform op.")};
-  Option<std::string> transformLibraryFileName{
-      *this, "transform-library-file-name", llvm::cl::init(""),
+  ListOption<std::string> transformLibraryFileNames{
+      *this, "transform-library-file-names", llvm::cl::ZeroOrMore,
+      llvm::cl::desc("Optional filenames containing transform dialect symbol "
+                     "definitions to be injected into the transform module.")};
+  ListOption<std::string> transformLibraryDirNames{
+      *this, "transform-library-dir-names", llvm::cl::ZeroOrMore,
       llvm::cl::desc(
-          "Optional name of the file containing transform dialect symbol "
-          "definitions to be injected into the transform module.")};
-
+          "Optional directories containing transform dialect symbol "
+          "definitions to be injected into the transform module. All '.mlir' "
+          "files rooted under this directory will be loaded.")};
   Option<bool> testModuleGeneration{
       *this, "test-module-generation", llvm::cl::init(false),
       llvm::cl::desc("test the generation of the transform module during pass "

@ingomueller-net
Copy link
Contributor Author

@qcolombet: This just got a massive feature boost and should be relatively usable now.

@ingomueller-net ingomueller-net force-pushed the transform-interpreter-files branch 4 times, most recently from 2957ce3 to e419307 Compare September 25, 2023 08:21
@ingomueller-net ingomueller-net force-pushed the transform-interpreter-files branch 4 times, most recently from e65d6a9 to a9a5050 Compare October 4, 2023 12:57
@ingomueller-net ingomueller-net marked this pull request as ready for review October 4, 2023 12:57
Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

// Assemble list of library files.
SmallVector<std::string> libraryFileNames;
for (const std::string &path : transformLibraryPaths) {
auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transformFileName -> path

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed in latest commit.

Comment on lines 635 to 668
SmallVector<std::string> libraryFileNames;
for (const std::string &path : transformLibraryPaths) {
auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);

if (llvm::sys::fs::is_regular_file(path)) {
LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
libraryFileNames.push_back(path);
continue;
}

if (!llvm::sys::fs::is_directory(path)) {
return emitError(loc)
<< "'" << path << "' is neither a file nor a directory";
}

LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");

std::error_code ec;
for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
it != itEnd && !ec; it.increment(ec)) {
const std::string &fileName = it->path();

if (it->type() != llvm::sys::fs::file_type::regular_file) {
LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName
<< "'\n");
continue;
}

if (!StringRef(fileName).endswith(".mlir")) {
LLVM_DEBUG(DBGS() << " Skipping '" << fileName
<< "' because it does not end with '.mlir'\n");
continue;
}

LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
libraryFileNames.push_back(fileName);
}

if (ec)
return emitError(loc) << "error while opening files in '" << path
<< "': " << ec.message();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we refactor this expansion into a static function that we just call from here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! Done in latest commit.

Copy link
Contributor Author

@ingomueller-net ingomueller-net left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! Awesome, this is finally getting to the finish line :) Will merge when CI has passed.

// Assemble list of library files.
SmallVector<std::string> libraryFileNames;
for (const std::string &path : transformLibraryPaths) {
auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed in latest commit.

Comment on lines 635 to 668
SmallVector<std::string> libraryFileNames;
for (const std::string &path : transformLibraryPaths) {
auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);

if (llvm::sys::fs::is_regular_file(path)) {
LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
libraryFileNames.push_back(path);
continue;
}

if (!llvm::sys::fs::is_directory(path)) {
return emitError(loc)
<< "'" << path << "' is neither a file nor a directory";
}

LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");

std::error_code ec;
for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
it != itEnd && !ec; it.increment(ec)) {
const std::string &fileName = it->path();

if (it->type() != llvm::sys::fs::file_type::regular_file) {
LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName
<< "'\n");
continue;
}

if (!StringRef(fileName).endswith(".mlir")) {
LLVM_DEBUG(DBGS() << " Skipping '" << fileName
<< "' because it does not end with '.mlir'\n");
continue;
}

LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
libraryFileNames.push_back(fileName);
}

if (ec)
return emitError(loc) << "error while opening files in '" << path
<< "': " << ec.message();
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! Done in latest commit.

The transfrom interpreter accepts an argument to a "library" file with
named sequences. This patch exteneds this functionality such that (1)
several such individual files are accepted and (2) folders can be passed
in, in which all `*.mlir` files are loaded.
@ingomueller-net ingomueller-net merged commit 6a2071c into llvm:main Oct 6, 2023
3 checks passed
@ingomueller-net ingomueller-net deleted the transform-interpreter-files branch October 6, 2023 10:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants