Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] refine sparse assembler strategy #80521

Merged
merged 2 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
let summary = "Add [dis]assemble operations on external sparse tensors";
let description = [{
A pass that converts public entry methods that use sparse tensors as
input parameters and/or output return values into wrapper functions
input parameters and/or output return values into wrapper methods
that [dis]assemble the individual tensors that constitute the actual
storage used externally into MLIR sparse tensors. This pass can be used
to prepare the public entry methods of a program that is compiled by the
Expand Down
55 changes: 30 additions & 25 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,29 +132,29 @@ void convVals(OpBuilder &builder, Location loc, TypeRange types,
namespace {

// A rewriting rules that converts public entry methods that use sparse tensors
// as input parameters and/or output return values into wrapper functions
// that [dis]assemble the individual tensors that constitute the actual
// storage used externally into MLIR sparse tensors.
// as input parameters and/or output return values into wrapper methods that
// [dis]assemble the individual tensors that constitute the actual storage used
// externally into MLIR sparse tensors before calling the origal method.
aartbik marked this conversation as resolved.
Show resolved Hide resolved
//
// In particular, each sparse tensor input
//
// void foo(..., t, ...) { }
//
// adds the following strucuture in a wrapper
// makes the original foo() internal and adds the following wrapper method
//
// void spiface_foo(..., t1..tn, ...) {
// void foo(..., t1..tn, ...) {
// t = assemble t1..tn
// foo(..., t, ...)
// _internal_foo(..., t, ...)
// }
//
// and likewise, each output tensor
//
// ... T ... bar(...) { return ..., t, ...; }
//
// adds the following structure in a wrapper
// makes the original bar() internal and adds the following wrapper method
//
// ... T1..TN ... spiface_bar(..., t1'..tn') {
// ..., t, ... = bar(...)
// ... T1..TN ... bar(..., t1'..tn') {
// ..., t, ... = _internal_bar(...)
// t1..tn = disassemble t, t1'..tn'
// return ..., t1..tn, ...
// }
Expand All @@ -168,9 +168,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {

LogicalResult matchAndRewrite(func::FuncOp funcOp,
PatternRewriter &rewriter) const override {
// Only a rewrite an entry with the c-interface requested.
if (!funcOp->getAttrOfType<UnitAttr>(
LLVM::LLVMDialect::getEmitCWrapperAttrName()))
// Only rewrite public entry methods.
if (funcOp.isPrivate())
return failure();

// Translate sparse tensor types to external types.
Expand All @@ -180,29 +179,29 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
convTypes(funcOp.getArgumentTypes(), inputTypes);
convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes);

// Only sparse inputs or outputs need a wrapper function.
// Only sparse inputs or outputs need a wrapper method.
if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
outputTypes.size() == funcOp.getResultTypes().size())
return failure();

// Start the new wrapper function. Together with the c-interface mangling,
// a sparse external entry point eventually will have a name like:
// _mlir_ciface_spiface_XXX(...)
// Modify the original method into an internal, private method.
auto orgName = funcOp.getName();
std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
funcOp.setName(wrapper);
funcOp.setPrivate();

// Start the new public wrapper method with original name.
Location loc = funcOp.getLoc();
ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
MLIRContext *context = modOp.getContext();
OpBuilder moduleBuilder(modOp.getBodyRegion());
std::string wrapper = llvm::formatv("spiface_{0}", funcOp.getName()).str();
unsigned extra = inputTypes.size();
inputTypes.append(extraTypes);
auto func = moduleBuilder.create<func::FuncOp>(
loc, wrapper, FunctionType::get(context, inputTypes, outputTypes));
loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
func.setPublic();
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
UnitAttr::get(context));

// Construct new wrapper function body.
auto org = SymbolRefAttr::get(context, funcOp.getName());
// Construct new wrapper method body.
OpBuilder::InsertionGuard insertionGuard(rewriter);
Block *body = func.addEntryBlock();
rewriter.setInsertionPointToStart(body);
Expand All @@ -212,7 +211,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
ValueRange(), inputs, 0, /*isIn=*/true);

// Call original function.
// Call original, now internal method.
auto org = SymbolRefAttr::get(context, wrapper);
auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
inputs);

Expand All @@ -222,8 +222,13 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
body->getArguments(), outputs, extra, /*isIn=*/false);
rewriter.create<func::ReturnOp>(loc, outputs);

// Strip the c-interface attribute from the original function.
funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
// Finally, migrate a potential c-interface property.
if (funcOp->getAttrOfType<UnitAttr>(
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
UnitAttr::get(context));
funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
}
return success();
}
};
Expand Down
49 changes: 27 additions & 22 deletions mlir/test/Dialect/SparseTensor/external.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,95 +3,100 @@
// -----

// CHECK-LABEL: func.func @nop(
// CHECK-SAME: %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> attributes {llvm.emit_c_interface} {
// CHECK-SAME: %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> {
// CHECK: return %[[A]] : tensor<100xf32>
// CHECK: }
func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes { llvm.emit_c_interface } {
func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
return %arg0 : tensor<100xf32>
}

// -----

// CHECK-LABEL: func.func @spiface_sparse_in(
// CHECK-LABEL: func.func @sparse_in(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
// CHECK: %[[F:.*]] = call @sparse_in(%[[I]])
// CHECK: %[[F:.*]] = call @_internal_sparse_in(%[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
// CHECK: func.func private @_internal_sparse_in
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } {
func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
return %0 : tensor<64x64xf32>
}

// -----

// CHECK-LABEL: func.func @spiface_sparse_in2(
// CHECK-LABEL: func.func @sparse_in2(
// CHECK-SAME: %[[X:.*]]: tensor<100xf32>,
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
// CHECK: %[[F:.*]] = call @sparse_in2(%[[X]], %[[I]])
// CHECK: %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
// CHECK: func.func private @_internal_sparse_in2
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } {
func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
%0 = sparse_tensor.convert %arg1 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
return %0 : tensor<64x64xf32>
}

// -----

// CHECK-LABEL: func.func @spiface_sparse_out(
// CHECK-LABEL: func.func @sparse_out(
// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
// CHECK: %[[F:.*]] = call @sparse_out(%[[X]])
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
// CHECK: }
// CHECK: func.func private @_internal_sparse_out
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } {
func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
return %0 : tensor<64x64xf32, #sparse>
}

// -----

// CHECK-LABEL: func.func @spiface_sparse_out2(
// CHECK-LABEL: func.func @sparse_out2(
// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
// CHECK: %[[F:.*]]:2 = call @sparse_out2(%[[X]])
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
// CHECK: %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
// CHECK: sparse_tensor.disassemble %[[F]]#1
// CHECK: return %[[F]]#0
// CHECK: }
// CHECK: func.func private @_internal_sparse_out2
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) attributes { llvm.emit_c_interface } {
func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) {
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
return %arg0, %0 : tensor<64x64xf32>, tensor<64x64xf32, #sparse>
}

// -----

// CHECK-LABEL: func.func @spiface_sparse_inout(
// CHECK-LABEL: func.func @sparse_inout(
// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
// CHECK-SAME: %[[D:.*3]]: tensor<?xf32>,
// CHECK-SAME: %[[E:.*4]]: tensor<?xindex>,
// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
// CHECK: %[[F:.*]] = call @sparse_inout(%[[I]])
// CHECK: %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
// CHECK: }
// CHECK: func.func private @_internal_sparse_inout
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } {
func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> {
return %arg0 : tensor<64x64xf32, #sparse>
}
55 changes: 55 additions & 0 deletions mlir/test/Dialect/SparseTensor/torch_linalg.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// RUN: mlir-opt %s --sparse-assembler | FileCheck %s --check-prefix=CHECK-HI
// RUN: mlir-opt %s --sparse-assembler \
// RUN: --linalg-generalize-named-ops \
// RUN: --linalg-fuse-elementwise-ops \
// RUN: --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-MID
// RUN: mlir-opt %s --sparse-assembler \
// RUN: --sparsifier | FileCheck %s --check-prefix=CHECK-LOW

//
// An example of a module generated by torch-mlir with a sparse tensor from
// torch.sparse. The MLIR sparsifier should be able to provide the external
// API through a wrapper method (spiface and ciface). Various passes should
// compose without trouble.
//

// CHECK-HI-LABEL: func.func @main
// CHECK-HI: sparse_tensor.assemble
// CHECK-HI: call @_internal_main
// CHECK-HI: return
// CHECK-HI: func.func private @_internal_main
// CHECK-HI: linalg.matmul
// CHECK-HI: return
//
// CHECK-MID-LABEL: func.func @main
// CHECK-MID: memref.load
// CHECK-MID: call @_internal_main
// CHECK-MID: return
// CHECK-MID: func.func private @_internal_main
// CHECK-MID: scf.for
// CHECK-MID: scf.for
// CHECK-MID: return

// CHECK-LOW-LABEL: llvm.func @main
// CHECK-LOW: llvm.call @_internal_main
// CHECK-LOW: llvm.return
// CHECK-LOW: llvm.func @_mlir_ciface_main
// CHECK-LOW: llvm.call @main
// CHECK-LOW: llvm.return
// CHECK-LOW: llvm.func @_internal_main
// CHECK-SAME: {sym_visibility = "private"}
// CHECK-LOW: llvm.return

#csc = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
module {
func.func @main(%arg0: tensor<64x64xf32, #csc>,
%arg1: tensor<64x64xf32>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<64x64xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64x64xf32>) -> tensor<64x64xf32>
%2 = linalg.matmul
ins(%arg0, %arg1 : tensor<64x64xf32, #csc>, tensor<64x64xf32>)
outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32>
return %2 : tensor<64x64xf32>
}
}
Loading