From 7bfe3ae09cfeadf8925094432554cab86c032e93 Mon Sep 17 00:00:00 2001 From: Ryutaro Okada <1015ryu88@gmail.com> Date: Tue, 27 Jan 2026 00:46:44 +0900 Subject: [PATCH 1/3] [mlir][Linalg] implement bufferization for `linalg.pack` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a BufferizableOpInterface implementation for linalg.pack now that pack supports memref semantics https://github.com/llvm/llvm-project/commit/4b066c7fff3455dc547fabb676583391febe41e9. This completes the op’s bufferization path and avoids copy-before-write for destination operands. Signed-off-by: Ryutaro Okada <1015ryu88@gmail.com> --- .../BufferizableOpInterfaceImpl.cpp | 42 +++++++++++++++++++ mlir/test/Dialect/Linalg/bufferize.mlir | 17 ++++++++ 2 files changed, 59 insertions(+) diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp index 3512ecd9d2eb2..60c685578682a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -191,6 +191,47 @@ struct SoftmaxOpInterface return success(); } }; + +struct PackOpInterface + : public DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + auto packOp = cast(op); + return !packOp.isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options, + BufferizationState &state) const { + auto packOp = cast(op); + if (packOp.hasPureBufferSemantics()) + return success(); + if (!packOp.hasPureTensorSemantics()) + return packOp.emitError() << "op does not have pure tensor semantics"; + + FailureOr sourceBuffer = + getBuffer(rewriter, packOp.getSource(), options, state); + if (failed(sourceBuffer)) + return failure(); + FailureOr destBuffer = + getBuffer(rewriter, packOp.getDest(), options, state); + if (failed(destBuffer)) + return failure(); + + SmallVector operands; + operands.push_back(*sourceBuffer); + operands.push_back(*destBuffer); + if (auto val = packOp.getPaddingValue()) + operands.push_back(val); + llvm::append_range(operands, packOp.getInnerTiles()); + + linalg::PackOp::create(rewriter, packOp.getLoc(), TypeRange{}, operands, + op->getAttrs()); + replaceOpWithBufferizedValues(rewriter, op, *destBuffer); + return success(); + } +}; } // namespace void mlir::linalg::registerBufferizableOpInterfaceExternalModels( @@ -206,5 +247,6 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels( >::registerOpInterface(ctx); SoftmaxOp::attachInterface(*ctx); + PackOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir index 1c6cb88fa028b..2cb09c39b5776 100644 --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -206,3 +206,20 @@ func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> return %1 : tensor<2x16x32xf32> } + +// ----- + +// CHECK-LABEL: func @bufferize_pack( +// CHECK-SAME: %[[SRC:.*]]: tensor<128x256xf32>, %[[DST:.*]]: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> { +// CHECK-DAG: %[[SRC_BUF:.*]] = bufferization.to_buffer %[[SRC]] : tensor<128x256xf32> to memref<128x256xf32> +// CHECK-DAG: %[[DST_BUF:.*]] = memref.alloc() {{.*}} : memref<16x8x8x32xf32> +// CHECK-NOT: memref.copy +// CHECK: linalg.pack %[[SRC_BUF]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DST_BUF]] : memref<128x256xf32> -> memref<16x8x8x32xf32> +// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[DST_BUF]] : memref<16x8x8x32xf32> to tensor<16x8x8x32xf32> +// CHECK: return %[[RESULT]] : tensor<16x8x8x32xf32> +func.func @bufferize_pack(%source: tensor<128x256xf32>, %dest: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> { + %0 = linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] + into %dest : tensor<128x256xf32> -> tensor<16x8x8x32xf32> + return %0 : tensor<16x8x8x32xf32> +} + From f9ea1dbbc1f1f3dcfcf76841e9bbcaf89e3a9626 Mon Sep 17 00:00:00 2001 From: Ryutaro Okada <1015ryu88@gmail.com> Date: Tue, 27 Jan 2026 01:45:37 +0900 Subject: [PATCH 2/3] expand test to include padding_value and outer_dims_perm Signed-off-by: Ryutaro Okada <1015ryu88@gmail.com> --- mlir/test/Dialect/Linalg/bufferize.mlir | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir index 2cb09c39b5776..6729cc4b76c4d 100644 --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -210,16 +210,19 @@ func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf // ----- // CHECK-LABEL: func @bufferize_pack( -// CHECK-SAME: %[[SRC:.*]]: tensor<128x256xf32>, %[[DST:.*]]: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> { -// CHECK-DAG: %[[SRC_BUF:.*]] = bufferization.to_buffer %[[SRC]] : tensor<128x256xf32> to memref<128x256xf32> -// CHECK-DAG: %[[DST_BUF:.*]] = memref.alloc() {{.*}} : memref<16x8x8x32xf32> +// CHECK-SAME: %[[SRC:.*]]: tensor<200x127x256xf32>, %[[DST:.*]]: tensor<256x64x200x2xf32>) -> tensor<256x64x200x2xf32> { +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[SRC_BUF:.*]] = bufferization.to_buffer %[[SRC]] : tensor<200x127x256xf32> to memref<200x127x256xf32> +// CHECK-DAG: %[[DST_BUF:.*]] = memref.alloc() {{.*}} : memref<256x64x200x2xf32> // CHECK-NOT: memref.copy -// CHECK: linalg.pack %[[SRC_BUF]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DST_BUF]] : memref<128x256xf32> -> memref<16x8x8x32xf32> -// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[DST_BUF]] : memref<16x8x8x32xf32> to tensor<16x8x8x32xf32> -// CHECK: return %[[RESULT]] : tensor<16x8x8x32xf32> -func.func @bufferize_pack(%source: tensor<128x256xf32>, %dest: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> { - %0 = linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] - into %dest : tensor<128x256xf32> -> tensor<16x8x8x32xf32> - return %0 : tensor<16x8x8x32xf32> +// CHECK: linalg.pack %[[SRC_BUF]] padding_value(%[[CST]] : f32) outer_dims_perm = [2, 1, 0] inner_dims_pos = [1] inner_tiles = [2] into %[[DST_BUF]] : memref<200x127x256xf32> -> memref<256x64x200x2xf32> +// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[DST_BUF]] : memref<256x64x200x2xf32> to tensor<256x64x200x2xf32> +// CHECK: return %[[RESULT]] : tensor<256x64x200x2xf32> +func.func @bufferize_pack(%arg0: tensor<200x127x256xf32>, %arg1: tensor<256x64x200x2xf32>) -> tensor<256x64x200x2xf32> { + %pad = arith.constant 0.0 : f32 + %0 = linalg.pack %arg0 padding_value(%pad : f32) outer_dims_perm = [2, 1, 0] + inner_dims_pos = [1] inner_tiles = [2] into %arg1 + : tensor<200x127x256xf32> -> tensor<256x64x200x2xf32> + return %0 : tensor<256x64x200x2xf32> } From 137b07e112c703e90796a54c7f62017d4a23b86d Mon Sep 17 00:00:00 2001 From: Ryutaro Okada <1015ryu88@gmail.com> Date: Tue, 27 Jan 2026 22:06:55 +0900 Subject: [PATCH 3/3] fix upon review Signed-off-by: Ryutaro Okada <1015ryu88@gmail.com> --- .../Linalg/Transforms/BufferizableOpInterfaceImpl.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp index 60c685578682a..aed4fbf12bd43 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -205,11 +205,10 @@ struct PackOpInterface const BufferizationOptions &options, BufferizationState &state) const { auto packOp = cast(op); - if (packOp.hasPureBufferSemantics()) - return success(); + assert(!packOp.hasPureBufferSemantics() && "expected op with tensors"); if (!packOp.hasPureTensorSemantics()) - return packOp.emitError() << "op does not have pure tensor semantics"; - + return packOp.emitError() + << "mixed tensor/buffer semantic op not supported yet"; FailureOr sourceBuffer = getBuffer(rewriter, packOp.getSource(), options, state); if (failed(sourceBuffer))