Skip to content

Comments

[MLIR][Python] Add region_op wrappers for linalg#167616

Merged
ashermancinelli merged 3 commits intollvm:mainfrom
ashermancinelli:ajm/linalg-region-op-map-reduce
Nov 12, 2025
Merged

[MLIR][Python] Add region_op wrappers for linalg#167616
ashermancinelli merged 3 commits intollvm:mainfrom
ashermancinelli:ajm/linalg-region-op-map-reduce

Conversation

@ashermancinelli
Copy link
Contributor

Makes linalg.reduce and linalg.map region_ops so they can be constructed from functions and be called as decorators.

Makes linalg.reduce and linalg.map region_ops so they can be constructed
from functions and be called as decorators.
@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Asher Mancinelli (ashermancinelli)

Changes

Makes linalg.reduce and linalg.map region_ops so they can be constructed from functions and be called as decorators.


Full diff: https://github.com/llvm/llvm-project/pull/167616.diff

2 Files Affected:

  • (modified) mlir/python/mlir/dialects/linalg/init.py (+3)
  • (modified) mlir/test/python/dialects/linalg/ops.py (+68-1)
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index d387c12deeed9..2f463cadfdba9 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -352,3 +352,6 @@ def unpack(
             ip=ip,
         )
     )
+
+reduce = region_op(ReduceOp, terminator=YieldOp)
+map = region_op(MapOp, terminator=YieldOp)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 709a1d2424f35..b99e437e34f57 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -1,7 +1,8 @@
 # RUN: %PYTHON %s | FileCheck %s
 
-from mlir.dialects import arith, func, linalg, tensor, memref
+from mlir.dialects import arith, func, linalg, tensor, memref, builtin
 from mlir.dialects.linalg.opdsl.lang import *
+from mlir.extras import types as T
 from mlir.ir import *
 
 
@@ -857,3 +858,69 @@ def elementwise_op(
                     )
 
         print(module)
+
+@run
+def testReduceOp():
+    with Context(), Location.unknown():
+        f32 = T.f32()
+        tensor_type = T.tensor(10, f32)
+        @builtin.module
+        def module():
+            @func.func(tensor_type)
+            def reduce_op(input):
+                c1 = arith.constant(f32, 1.0)
+                single_result = ir.RankedTensorType.get((), f32)
+                dims = ir.DenseI64ArrayAttr.get([0])
+                init = tensor.splat(single_result, c1, [])
+
+                @linalg.reduce(
+                    result=[single_result],
+                    inputs=[input],
+                    inits=[init],
+                    dimensions=dims,
+                )
+                def reduced(element: f32, acc: f32):
+                    return arith.mulf(acc, element)
+
+                return tensor.extract(reduced, [])
+
+        print(module)
+
+# CHECK-LABEL:   func.func @reduce_op(
+# CHECK-SAME:      %[[ARG0:.*]]: tensor<10xf32>) -> f32 {
+# CHECK:           %[[CONSTANT_0:.*]] = arith.constant 1.000000e+00 : f32
+# CHECK:           %[[SPLAT_0:.*]] = tensor.splat %[[CONSTANT_0]] : tensor<f32>
+# CHECK:           %[[REDUCE_0:.*]] = linalg.reduce { arith.mulf } ins(%[[ARG0]] : tensor<10xf32>) outs(%[[SPLAT_0]] : tensor<f32>) dimensions = [0]
+# CHECK:           %[[EXTRACT_0:.*]] = tensor.extract %[[REDUCE_0]][] : tensor<f32>
+# CHECK:           return %[[EXTRACT_0]] : f32
+# CHECK:         }
+
+@run
+def testMapOp():
+    with Context(), Location.unknown():
+        f32 = T.f32()
+        tensor_type = T.tensor(10, f32)
+        @builtin.module
+        def module():
+            @func.func(tensor_type)
+            def map_op(input):
+                empty = tensor.empty(tensor_type.shape, f32)
+                @linalg.map(
+                    result=[tensor_type],
+                    inputs=[input, input],
+                    init=empty,
+                )
+                def add(element: f32, acc: f32, init: f32):
+                    return arith.addf(element, acc)
+
+                return add
+
+        module.verify()
+        print(module)
+
+# CHECK-LABEL:   func.func @map_op(
+# CHECK-SAME:                      %[[ARG0:.*]]: tensor<10xf32>) -> tensor<10xf32> {
+# CHECK:           %[[EMPTY_0:.*]] = tensor.empty() : tensor<10xf32>
+# CHECK:           %[[MAP_0:.*]] = linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG0]] : tensor<10xf32>, tensor<10xf32>) outs(%[[EMPTY_0]] : tensor<10xf32>)
+# CHECK:           return %[[MAP_0]] : tensor<10xf32>
+# CHECK:         }

@github-actions
Copy link

github-actions bot commented Nov 12, 2025

✅ With the latest revision this PR passed the Python code formatter.

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@ashermancinelli ashermancinelli merged commit 175e3be into llvm:main Nov 12, 2025
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants