diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp index 3512ecd9d2eb2..aed4fbf12bd43 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -191,6 +191,46 @@ struct SoftmaxOpInterface return success(); } }; + +struct PackOpInterface + : public DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + auto packOp = cast(op); + return !packOp.isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options, + BufferizationState &state) const { + auto packOp = cast(op); + assert(!packOp.hasPureBufferSemantics() && "expected op with tensors"); + if (!packOp.hasPureTensorSemantics()) + return packOp.emitError() + << "mixed tensor/buffer semantic op not supported yet"; + FailureOr sourceBuffer = + getBuffer(rewriter, packOp.getSource(), options, state); + if (failed(sourceBuffer)) + return failure(); + FailureOr destBuffer = + getBuffer(rewriter, packOp.getDest(), options, state); + if (failed(destBuffer)) + return failure(); + + SmallVector operands; + operands.push_back(*sourceBuffer); + operands.push_back(*destBuffer); + if (auto val = packOp.getPaddingValue()) + operands.push_back(val); + llvm::append_range(operands, packOp.getInnerTiles()); + + linalg::PackOp::create(rewriter, packOp.getLoc(), TypeRange{}, operands, + op->getAttrs()); + replaceOpWithBufferizedValues(rewriter, op, *destBuffer); + return success(); + } +}; } // namespace void mlir::linalg::registerBufferizableOpInterfaceExternalModels( @@ -206,5 +246,6 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels( >::registerOpInterface(ctx); SoftmaxOp::attachInterface(*ctx); + PackOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir index 1c6cb88fa028b..6729cc4b76c4d 100644 --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -206,3 +206,23 @@ func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> return %1 : tensor<2x16x32xf32> } + +// ----- + +// CHECK-LABEL: func @bufferize_pack( +// CHECK-SAME: %[[SRC:.*]]: tensor<200x127x256xf32>, %[[DST:.*]]: tensor<256x64x200x2xf32>) -> tensor<256x64x200x2xf32> { +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[SRC_BUF:.*]] = bufferization.to_buffer %[[SRC]] : tensor<200x127x256xf32> to memref<200x127x256xf32> +// CHECK-DAG: %[[DST_BUF:.*]] = memref.alloc() {{.*}} : memref<256x64x200x2xf32> +// CHECK-NOT: memref.copy +// CHECK: linalg.pack %[[SRC_BUF]] padding_value(%[[CST]] : f32) outer_dims_perm = [2, 1, 0] inner_dims_pos = [1] inner_tiles = [2] into %[[DST_BUF]] : memref<200x127x256xf32> -> memref<256x64x200x2xf32> +// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[DST_BUF]] : memref<256x64x200x2xf32> to tensor<256x64x200x2xf32> +// CHECK: return %[[RESULT]] : tensor<256x64x200x2xf32> +func.func @bufferize_pack(%arg0: tensor<200x127x256xf32>, %arg1: tensor<256x64x200x2xf32>) -> tensor<256x64x200x2xf32> { + %pad = arith.constant 0.0 : f32 + %0 = linalg.pack %arg0 padding_value(%pad : f32) outer_dims_perm = [2, 1, 0] + inner_dims_pos = [1] inner_tiles = [2] into %arg1 + : tensor<200x127x256xf32> -> tensor<256x64x200x2xf32> + return %0 : tensor<256x64x200x2xf32> +} +