-
Notifications
You must be signed in to change notification settings - Fork 15.9k
[mlir][Linalg] implement bufferization for linalg.pack
#177982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add a BufferizableOpInterface implementation for linalg.pack now that pack supports memref semantics llvm@4b066c7. This completes the op’s bufferization path and avoids copy-before-write for destination operands. Signed-off-by: Ryutaro Okada <[email protected]>
|
@llvm/pr-subscribers-mlir-linalg Author: Ryutaro Okada (sakupan102) ChangesAdd a BufferizableOpInterface implementation for linalg.pack now that pack supports memref semantics 4b066c7. This completes the op’s bufferization path and avoids copy-before-write for destination operands. Full diff: https://github.com/llvm/llvm-project/pull/177982.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 3512ecd9d2eb2..60c685578682a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -191,6 +191,47 @@ struct SoftmaxOpInterface
return success();
}
};
+
+struct PackOpInterface
+ : public DstBufferizableOpInterfaceExternalModel<PackOpInterface,
+ linalg::PackOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ auto packOp = cast<linalg::PackOp>(op);
+ return !packOp.isDpsInit(&opOperand);
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
+ auto packOp = cast<linalg::PackOp>(op);
+ if (packOp.hasPureBufferSemantics())
+ return success();
+ if (!packOp.hasPureTensorSemantics())
+ return packOp.emitError() << "op does not have pure tensor semantics";
+
+ FailureOr<Value> sourceBuffer =
+ getBuffer(rewriter, packOp.getSource(), options, state);
+ if (failed(sourceBuffer))
+ return failure();
+ FailureOr<Value> destBuffer =
+ getBuffer(rewriter, packOp.getDest(), options, state);
+ if (failed(destBuffer))
+ return failure();
+
+ SmallVector<Value> operands;
+ operands.push_back(*sourceBuffer);
+ operands.push_back(*destBuffer);
+ if (auto val = packOp.getPaddingValue())
+ operands.push_back(val);
+ llvm::append_range(operands, packOp.getInnerTiles());
+
+ linalg::PackOp::create(rewriter, packOp.getLoc(), TypeRange{}, operands,
+ op->getAttrs());
+ replaceOpWithBufferizedValues(rewriter, op, *destBuffer);
+ return success();
+ }
+};
} // namespace
void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
@@ -206,5 +247,6 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
>::registerOpInterface(ctx);
SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
+ PackOp::attachInterface<PackOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 1c6cb88fa028b..2cb09c39b5776 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -206,3 +206,20 @@ func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf
outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
return %1 : tensor<2x16x32xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @bufferize_pack(
+// CHECK-SAME: %[[SRC:.*]]: tensor<128x256xf32>, %[[DST:.*]]: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> {
+// CHECK-DAG: %[[SRC_BUF:.*]] = bufferization.to_buffer %[[SRC]] : tensor<128x256xf32> to memref<128x256xf32>
+// CHECK-DAG: %[[DST_BUF:.*]] = memref.alloc() {{.*}} : memref<16x8x8x32xf32>
+// CHECK-NOT: memref.copy
+// CHECK: linalg.pack %[[SRC_BUF]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DST_BUF]] : memref<128x256xf32> -> memref<16x8x8x32xf32>
+// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[DST_BUF]] : memref<16x8x8x32xf32> to tensor<16x8x8x32xf32>
+// CHECK: return %[[RESULT]] : tensor<16x8x8x32xf32>
+func.func @bufferize_pack(%source: tensor<128x256xf32>, %dest: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> {
+ %0 = linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ into %dest : tensor<128x256xf32> -> tensor<16x8x8x32xf32>
+ return %0 : tensor<16x8x8x32xf32>
+}
+
|
|
@llvm/pr-subscribers-mlir Author: Ryutaro Okada (sakupan102) ChangesAdd a BufferizableOpInterface implementation for linalg.pack now that pack supports memref semantics 4b066c7. This completes the op’s bufferization path and avoids copy-before-write for destination operands. Full diff: https://github.com/llvm/llvm-project/pull/177982.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 3512ecd9d2eb2..60c685578682a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -191,6 +191,47 @@ struct SoftmaxOpInterface
return success();
}
};
+
+struct PackOpInterface
+ : public DstBufferizableOpInterfaceExternalModel<PackOpInterface,
+ linalg::PackOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ auto packOp = cast<linalg::PackOp>(op);
+ return !packOp.isDpsInit(&opOperand);
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
+ auto packOp = cast<linalg::PackOp>(op);
+ if (packOp.hasPureBufferSemantics())
+ return success();
+ if (!packOp.hasPureTensorSemantics())
+ return packOp.emitError() << "op does not have pure tensor semantics";
+
+ FailureOr<Value> sourceBuffer =
+ getBuffer(rewriter, packOp.getSource(), options, state);
+ if (failed(sourceBuffer))
+ return failure();
+ FailureOr<Value> destBuffer =
+ getBuffer(rewriter, packOp.getDest(), options, state);
+ if (failed(destBuffer))
+ return failure();
+
+ SmallVector<Value> operands;
+ operands.push_back(*sourceBuffer);
+ operands.push_back(*destBuffer);
+ if (auto val = packOp.getPaddingValue())
+ operands.push_back(val);
+ llvm::append_range(operands, packOp.getInnerTiles());
+
+ linalg::PackOp::create(rewriter, packOp.getLoc(), TypeRange{}, operands,
+ op->getAttrs());
+ replaceOpWithBufferizedValues(rewriter, op, *destBuffer);
+ return success();
+ }
+};
} // namespace
void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
@@ -206,5 +247,6 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
>::registerOpInterface(ctx);
SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
+ PackOp::attachInterface<PackOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 1c6cb88fa028b..2cb09c39b5776 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -206,3 +206,20 @@ func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf
outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
return %1 : tensor<2x16x32xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @bufferize_pack(
+// CHECK-SAME: %[[SRC:.*]]: tensor<128x256xf32>, %[[DST:.*]]: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> {
+// CHECK-DAG: %[[SRC_BUF:.*]] = bufferization.to_buffer %[[SRC]] : tensor<128x256xf32> to memref<128x256xf32>
+// CHECK-DAG: %[[DST_BUF:.*]] = memref.alloc() {{.*}} : memref<16x8x8x32xf32>
+// CHECK-NOT: memref.copy
+// CHECK: linalg.pack %[[SRC_BUF]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DST_BUF]] : memref<128x256xf32> -> memref<16x8x8x32xf32>
+// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[DST_BUF]] : memref<16x8x8x32xf32> to tensor<16x8x8x32xf32>
+// CHECK: return %[[RESULT]] : tensor<16x8x8x32xf32>
+func.func @bufferize_pack(%source: tensor<128x256xf32>, %dest: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> {
+ %0 = linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ into %dest : tensor<128x256xf32> -> tensor<16x8x8x32xf32>
+ return %0 : tensor<16x8x8x32xf32>
+}
+
|
|
If everything looks good, could you merge it? |
Signed-off-by: Ryutaro Okada <[email protected]>
340a435 to
f9ea1db
Compare
Sure but let’s give others a chance to have a look too. |
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
Signed-off-by: Ryutaro Okada <[email protected]>
Add a BufferizableOpInterface implementation for linalg.pack now that pack supports memref semantics llvm@4b066c7. This completes the op’s bufferization path and avoids copy-before-write for destination operands. --------- Signed-off-by: Ryutaro Okada <[email protected]>
Reverts: - Carries local revert of llvm/llvm-project#169614 due to #22649. - Adds revert of llvm/llvm-project#177982. `reifyResultShapes()` is unimplemented for pack ops on memrefs causing a crash in [getPackUnPackIterationDomain](https://github.com/iree-org/llvm-project/blob/b24bd7161ca7eb6e9652c34b92e200ac16af3628/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp#L718) Fixes: - Passes `LLVM::DIFlags::Zero` to address the added argument for L`LVM::DIDerivedTypeAttr::get()` [llvm/llvm-project#177889](llvm/llvm-project#177889) - Removes he `firstIndex` parameter to address the API change to `visitNonControlFlowArguments()` [llvm/llvm-project#175210](llvm/llvm-project#175210) https://github.com/iree-org/llvm-project/tree/sm-iree-integrates/llvm-20260128 Signed-off-by: Ian Wood <[email protected]>
Add a BufferizableOpInterface implementation for linalg.pack now that pack supports memref semantics 4b066c7. This completes the op’s bufferization path and avoids copy-before-write for destination operands.