Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 50 additions & 51 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file defines Pack + Unpack Ops that have been moved from the Tensor
// dialect. As such, these are defined as memory-effect-free and only accept
// "tensors" as inputs.
//
// TODO: Once a good motivating example is identified, relax these
// restrictions.
// dialect.
//
//===----------------------------------------------------------------------===//

Expand All @@ -30,24 +26,27 @@ include "mlir/IR/OpAsmInterface.td"
// RelayoutOp
//===----------------------------------------------------------------------===//

class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
ConditionallySpeculatable, NoMemoryEffect,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []>
: Op<Linalg_Dialect, mnemonic,
!listconcat(
traits, [DeclareOpInterfaceMethods<
OpAsmOpInterface, ["getAsmResultNames"]>,
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
ConditionallySpeculatable,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<
ReifyRankedShapedTypeOpInterface, [
"reifyResultShapes"]>,
TypesMatchWith<"result type matches type of dest",
"dest", "result",
"$_self">])> {
OptionalTypesMatchWith<"result type matches type of dest",
"dest", "result", "$_self">])> {

code commonExtraClassDeclaration = [{
size_t getSourceRank() { return getSourceType().getRank(); };
size_t getDestRank() { return getDestType().getRank(); };
RankedTensorType getSourceType() {
return ::llvm::cast<RankedTensorType>(getSource().getType()); };
RankedTensorType getDestType() {
return ::llvm::cast<RankedTensorType>(getDest().getType()); };
ShapedType getSourceType() {
return ::llvm::cast<ShapedType>(getSource().getType()); };
ShapedType getDestType() {
return ::llvm::cast<ShapedType>(getDest().getType()); };

MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }

Expand Down Expand Up @@ -195,23 +194,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
```
}];
let arguments = (ins AnyRankedTensor:$source,
AnyRankedTensor:$dest,
Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$source
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
(`outer_dims_perm` `=` $outer_dims_perm^)?
`inner_dims_pos` `=` $inner_dims_pos
`inner_tiles` `=`
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
`into` $dest attr-dict `:` type($source) `->` type($dest)
}];
let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
TensorOrMemRef<[AnyType]>:$dest,
Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs Optional<AnyRankedTensor>:$result);
Copy link
Contributor

@rolfmorel rolfmorel Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think changing the op's formal results here is making the generated Python PackOp -- used here to expose a nice linalg.pack Python wrapper -- take the result type as the first argument. You will just need a little bit of logic in this function (and presumably in the linalg.unpack Python wrapper as well) that determines if the op is using tensor semantics, in which case is the call to the PackOp constructor should look something like PackOp(dest.type, ... and, in case of memref types/semantics, just PackOp(None, ...). (N.B. syntax might be a bit different.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sakupan102 on a related note, could you also add a memref pack/unpack test case to the python tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test case in ops.py


let builders = [
OpBuilder<(ins "Value":$source, "Value":$dest,
Expand All @@ -233,7 +223,21 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
static RankedTensorType inferPackedType(RankedTensorType sourceType,
static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});

// Method to get the `MemRefType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
static MemRefType inferPackedMemRefType(MemRefType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});

// Returns the shape of the packed type. It is a shared helper that helps
// type inference methods in a way that ensures that they agree on which
// dimensions are dynamic.
static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});

Expand Down Expand Up @@ -285,6 +289,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
let hasCanonicalizeMethod = 1;

let hasFolder = 1;

let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -352,21 +358,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
// Outer Dims: 9x3x8 Inner Dims: 4x2
```
}];
let arguments = (ins AnyRankedTensor:$source,
AnyRankedTensor:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$source
(`outer_dims_perm` `=` $outer_dims_perm^)?
`inner_dims_pos` `=` $inner_dims_pos
`inner_tiles` `=`
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
`into` $dest attr-dict `:` type($source) `->` type($dest)
}];
let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
TensorOrMemRef<[AnyType]>:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs Optional<AnyRankedTensor>:$result);

let builders = [
OpBuilder<(ins "Value":$source, "Value":$dest,
Expand Down Expand Up @@ -409,6 +406,8 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
let hasCanonicalizeMethod = 1;

let hasFolder = 1;

let hasCustomAssemblyFormat = 1;
}

#endif // LINALG_RELEAYOUT_OPS
Loading