Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] refine sparse assembler strategy #80521

Merged
merged 2 commits into from
Feb 5, 2024
Merged

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Feb 3, 2024

Rewrite all public methods, making original internal, private methods, and exposing wrappers under the original name. This works a bit better in practice (when combined with c-interface mechanism of torch-mlir for example).

Rewrite *all* public methods, making original internal,
private methods, and exposing wrappers under the original
name. This works a bit better in practice (when combined
with c-interface mechanism of torch-mlir for example).
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 3, 2024

@llvm/pr-subscribers-mlir-sparse

Author: Aart Bik (aartbik)

Changes

Rewrite all public methods, making original internal, private methods, and exposing wrappers under the original name. This works a bit better in practice (when combined with c-interface mechanism of torch-mlir for example).


Full diff: https://github.com/llvm/llvm-project/pull/80521.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (+30-25)
  • (modified) mlir/test/Dialect/SparseTensor/external.mlir (+27-22)
  • (added) mlir/test/Dialect/SparseTensor/torch_linalg.mlir (+55)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 8772d5f127949..58e2d6f32386c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -15,7 +15,7 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
   let summary = "Add [dis]assemble operations on external sparse tensors";
   let description = [{
     A pass that converts public entry methods that use sparse tensors as
-    input parameters and/or output return values into wrapper functions
+    input parameters and/or output return values into wrapper methods
     that [dis]assemble the individual tensors that constitute the actual
     storage used externally into MLIR sparse tensors. This pass can be used
     to prepare the public entry methods of a program that is compiled by the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index f9b6397e0f086..b4cefec8fb21f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -132,29 +132,29 @@ void convVals(OpBuilder &builder, Location loc, TypeRange types,
 namespace {
 
 // A rewriting rules that converts public entry methods that use sparse tensors
-// as input parameters and/or output return values into wrapper functions
-// that [dis]assemble the individual tensors that constitute the actual
-// storage used externally into MLIR sparse tensors.
+// as input parameters and/or output return values into wrapper methods that
+// [dis]assemble the individual tensors that constitute the actual storage used
+// externally into MLIR sparse tensors before calling the origal method.
 //
 // In particular, each sparse tensor input
 //
 // void foo(..., t, ...) { }
 //
-// adds the following strucuture in a wrapper
+// makes the original foo() internal and adds the following wrapper method
 //
-// void spiface_foo(..., t1..tn, ...) {
+// void foo(..., t1..tn, ...) {
 //   t = assemble t1..tn
-//   foo(..., t, ...)
+//   _internal_foo(..., t, ...)
 // }
 //
 // and likewise, each output tensor
 //
 // ... T ... bar(...) { return ..., t, ...; }
 //
-// adds the following structure in a wrapper
+// makes the original bar() internal and adds the following wrapper method
 //
-// ... T1..TN ... spiface_bar(..., t1'..tn') {
-//   ..., t, ... = bar(...)
+// ... T1..TN ... bar(..., t1'..tn') {
+//   ..., t, ... = _internal_bar(...)
 //   t1..tn = disassemble t, t1'..tn'
 //   return ..., t1..tn, ...
 // }
@@ -168,9 +168,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
 
   LogicalResult matchAndRewrite(func::FuncOp funcOp,
                                 PatternRewriter &rewriter) const override {
-    // Only a rewrite an entry with the c-interface requested.
-    if (!funcOp->getAttrOfType<UnitAttr>(
-            LLVM::LLVMDialect::getEmitCWrapperAttrName()))
+    // Only rewrite public entry methods.
+    if (funcOp.isPrivate())
       return failure();
 
     // Translate sparse tensor types to external types.
@@ -180,29 +179,29 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     convTypes(funcOp.getArgumentTypes(), inputTypes);
     convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes);
 
-    // Only sparse inputs or outputs need a wrapper function.
+    // Only sparse inputs or outputs need a wrapper method.
     if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
         outputTypes.size() == funcOp.getResultTypes().size())
       return failure();
 
-    // Start the new wrapper function. Together with the c-interface mangling,
-    // a sparse external entry point eventually will have a name like:
-    //    _mlir_ciface_spiface_XXX(...)
+    // Modify the original method into an internal, private method.
+    auto orgName = funcOp.getName();
+    std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
+    funcOp.setName(wrapper);
+    funcOp.setPrivate();
+
+    // Start the new public wrapper method with original name.
     Location loc = funcOp.getLoc();
     ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
     MLIRContext *context = modOp.getContext();
     OpBuilder moduleBuilder(modOp.getBodyRegion());
-    std::string wrapper = llvm::formatv("spiface_{0}", funcOp.getName()).str();
     unsigned extra = inputTypes.size();
     inputTypes.append(extraTypes);
     auto func = moduleBuilder.create<func::FuncOp>(
-        loc, wrapper, FunctionType::get(context, inputTypes, outputTypes));
+        loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
     func.setPublic();
-    func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
-                  UnitAttr::get(context));
 
-    // Construct new wrapper function body.
-    auto org = SymbolRefAttr::get(context, funcOp.getName());
+    // Construct new wrapper method body.
     OpBuilder::InsertionGuard insertionGuard(rewriter);
     Block *body = func.addEntryBlock();
     rewriter.setInsertionPointToStart(body);
@@ -212,7 +211,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
              ValueRange(), inputs, 0, /*isIn=*/true);
 
-    // Call original function.
+    // Call original, now internal method.
+    auto org = SymbolRefAttr::get(context, wrapper);
     auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
                                               inputs);
 
@@ -222,8 +222,13 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
              body->getArguments(), outputs, extra, /*isIn=*/false);
     rewriter.create<func::ReturnOp>(loc, outputs);
 
-    // Strip the c-interface attribute from the original function.
-    funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
+    // Finally, migrate a potential c-interface property.
+    if (funcOp->getAttrOfType<UnitAttr>(
+            LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
+      func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
+                    UnitAttr::get(context));
+      funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
+    }
     return success();
   }
 };
diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index 57df8aca3a6a5..c17ba13e86c92 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -3,95 +3,100 @@
 // -----
 
 // CHECK-LABEL: func.func @nop(
-// CHECK-SAME:    %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> attributes {llvm.emit_c_interface} {
+// CHECK-SAME:    %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> {
 // CHECK:         return %[[A]] : tensor<100xf32>
 // CHECK:       }
-func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes { llvm.emit_c_interface } {
+func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
   return %arg0 : tensor<100xf32>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_in(
+// CHECK-LABEL: func.func @sparse_in(
 // CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
+// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
 // CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
-// CHECK:         %[[F:.*]] = call @sparse_in(%[[I]])
+// CHECK:         %[[F:.*]] = call @_internal_sparse_in(%[[I]])
 // CHECK:         return %[[F]] : tensor<64x64xf32>
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_in
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } {
+func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
   %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
   return %0 : tensor<64x64xf32>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_in2(
+// CHECK-LABEL: func.func @sparse_in2(
 // CHECK-SAME:    %[[X:.*]]: tensor<100xf32>,
 // CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
+// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
 // CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
-// CHECK:         %[[F:.*]] = call @sparse_in2(%[[X]], %[[I]])
+// CHECK:         %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
 // CHECK:         return %[[F]] : tensor<64x64xf32>
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_in2
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } {
+func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
   %0 = sparse_tensor.convert %arg1 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
   return %0 : tensor<64x64xf32>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_out(
+// CHECK-LABEL: func.func @sparse_out(
 // CHECK-SAME:    %[[X:.*]]: tensor<64x64xf32>,
 // CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
-// CHECK:         %[[F:.*]] = call @sparse_out(%[[X]])
+// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK:         %[[F:.*]] = call @_internal_sparse_out(%[[X]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_out
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } {
+func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
   %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
   return %0 : tensor<64x64xf32, #sparse>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_out2(
+// CHECK-LABEL: func.func @sparse_out2(
 // CHECK-SAME:    %[[X:.*]]: tensor<64x64xf32>,
 // CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
-// CHECK:         %[[F:.*]]:2 = call @sparse_out2(%[[X]])
+// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK:         %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
 // CHECK:         sparse_tensor.disassemble %[[F]]#1
 // CHECK:         return %[[F]]#0
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_out2
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) attributes { llvm.emit_c_interface } {
+func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) {
   %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
   return %arg0, %0 : tensor<64x64xf32>, tensor<64x64xf32, #sparse>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_inout(
+// CHECK-LABEL: func.func @sparse_inout(
 // CHECK-SAME:    %[[A:.*0]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
 // CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
 // CHECK-SAME:    %[[D:.*3]]: tensor<?xf32>,
 // CHECK-SAME:    %[[E:.*4]]: tensor<?xindex>,
-// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
+// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
 // CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
-// CHECK:         %[[F:.*]] = call @sparse_inout(%[[I]])
+// CHECK:         %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_inout
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } {
+func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> {
   return %arg0 : tensor<64x64xf32, #sparse>
 }
diff --git a/mlir/test/Dialect/SparseTensor/torch_linalg.mlir b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
new file mode 100644
index 0000000000000..f29e6b143783a
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt %s --sparse-assembler                 | FileCheck %s --check-prefix=CHECK-HI
+// RUN: mlir-opt %s --sparse-assembler \
+// RUN:             --linalg-generalize-named-ops \
+// RUN:             --linalg-fuse-elementwise-ops \
+// RUN:             --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-MID
+// RUN: mlir-opt %s --sparse-assembler \
+// RUN:             --sparsifier                       | FileCheck %s --check-prefix=CHECK-LOW
+
+//
+// An example of a module generated by torch-mlir with a sparse tensor from
+// torch.sparse. The MLIR sparsifier should be able to provide the external
+// API through a wrapper method (spiface and ciface). Various passes should
+// compose without trouble.
+//
+
+// CHECK-HI-LABEL: func.func @main
+// CHECK-HI:         sparse_tensor.assemble
+// CHECK-HI:         call @_internal_main
+// CHECK-HI:         return
+// CHECK-HI:       func.func private @_internal_main
+// CHECK-HI:         linalg.matmul
+// CHECK-HI:         return
+//
+// CHECK-MID-LABEL: func.func @main
+// CHECK-MID:          memref.load
+// CHECK-MID:          call @_internal_main
+// CHECK-MID:          return
+// CHECK-MID:       func.func private @_internal_main
+// CHECK-MID:          scf.for
+// CHECK-MID:            scf.for
+// CHECK-MID:          return
+
+// CHECK-LOW-LABEL: llvm.func @main
+// CHECK-LOW:         llvm.call @_internal_main
+// CHECK-LOW:         llvm.return
+// CHECK-LOW:       llvm.func @_mlir_ciface_main
+// CHECK-LOW:         llvm.call @main
+// CHECK-LOW:         llvm.return
+// CHECK-LOW:       llvm.func @_internal_main
+// CHECK-SAME:        {sym_visibility = "private"}
+// CHECK-LOW:         llvm.return
+
+#csc = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+module {
+  func.func @main(%arg0: tensor<64x64xf32, #csc>,
+                  %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = tensor.empty() : tensor<64x64xf32>
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64x64xf32>) -> tensor<64x64xf32>
+    %2 = linalg.matmul
+      ins(%arg0, %arg1 : tensor<64x64xf32, #csc>, tensor<64x64xf32>)
+      outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32>
+    return %2 : tensor<64x64xf32>
+  }
+}

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 3, 2024

@llvm/pr-subscribers-mlir

Author: Aart Bik (aartbik)

Changes

Rewrite all public methods, making original internal, private methods, and exposing wrappers under the original name. This works a bit better in practice (when combined with c-interface mechanism of torch-mlir for example).


Full diff: https://github.com/llvm/llvm-project/pull/80521.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (+30-25)
  • (modified) mlir/test/Dialect/SparseTensor/external.mlir (+27-22)
  • (added) mlir/test/Dialect/SparseTensor/torch_linalg.mlir (+55)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 8772d5f127949..58e2d6f32386c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -15,7 +15,7 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
   let summary = "Add [dis]assemble operations on external sparse tensors";
   let description = [{
     A pass that converts public entry methods that use sparse tensors as
-    input parameters and/or output return values into wrapper functions
+    input parameters and/or output return values into wrapper methods
     that [dis]assemble the individual tensors that constitute the actual
     storage used externally into MLIR sparse tensors. This pass can be used
     to prepare the public entry methods of a program that is compiled by the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index f9b6397e0f086..b4cefec8fb21f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -132,29 +132,29 @@ void convVals(OpBuilder &builder, Location loc, TypeRange types,
 namespace {
 
 // A rewriting rules that converts public entry methods that use sparse tensors
-// as input parameters and/or output return values into wrapper functions
-// that [dis]assemble the individual tensors that constitute the actual
-// storage used externally into MLIR sparse tensors.
+// as input parameters and/or output return values into wrapper methods that
+// [dis]assemble the individual tensors that constitute the actual storage used
+// externally into MLIR sparse tensors before calling the origal method.
 //
 // In particular, each sparse tensor input
 //
 // void foo(..., t, ...) { }
 //
-// adds the following strucuture in a wrapper
+// makes the original foo() internal and adds the following wrapper method
 //
-// void spiface_foo(..., t1..tn, ...) {
+// void foo(..., t1..tn, ...) {
 //   t = assemble t1..tn
-//   foo(..., t, ...)
+//   _internal_foo(..., t, ...)
 // }
 //
 // and likewise, each output tensor
 //
 // ... T ... bar(...) { return ..., t, ...; }
 //
-// adds the following structure in a wrapper
+// makes the original bar() internal and adds the following wrapper method
 //
-// ... T1..TN ... spiface_bar(..., t1'..tn') {
-//   ..., t, ... = bar(...)
+// ... T1..TN ... bar(..., t1'..tn') {
+//   ..., t, ... = _internal_bar(...)
 //   t1..tn = disassemble t, t1'..tn'
 //   return ..., t1..tn, ...
 // }
@@ -168,9 +168,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
 
   LogicalResult matchAndRewrite(func::FuncOp funcOp,
                                 PatternRewriter &rewriter) const override {
-    // Only a rewrite an entry with the c-interface requested.
-    if (!funcOp->getAttrOfType<UnitAttr>(
-            LLVM::LLVMDialect::getEmitCWrapperAttrName()))
+    // Only rewrite public entry methods.
+    if (funcOp.isPrivate())
       return failure();
 
     // Translate sparse tensor types to external types.
@@ -180,29 +179,29 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     convTypes(funcOp.getArgumentTypes(), inputTypes);
     convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes);
 
-    // Only sparse inputs or outputs need a wrapper function.
+    // Only sparse inputs or outputs need a wrapper method.
     if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
         outputTypes.size() == funcOp.getResultTypes().size())
       return failure();
 
-    // Start the new wrapper function. Together with the c-interface mangling,
-    // a sparse external entry point eventually will have a name like:
-    //    _mlir_ciface_spiface_XXX(...)
+    // Modify the original method into an internal, private method.
+    auto orgName = funcOp.getName();
+    std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
+    funcOp.setName(wrapper);
+    funcOp.setPrivate();
+
+    // Start the new public wrapper method with original name.
     Location loc = funcOp.getLoc();
     ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
     MLIRContext *context = modOp.getContext();
     OpBuilder moduleBuilder(modOp.getBodyRegion());
-    std::string wrapper = llvm::formatv("spiface_{0}", funcOp.getName()).str();
     unsigned extra = inputTypes.size();
     inputTypes.append(extraTypes);
     auto func = moduleBuilder.create<func::FuncOp>(
-        loc, wrapper, FunctionType::get(context, inputTypes, outputTypes));
+        loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
     func.setPublic();
-    func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
-                  UnitAttr::get(context));
 
-    // Construct new wrapper function body.
-    auto org = SymbolRefAttr::get(context, funcOp.getName());
+    // Construct new wrapper method body.
     OpBuilder::InsertionGuard insertionGuard(rewriter);
     Block *body = func.addEntryBlock();
     rewriter.setInsertionPointToStart(body);
@@ -212,7 +211,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
              ValueRange(), inputs, 0, /*isIn=*/true);
 
-    // Call original function.
+    // Call original, now internal method.
+    auto org = SymbolRefAttr::get(context, wrapper);
     auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
                                               inputs);
 
@@ -222,8 +222,13 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
              body->getArguments(), outputs, extra, /*isIn=*/false);
     rewriter.create<func::ReturnOp>(loc, outputs);
 
-    // Strip the c-interface attribute from the original function.
-    funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
+    // Finally, migrate a potential c-interface property.
+    if (funcOp->getAttrOfType<UnitAttr>(
+            LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
+      func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
+                    UnitAttr::get(context));
+      funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
+    }
     return success();
   }
 };
diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index 57df8aca3a6a5..c17ba13e86c92 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -3,95 +3,100 @@
 // -----
 
 // CHECK-LABEL: func.func @nop(
-// CHECK-SAME:    %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> attributes {llvm.emit_c_interface} {
+// CHECK-SAME:    %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> {
 // CHECK:         return %[[A]] : tensor<100xf32>
 // CHECK:       }
-func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes { llvm.emit_c_interface } {
+func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
   return %arg0 : tensor<100xf32>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_in(
+// CHECK-LABEL: func.func @sparse_in(
 // CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
+// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
 // CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
-// CHECK:         %[[F:.*]] = call @sparse_in(%[[I]])
+// CHECK:         %[[F:.*]] = call @_internal_sparse_in(%[[I]])
 // CHECK:         return %[[F]] : tensor<64x64xf32>
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_in
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } {
+func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
   %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
   return %0 : tensor<64x64xf32>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_in2(
+// CHECK-LABEL: func.func @sparse_in2(
 // CHECK-SAME:    %[[X:.*]]: tensor<100xf32>,
 // CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
+// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
 // CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
-// CHECK:         %[[F:.*]] = call @sparse_in2(%[[X]], %[[I]])
+// CHECK:         %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
 // CHECK:         return %[[F]] : tensor<64x64xf32>
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_in2
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } {
+func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
   %0 = sparse_tensor.convert %arg1 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
   return %0 : tensor<64x64xf32>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_out(
+// CHECK-LABEL: func.func @sparse_out(
 // CHECK-SAME:    %[[X:.*]]: tensor<64x64xf32>,
 // CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
-// CHECK:         %[[F:.*]] = call @sparse_out(%[[X]])
+// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK:         %[[F:.*]] = call @_internal_sparse_out(%[[X]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_out
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } {
+func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
   %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
   return %0 : tensor<64x64xf32, #sparse>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_out2(
+// CHECK-LABEL: func.func @sparse_out2(
 // CHECK-SAME:    %[[X:.*]]: tensor<64x64xf32>,
 // CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
-// CHECK:         %[[F:.*]]:2 = call @sparse_out2(%[[X]])
+// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK:         %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
 // CHECK:         sparse_tensor.disassemble %[[F]]#1
 // CHECK:         return %[[F]]#0
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_out2
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) attributes { llvm.emit_c_interface } {
+func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) {
   %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
   return %arg0, %0 : tensor<64x64xf32>, tensor<64x64xf32, #sparse>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_inout(
+// CHECK-LABEL: func.func @sparse_inout(
 // CHECK-SAME:    %[[A:.*0]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
 // CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
 // CHECK-SAME:    %[[D:.*3]]: tensor<?xf32>,
 // CHECK-SAME:    %[[E:.*4]]: tensor<?xindex>,
-// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
+// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
 // CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
-// CHECK:         %[[F:.*]] = call @sparse_inout(%[[I]])
+// CHECK:         %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_inout
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } {
+func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> {
   return %arg0 : tensor<64x64xf32, #sparse>
 }
diff --git a/mlir/test/Dialect/SparseTensor/torch_linalg.mlir b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
new file mode 100644
index 0000000000000..f29e6b143783a
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt %s --sparse-assembler                 | FileCheck %s --check-prefix=CHECK-HI
+// RUN: mlir-opt %s --sparse-assembler \
+// RUN:             --linalg-generalize-named-ops \
+// RUN:             --linalg-fuse-elementwise-ops \
+// RUN:             --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-MID
+// RUN: mlir-opt %s --sparse-assembler \
+// RUN:             --sparsifier                       | FileCheck %s --check-prefix=CHECK-LOW
+
+//
+// An example of a module generated by torch-mlir with a sparse tensor from
+// torch.sparse. The MLIR sparsifier should be able to provide the external
+// API through a wrapper method (spiface and ciface). Various passes should
+// compose without trouble.
+//
+
+// CHECK-HI-LABEL: func.func @main
+// CHECK-HI:         sparse_tensor.assemble
+// CHECK-HI:         call @_internal_main
+// CHECK-HI:         return
+// CHECK-HI:       func.func private @_internal_main
+// CHECK-HI:         linalg.matmul
+// CHECK-HI:         return
+//
+// CHECK-MID-LABEL: func.func @main
+// CHECK-MID:          memref.load
+// CHECK-MID:          call @_internal_main
+// CHECK-MID:          return
+// CHECK-MID:       func.func private @_internal_main
+// CHECK-MID:          scf.for
+// CHECK-MID:            scf.for
+// CHECK-MID:          return
+
+// CHECK-LOW-LABEL: llvm.func @main
+// CHECK-LOW:         llvm.call @_internal_main
+// CHECK-LOW:         llvm.return
+// CHECK-LOW:       llvm.func @_mlir_ciface_main
+// CHECK-LOW:         llvm.call @main
+// CHECK-LOW:         llvm.return
+// CHECK-LOW:       llvm.func @_internal_main
+// CHECK-SAME:        {sym_visibility = "private"}
+// CHECK-LOW:         llvm.return
+
+#csc = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+module {
+  func.func @main(%arg0: tensor<64x64xf32, #csc>,
+                  %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = tensor.empty() : tensor<64x64xf32>
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64x64xf32>) -> tensor<64x64xf32>
+    %2 = linalg.matmul
+      ins(%arg0, %arg1 : tensor<64x64xf32, #csc>, tensor<64x64xf32>)
+      outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32>
+    return %2 : tensor<64x64xf32>
+  }
+}

@aartbik aartbik merged commit d00e6d0 into llvm:main Feb 5, 2024
3 of 4 checks passed
@aartbik aartbik deleted the bik branch February 5, 2024 18:48
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
Rewrite *all* public methods, making original internal, private methods,
and exposing wrappers under the original name. This works a bit better
in practice (when combined with c-interface mechanism of torch-mlir for
example).
ichaer added a commit to ichaer/llvm-project-onesided_lower_bound that referenced this pull request Feb 12, 2024
* llvm/main: (328 commits)
  [Flang][OpenMP] Attempt to make map-types-and-sizes.f90 test more agnostic to other architectures
  [Transforms] Add more cos combinations to SimplifyLibCalls and InstCombine (llvm#79699)
  [workflows] Close issues used for backports once the PR has been created (llvm#80394)
  [RISCV] Add support for RISC-V Pointer Masking (llvm#79929)
  [lldb] Cleanup regex in libcxx formatters (NFC) (llvm#80618)
  [lldb] Remove unused private TypeCategoryMap methods (NFC) (llvm#80602)
  [mlir][sparse] refine sparse assembler strategy (llvm#80521)
  [NFC] Fix typo (llvm#80703)
  Fix broken ARM processor features test (llvm#80717)
  [ValueTracking][NFC] Pass `SimplifyQuery` to `computeKnownFPClass` family (llvm#80657)
  [x86_64][windows][swift] do not use Swift async extended frame for wi… (llvm#80468)
  [X86] addConstantComments - add FP16 MOVSH asm comments support
  [X86] Regenerate some vector constant comments missed in recent patches to improve mask predicate handling in addConstantComments
  [clang][AMDGPU][CUDA] Handle __builtin_printf for device printf (llvm#68515)
  Add some clarification to email check message
  [GitHub][Workflows] Prevent multiple private email comments (temporarily) (llvm#80648)
  [workflows] Use /mnt as the build directory on Linux (llvm#80583)
  [Flang][OpenMP] Initial mapping of Fortran pointers and allocatables for target devices (llvm#71766)
  [AMDGPU] GlobalISel for f8 conversions (llvm#80503)
  [AMDGPU] Fixed byte_sel of v_cvt_f32_bf8/v_cvt_f32_fp8 (llvm#80502)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants