[mlir][linalg][python] Add Python Bindings for Inferring Contraction Dimensions from Affine Maps#167587
Merged
bangtianliu merged 4 commits intollvm:mainfrom Nov 12, 2025
Merged
Conversation
Member
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Bangtian Liu (bangtianliu) ChangesThis patch exposes Full diff: https://github.com/llvm/llvm-project/pull/167587.diff 4 Files Affected:
diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 339e63d667c5e..989835485bdd9 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -10,6 +10,7 @@
#ifndef MLIR_C_DIALECT_LINALG_H
#define MLIR_C_DIALECT_LINALG_H
+#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
@@ -34,6 +35,10 @@ typedef struct MlirLinalgContractionDimensions {
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
mlirLinalgInferContractionDimensions(MlirOperation op);
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensionsFromMaps(MlirAffineMap const *indexingMaps,
+ intptr_t numMaps);
+
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op);
typedef struct MlirLinalgConvolutionDimensions {
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 015502371c65b..c179dac3c4df1 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -80,6 +80,29 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
"op.",
nb::arg("op"));
+ m.def(
+ "infer_contraction_dimensions_from_maps",
+ [](std::vector<MlirAffineMap> indexingMaps)
+ -> std::optional<MlirLinalgContractionDimensions> {
+ if (indexingMaps.empty())
+ return std::nullopt;
+
+ MlirLinalgContractionDimensions dims =
+ mlirLinalgInferContractionDimensionsFromMaps(
+ indexingMaps.data(), indexingMaps.size());
+
+ // Detect "empty" result. This occurs when the input is invalid
+ // or when `linalg::inferContractionDims` fails.
+ if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
+ mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
+ return std::nullopt;
+ }
+ return dims;
+ },
+ "Infers contraction dimensions (batch/m/n/k) from a list of affine "
+ "maps.",
+ nb::arg("indexing_maps"));
+
m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
"Checks if the given operation is a Linalg convolution operation.",
nb::arg("op"));
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 5c2a65d2c4c8a..ef2658167bb46 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Linalg.h"
+#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -75,6 +76,40 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
return result;
}
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensionsFromMaps(MlirAffineMap const *indexingMaps,
+ intptr_t numMaps) {
+ MlirLinalgContractionDimensions result{};
+ if (!indexingMaps || numMaps <= 0)
+ return result;
+
+ SmallVector<AffineMap> maps;
+ maps.reserve(numMaps);
+ for (intptr_t i = 0; i < numMaps; ++i) {
+ maps.push_back(unwrap(indexingMaps[i]));
+ }
+
+ FailureOr<linalg::ContractionDimensions> maybeDims =
+ linalg::inferContractionDims(maps);
+ if (failed(maybeDims))
+ return result;
+
+ const linalg::ContractionDimensions &contractionDims = *maybeDims;
+ MLIRContext *ctx = maps[0].getContext();
+
+ auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
+ return wrap(
+ DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
+ };
+
+ result.batch = toAttr(contractionDims.batch);
+ result.m = toAttr(contractionDims.m);
+ result.n = toAttr(contractionDims.n);
+ result.k = toAttr(contractionDims.k);
+
+ return result;
+}
+
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) {
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
if (!linalgOp)
diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py
index 5f7cb6a6c83cb..c0fd0efed7c65 100644
--- a/mlir/test/python/dialects/linalg/utils.py
+++ b/mlir/test/python/dialects/linalg/utils.py
@@ -208,3 +208,45 @@ def matmul_func(a, b, c):
assert maps[0] == a_map
assert maps[1] == b_map
assert maps[2] == c_map
+
+
+@run
+def test_infer_contraction_dimensions_from_maps():
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ # === Test valid contraction (matmul) ===
+ dim_m = AffineDimExpr.get(0)
+ dim_n = AffineDimExpr.get(1)
+ dim_k = AffineDimExpr.get(2)
+ a_map = AffineMap.get(3, 0, [dim_m, dim_k])
+ b_map = AffineMap.get(3, 0, [dim_k, dim_n])
+ c_map = AffineMap.get(3, 0, [dim_m, dim_n])
+
+ dims = linalg.infer_contraction_dimensions_from_maps(
+ [a_map, b_map, c_map]
+ )
+ assert dims is not None
+
+ # Expect m=[0], n=[1], k=[2] as per standard matmul.
+ assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
+ assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
+ assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
+ assert (
+ list(dims.batch) == []
+ ), f"Expected batch=[], got {list(dims.batch)}"
+
+ # === Test invalid input (wrong number of maps) ===
+ invalid_dims = linalg.infer_contraction_dimensions_from_maps(
+ [a_map, b_map]
+ )
+ assert invalid_dims is None
+
+ # === Test non-contraction (element-wise operation) ===
+ dim_i = AffineDimExpr.get(0)
+ dim_j = AffineDimExpr.get(1)
+ elementwise_map = AffineMap.get(2, 0, [dim_i, dim_j])
+ non_contraction_dims = linalg.infer_contraction_dimensions_from_maps(
+ [elementwise_map, elementwise_map, elementwise_map]
+ )
+ assert non_contraction_dims is None
|
5ba54b2 to
1a520ea
Compare
|
✅ With the latest revision this PR passed the Python code formatter. |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
9984adf to
f875428
Compare
6bc3096 to
a8d064e
Compare
f1d9e3a to
b76fce9
Compare
kuhar
reviewed
Nov 12, 2025
kuhar
reviewed
Nov 12, 2025
0d047ad to
e9f7979
Compare
kuhar
reviewed
Nov 12, 2025
kuhar
approved these changes
Nov 12, 2025
kuhar
reviewed
Nov 12, 2025
39a3d04 to
21cc334
Compare
…dimensions from affine maps Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
git-crd
pushed a commit
to git-crd/crd-llvm-project
that referenced
this pull request
Nov 13, 2025
…Dimensions from Affine Maps (llvm#167587) This PR exposes `linalg::inferContractionDims(ArrayRef<AffineMap>)` to Python, allowing users to infer contraction dimensions (batch/m/n/k) directly from a list of affine maps without needing an operation. --------- Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR exposes
linalg::inferContractionDims(ArrayRef<AffineMap>)to Python, allowing users to infer contraction dimensions (batch/m/n/k) directly from a list of affine maps without needing an operation.