diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index d387c12deeed9..c92bda74c12bf 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -352,3 +352,7 @@ def unpack( ip=ip, ) ) + + +reduce = region_op(ReduceOp, terminator=YieldOp) +map = region_op(MapOp, terminator=YieldOp) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 709a1d2424f35..92591cd59fb40 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -1,7 +1,8 @@ # RUN: %PYTHON %s | FileCheck %s -from mlir.dialects import arith, func, linalg, tensor, memref +from mlir.dialects import arith, func, linalg, tensor, memref, builtin from mlir.dialects.linalg.opdsl.lang import * +from mlir.extras import types as T from mlir.ir import * @@ -857,3 +858,76 @@ def elementwise_op( ) print(module) + + +@run +def testReduceOp(): + with Context(), Location.unknown(): + f32 = T.f32() + tensor_type = T.tensor(10, f32) + + @builtin.module + def module(): + @func.func(tensor_type) + def reduce_op(input): + c1 = arith.constant(f32, 1.0) + single_result = ir.RankedTensorType.get((), f32) + dims = ir.DenseI64ArrayAttr.get([0]) + init = tensor.splat(single_result, c1, []) + + @linalg.reduce( + result=[single_result], + inputs=[input], + inits=[init], + dimensions=dims, + ) + def reduced(element: f32, acc: f32): + return arith.mulf(acc, element) + + return tensor.extract(reduced, []) + + print(module) + + +# CHECK-LABEL: func.func @reduce_op( +# CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32>) -> f32 { +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.000000e+00 : f32 +# CHECK: %[[SPLAT_0:.*]] = tensor.splat %[[CONSTANT_0]] : tensor +# CHECK: %[[REDUCE_0:.*]] = linalg.reduce { arith.mulf } ins(%[[ARG0]] : tensor<10xf32>) outs(%[[SPLAT_0]] : tensor) dimensions = [0] +# CHECK: %[[EXTRACT_0:.*]] = tensor.extract %[[REDUCE_0]][] : tensor +# CHECK: return %[[EXTRACT_0]] : f32 +# CHECK: } + + +@run +def testMapOp(): + with Context(), Location.unknown(): + f32 = T.f32() + tensor_type = T.tensor(10, f32) + + @builtin.module + def module(): + @func.func(tensor_type) + def map_op(input): + empty = tensor.empty(tensor_type.shape, f32) + + @linalg.map( + result=[tensor_type], + inputs=[input, input], + init=empty, + ) + def add(element: f32, acc: f32, init: f32): + return arith.addf(element, acc) + + return add + + module.verify() + print(module) + + +# CHECK-LABEL: func.func @map_op( +# CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32>) -> tensor<10xf32> { +# CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<10xf32> +# CHECK: %[[MAP_0:.*]] = linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG0]] : tensor<10xf32>, tensor<10xf32>) outs(%[[EMPTY_0]] : tensor<10xf32>) +# CHECK: return %[[MAP_0]] : tensor<10xf32> +# CHECK: }