-
Notifications
You must be signed in to change notification settings - Fork 15.9k
[mlir][linalg] Extend linalg.pack and linalg.unpack to accept memref #167675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0d667f5
4b6dbbf
cf28695
e00de18
6ec4b5d
a4ba0fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,11 +7,7 @@ | |
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // This file defines Pack + Unpack Ops that have been moved from the Tensor | ||
| // dialect. As such, these are defined as memory-effect-free and only accept | ||
| // "tensors" as inputs. | ||
| // | ||
| // TODO: Once a good motivating example is identified, relax these | ||
| // restrictions. | ||
| // dialect. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
|
|
@@ -30,24 +26,27 @@ include "mlir/IR/OpAsmInterface.td" | |
| // RelayoutOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> : | ||
| Op<Linalg_Dialect, mnemonic, !listconcat(traits, [ | ||
| DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, | ||
| DestinationStyleOpInterface, LinalgRelayoutOpInterface, | ||
| ConditionallySpeculatable, NoMemoryEffect, | ||
| DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [ | ||
| class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> | ||
| : Op<Linalg_Dialect, mnemonic, | ||
| !listconcat( | ||
| traits, [DeclareOpInterfaceMethods< | ||
| OpAsmOpInterface, ["getAsmResultNames"]>, | ||
| DestinationStyleOpInterface, LinalgRelayoutOpInterface, | ||
| ConditionallySpeculatable, | ||
| DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, | ||
| DeclareOpInterfaceMethods< | ||
| ReifyRankedShapedTypeOpInterface, [ | ||
| "reifyResultShapes"]>, | ||
| TypesMatchWith<"result type matches type of dest", | ||
| "dest", "result", | ||
| "$_self">])> { | ||
| OptionalTypesMatchWith<"result type matches type of dest", | ||
| "dest", "result", "$_self">])> { | ||
|
|
||
| code commonExtraClassDeclaration = [{ | ||
| size_t getSourceRank() { return getSourceType().getRank(); }; | ||
| size_t getDestRank() { return getDestType().getRank(); }; | ||
| RankedTensorType getSourceType() { | ||
| return ::llvm::cast<RankedTensorType>(getSource().getType()); }; | ||
| RankedTensorType getDestType() { | ||
| return ::llvm::cast<RankedTensorType>(getDest().getType()); }; | ||
| ShapedType getSourceType() { | ||
| return ::llvm::cast<ShapedType>(getSource().getType()); }; | ||
| ShapedType getDestType() { | ||
| return ::llvm::cast<ShapedType>(getDest().getType()); }; | ||
|
|
||
| MutableOperandRange getDpsInitsMutable() { return getDestMutable(); } | ||
|
|
||
|
|
@@ -195,23 +194,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ | |
| // expect tensor<2x8xf32> because CeilDiv(9, 8) = 2 | ||
| ``` | ||
| }]; | ||
| let arguments = (ins AnyRankedTensor:$source, | ||
| AnyRankedTensor:$dest, | ||
| Optional<AnyType>:$padding_value, | ||
| DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm, | ||
| DenseI64ArrayAttr:$inner_dims_pos, | ||
| Variadic<Index>:$inner_tiles, | ||
| DenseI64ArrayAttr:$static_inner_tiles); | ||
| let results = (outs AnyRankedTensor:$result); | ||
| let assemblyFormat = [{ | ||
| $source | ||
| (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)? | ||
| (`outer_dims_perm` `=` $outer_dims_perm^)? | ||
| `inner_dims_pos` `=` $inner_dims_pos | ||
| `inner_tiles` `=` | ||
| custom<DynamicIndexList>($inner_tiles, $static_inner_tiles) | ||
| `into` $dest attr-dict `:` type($source) `->` type($dest) | ||
| }]; | ||
| let arguments = (ins TensorOrMemRef<[AnyType]>:$source, | ||
| TensorOrMemRef<[AnyType]>:$dest, | ||
| Optional<AnyType>:$padding_value, | ||
| DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm, | ||
| DenseI64ArrayAttr:$inner_dims_pos, | ||
| Variadic<Index>:$inner_tiles, | ||
| DenseI64ArrayAttr:$static_inner_tiles); | ||
| let results = (outs Optional<AnyRankedTensor>:$result); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think changing the op's formal results here is making the generated Python
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sakupan102 on a related note, could you also add a memref pack/unpack test case to the python tests?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added test case in |
||
|
|
||
| let builders = [ | ||
| OpBuilder<(ins "Value":$source, "Value":$dest, | ||
|
|
@@ -233,7 +223,21 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ | |
| // Method to get the `RankedTensorType` of the result based on the inner | ||
| // tiles, position of the inner tiles (innerDimsPos) and interchange vector | ||
| // of outer loops (outerDimsPerm). | ||
| static RankedTensorType inferPackedType(RankedTensorType sourceType, | ||
| static RankedTensorType inferPackedTensorType(RankedTensorType sourceType, | ||
| ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos, | ||
| ArrayRef<int64_t> outerDimsPerm = {}); | ||
|
|
||
| // Method to get the `MemRefType` of the result based on the inner | ||
| // tiles, position of the inner tiles (innerDimsPos) and interchange vector | ||
| // of outer loops (outerDimsPerm). | ||
| static MemRefType inferPackedMemRefType(MemRefType sourceType, | ||
| ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos, | ||
| ArrayRef<int64_t> outerDimsPerm = {}); | ||
|
|
||
| // Returns the shape of the packed type. It is a shared helper that helps | ||
| // type inference methods in a way that ensures that they agree on which | ||
| // dimensions are dynamic. | ||
| static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape, | ||
| ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos, | ||
| ArrayRef<int64_t> outerDimsPerm = {}); | ||
|
|
||
|
|
@@ -285,6 +289,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ | |
| let hasCanonicalizeMethod = 1; | ||
|
|
||
| let hasFolder = 1; | ||
|
|
||
| let hasCustomAssemblyFormat = 1; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
|
|
@@ -352,21 +358,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> { | |
| // Outer Dims: 9x3x8 Inner Dims: 4x2 | ||
| ``` | ||
| }]; | ||
| let arguments = (ins AnyRankedTensor:$source, | ||
| AnyRankedTensor:$dest, | ||
| DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm, | ||
| DenseI64ArrayAttr:$inner_dims_pos, | ||
| Variadic<Index>:$inner_tiles, | ||
| DenseI64ArrayAttr:$static_inner_tiles); | ||
| let results = (outs AnyRankedTensor:$result); | ||
| let assemblyFormat = [{ | ||
| $source | ||
| (`outer_dims_perm` `=` $outer_dims_perm^)? | ||
| `inner_dims_pos` `=` $inner_dims_pos | ||
| `inner_tiles` `=` | ||
| custom<DynamicIndexList>($inner_tiles, $static_inner_tiles) | ||
| `into` $dest attr-dict `:` type($source) `->` type($dest) | ||
| }]; | ||
| let arguments = (ins TensorOrMemRef<[AnyType]>:$source, | ||
| TensorOrMemRef<[AnyType]>:$dest, | ||
| DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm, | ||
| DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles, | ||
| DenseI64ArrayAttr:$static_inner_tiles); | ||
| let results = (outs Optional<AnyRankedTensor>:$result); | ||
|
|
||
| let builders = [ | ||
| OpBuilder<(ins "Value":$source, "Value":$dest, | ||
|
|
@@ -409,6 +406,8 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> { | |
| let hasCanonicalizeMethod = 1; | ||
|
|
||
| let hasFolder = 1; | ||
|
|
||
| let hasCustomAssemblyFormat = 1; | ||
| } | ||
|
|
||
| #endif // LINALG_RELEAYOUT_OPS | ||
Uh oh!
There was an error while loading. Please reload this page.