From 7cda1b4a0de87393f0f3645743d99a6cd1ea047d Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Tue, 25 Nov 2025 18:13:49 -0600 Subject: [PATCH 1/6] [bindings] Expose scaled contraction matchers in python binding for tuner Signed-off-by: Muzammiluddin Syed --- .../c/iree/compiler/dialects/iree_codegen.h | 17 ++ .../python/IREECompilerDialectsModule.cpp | 48 ++++ .../python/test/api/tuner_api_test.py | 255 ++++++++++++++++++ .../API/Internal/IREECodegenDialectCAPI.cpp | 85 ++++++ 4 files changed, 405 insertions(+) diff --git a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h index c5e625501fc9..6e60f2c6c71d 100644 --- a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h +++ b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h @@ -119,6 +119,23 @@ MLIR_CAPI_EXPORTED bool ireeCodegenHasIGEMMGenericConvDetails(MlirOperation op); MLIR_CAPI_EXPORTED ireeCodegenIGEMMGenericConvDetails ireeCodegenGetIGEMMGenericConvDetails(MlirOperation op); +typedef struct ireeCodegenScaledContractionDimensions { + MlirAttribute batch; + MlirAttribute m; + MlirAttribute n; + MlirAttribute k; + MlirAttribute kB; +} ireeCodegenScaledContractionDimensions; + +MLIR_CAPI_EXPORTED bool ireeCodegenMlirOperationIsAScaledContractionOp(MlirOperation op); + +MLIR_CAPI_EXPORTED ireeCodegenScaledContractionDimensions +ireeCodegenInferScaledContractionDimensions(MlirOperation op); + +MLIR_CAPI_EXPORTED ireeCodegenScaledContractionDimensions +ireeCodegenInferScaledContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps, + size_t numMaps); + #ifdef __cplusplus } #endif diff --git a/compiler/bindings/python/IREECompilerDialectsModule.cpp b/compiler/bindings/python/IREECompilerDialectsModule.cpp index beb0e1860e77..f6f27217aae1 100644 --- a/compiler/bindings/python/IREECompilerDialectsModule.cpp +++ b/compiler/bindings/python/IREECompilerDialectsModule.cpp @@ -770,4 +770,52 @@ NB_MODULE(_ireeCompilerDialects, m) { "Gets IGEMM details for a linalg operation. " "Returns None if failed to infer IGEMM convolution details.", py::arg("linalg_op")); + + //===-------------------------------------------------------------------===// + // Binding to utility function ireeCodegenGetScaledContractionDetails + //===-------------------------------------------------------------------===// + iree_codegen_module.def( + "isa_scaled_contraction_op", &ireeCodegenMlirOperationIsAScaledContractionOp, + "Checks if the given operation is an IREE LinalgExt scaled contraction op.", + py::arg("op")); + + //===-------------------------------------------------------------------===// + // Binding to struct ireeCodegenScaledContractionDimensions + //===-------------------------------------------------------------------===// + py::class_(iree_codegen_module, + "ScaledContractionDimensions") + .def_prop_ro("batch", + [](const ireeCodegenScaledContractionDimensions &self) { + return getIntArrayAttrValues(self.batch); + }) + .def_prop_ro("m", + [](const ireeCodegenScaledContractionDimensions &self) { + return getIntArrayAttrValues(self.m); + }) + .def_prop_ro("n", + [](const ireeCodegenScaledContractionDimensions &self) { + return getIntArrayAttrValues(self.n); + }) + .def_prop_ro("k", + [](const ireeCodegenScaledContractionDimensions &self) { + return getIntArrayAttrValues(self.k); + }) + .def_prop_ro("kB", + [](const ireeCodegenScaledContractionDimensions &self) { + return getIntArrayAttrValues(self.kB); + }); + + iree_codegen_module.def( + "infer_scaled_contraction_dimensions", &ireeCodegenInferScaledContractionDimensions, + "Infers the scaled contraction dimensions for a given operation.", + py::arg("op")); + + iree_codegen_module.def( + "infer_scaled_contraction_dimensions_from_maps", + [](const std::vector &indexingMaps) -> ireeCodegenScaledContractionDimensions { + return ireeCodegenInferScaledContractionDimensionsFromMaps( + indexingMaps.data(), indexingMaps.size()); + }, + "Infers the scaled contraction dimensions for a given operation from indexing maps.", + py::arg("indexing_maps")); } diff --git a/compiler/bindings/python/test/api/tuner_api_test.py b/compiler/bindings/python/test/api/tuner_api_test.py index e59813f90096..de1f3d4e2288 100644 --- a/compiler/bindings/python/test/api/tuner_api_test.py +++ b/compiler/bindings/python/test/api/tuner_api_test.py @@ -336,3 +336,258 @@ def test_igemm_conv_details(): details = iree_codegen.get_igemm_generic_conv_details(matmul_op) assert details is None, "IGEMM details should be None for non-conv operation" + + +@run +def test_isa_scaled_contraction_op(): + # Test 1: Regular matmul is not a scaled contraction + module_str = """ + module { + func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4x4xf32>) -> tensor<4x4xf32> + return %0 : tensor<4x4xf32> + } + } + """ + input_module = ir.Module.parse(module_str) + assert input_module is not None, "Failed to parse input MLIR module" + root_op_list = iree_codegen.get_tuner_root_ops(input_module) + assert len(root_op_list) == 1 + matmul_op = root_op_list[0] + + # Regular matmul should not be a scaled contraction + assert not iree_codegen.isa_scaled_contraction_op(matmul_op), \ + "Regular matmul should not be a scaled contraction" + + # Test 2: Fill op is not a scaled contraction + module_str = """ + module { + func.func @fill(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill { root_op } ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> + return %0 : tensor<4x4xf32> + } + } + """ + input_module = ir.Module.parse(module_str) + root_op_list = iree_codegen.get_tuner_root_ops(input_module) + assert len(root_op_list) == 1 + fill_op = root_op_list[0] + + assert not iree_codegen.isa_scaled_contraction_op(fill_op), \ + "Fill op should not be a scaled contraction" + + # Test 3: Scaled matmul as linalg.generic should be detected + # Pattern: linalg.generic with 5 indexing maps (lhs, rhs, lhs_scale, rhs_scale, output) + # and 4 iterator types (2 parallel for M,N; 2 reduction for Ko,Kb) + # Uses f4E2M1FN for operands and f8E8M0FNU for scales (matching real scaled matmul pattern) + module_str = """ + module { + func.func @scaled_matmul(%lhs: tensor<16x4x32xf4E2M1FN>, %rhs: tensor<16x4x32xf4E2M1FN>, + %lhs_scales: tensor<16x4xf8E8M0FNU>, %rhs_scales: tensor<16x4xf8E8M0FNU>, + %out: tensor<16x16xf32>) -> tensor<16x16xf32> { + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction", "reduction"], + root_op + } ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<16x4x32xf4E2M1FN>, tensor<16x4x32xf4E2M1FN>, tensor<16x4xf8E8M0FNU>, tensor<16x4xf8E8M0FNU>) + outs(%out : tensor<16x16xf32>) { + ^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32): + %a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32 + %b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32 + %prod = arith.mulf %a_scaled, %b_scaled : f32 + %sum = arith.addf %acc, %prod : f32 + linalg.yield %sum : f32 + } -> tensor<16x16xf32> + return %result : tensor<16x16xf32> + } + } + """ + input_module = ir.Module.parse(module_str) + root_op_list = iree_codegen.get_tuner_root_ops(input_module) + assert len(root_op_list) == 1, "Should have one root op" + scaled_generic_op = root_op_list[0] + + # Check if it's recognized as a scaled contraction + is_scaled = iree_codegen.isa_scaled_contraction_op(scaled_generic_op) + assert is_scaled, "linalg.generic with scaled matmul pattern should be detected as scaled contraction" + + # Try to infer dimensions + dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_generic_op) + assert dims is not None, "Should be able to infer dimensions for scaled contraction" + + # Expected: m=[0], n=[1], k=[2], kB=[3] for the given indexing maps + assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" + assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" + assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" + assert list(dims.kB) == [3], f"Expected kB=[3], got {list(dims.kB)}" + assert list(dims.batch) == [], f"Expected no batch dims, got {list(dims.batch)}" + + +@run +def test_infer_scaled_contraction_dimensions(): + # Test 1: Verify dimension inference on a scaled matmul operation + module_str = """ + module { + func.func @scaled_matmul(%lhs: tensor<16x4x32xf4E2M1FN>, %rhs: tensor<16x4x32xf4E2M1FN>, + %lhs_scales: tensor<16x4xf8E8M0FNU>, %rhs_scales: tensor<16x4xf8E8M0FNU>, + %out: tensor<16x16xf32>) -> tensor<16x16xf32> { + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction", "reduction"], + root_op + } ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<16x4x32xf4E2M1FN>, tensor<16x4x32xf4E2M1FN>, tensor<16x4xf8E8M0FNU>, tensor<16x4xf8E8M0FNU>) + outs(%out : tensor<16x16xf32>) { + ^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32): + %a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32 + %b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32 + %prod = arith.mulf %a_scaled, %b_scaled : f32 + %sum = arith.addf %acc, %prod : f32 + linalg.yield %sum : f32 + } -> tensor<16x16xf32> + return %result : tensor<16x16xf32> + } + } + """ + input_module = ir.Module.parse(module_str) + root_op_list = iree_codegen.get_tuner_root_ops(input_module) + assert len(root_op_list) == 1, "Should have exactly one root op" + scaled_op = root_op_list[0] + + # Verify it's a scaled contraction first + assert iree_codegen.isa_scaled_contraction_op(scaled_op), \ + "Operation should be recognized as scaled contraction" + + # Test dimension inference + dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_op) + assert dims is not None, "Should successfully infer dimensions" + + # Verify the inferred dimensions match expected values + # For the given indexing maps: + # d0 = M (parallel) -> m + # d1 = N (parallel) -> n + # d2 = Ko (reduction) -> k + # d3 = Kb (reduction, block dim) -> kB + assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" + assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" + assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" + assert list(dims.kB) == [3], f"Expected kB=[3], got {list(dims.kB)}" + assert list(dims.batch) == [], f"Expected no batch dims, got {list(dims.batch)}" + + # Test 2: Non-scaled contraction should return None + module_str_regular = """ + module { + func.func @regular_matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4x4xf32>) -> tensor<4x4xf32> + return %0 : tensor<4x4xf32> + } + } + """ + input_module_regular = ir.Module.parse(module_str_regular) + regular_ops = iree_codegen.get_tuner_root_ops(input_module_regular) + assert len(regular_ops) == 1 + regular_matmul = regular_ops[0] + + # Regular matmul should not have scaled contraction dimensions + dims_regular = iree_codegen.infer_scaled_contraction_dimensions(regular_matmul) + # Check if all dimensions are empty (indicating it's not a scaled contraction) + if dims_regular is not None: + all_empty = (len(list(dims_regular.m)) == 0 and + len(list(dims_regular.n)) == 0 and + len(list(dims_regular.k)) == 0 and + len(list(dims_regular.kB)) == 0 and + len(list(dims_regular.batch)) == 0) + assert all_empty or dims_regular is None, \ + "Regular matmul should not have valid scaled contraction dimensions" + + # Test 3: Batched scaled matmul + module_str_batched = """ + module { + func.func @batched_scaled_matmul(%lhs: tensor<8x16x4x32xf4E2M1FN>, %rhs: tensor<8x16x4x32xf4E2M1FN>, + %lhs_scales: tensor<8x16x4xf8E8M0FNU>, %rhs_scales: tensor<8x16x4xf8E8M0FNU>, + %out: tensor<8x16x16xf32>) -> tensor<8x16x16xf32> { + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> + ], + iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"], + root_op + } ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<8x16x4x32xf4E2M1FN>, tensor<8x16x4x32xf4E2M1FN>, tensor<8x16x4xf8E8M0FNU>, tensor<8x16x4xf8E8M0FNU>) + outs(%out : tensor<8x16x16xf32>) { + ^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32): + %a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32 + %b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32 + %prod = arith.mulf %a_scaled, %b_scaled : f32 + %sum = arith.addf %acc, %prod : f32 + linalg.yield %sum : f32 + } -> tensor<8x16x16xf32> + return %result : tensor<8x16x16xf32> + } + } + """ + input_module_batched = ir.Module.parse(module_str_batched) + batched_ops = iree_codegen.get_tuner_root_ops(input_module_batched) + if len(batched_ops) == 1: + batched_op = batched_ops[0] + assert iree_codegen.isa_scaled_contraction_op(batched_op), \ + "Batched scaled matmul should be recognized" + + dims_batched = iree_codegen.infer_scaled_contraction_dimensions(batched_op) + if dims_batched is not None: + # Expected: batch=[0], m=[1], n=[2], k=[3], kB=[4] + assert list(dims_batched.batch) == [0], f"Expected batch=[0], got {list(dims_batched.batch)}" + assert list(dims_batched.m) == [1], f"Expected m=[1], got {list(dims_batched.m)}" + assert list(dims_batched.n) == [2], f"Expected n=[2], got {list(dims_batched.n)}" + assert list(dims_batched.k) == [3], f"Expected k=[3], got {list(dims_batched.k)}" + assert list(dims_batched.kB) == [4], f"Expected kB=[4], got {list(dims_batched.kB)}" + + +@run +def test_infer_scaled_contraction_dimensions_from_maps(): + # Test inferring scaled contraction dimensions from affine maps + # This follows the pattern of a scaled matmul with block scaling + # Pattern: (M, N, Ko, Kb) where Ko is the outer reduction and Kb is the block dimension + d0, d1, d2, d3 = [AffineDimExpr.get(i) for i in range(4)] + + # Maps for scaled contraction matching the example: + # lhs_map: (M, Ko, Kb) - left operand with outer and block reduction dims + # rhs_map: (N, Ko, Kb) - right operand with outer and block reduction dims + # lhs_scale_map: (M, Ko) - left scale factors indexed by M and Ko + # rhs_scale_map: (N, Ko) - right scale factors indexed by N and Ko + # out_map: (M, N) - output indexed by parallel dims only + + lhs_map = AffineMap.get(4, 0, [d0, d2, d3]) # (M, Ko, Kb) + rhs_map = AffineMap.get(4, 0, [d1, d2, d3]) # (N, Ko, Kb) + lhs_scale_map = AffineMap.get(4, 0, [d0, d2]) # (M, Ko) + rhs_scale_map = AffineMap.get(4, 0, [d1, d2]) # (N, Ko) + out_map = AffineMap.get(4, 0, [d0, d1]) # (M, N) + + # Call the inference function + dims = iree_codegen.infer_scaled_contraction_dimensions_from_maps( + [lhs_map, rhs_map, lhs_scale_map, rhs_scale_map, out_map] + ) + + assert dims is not None, "Should be able to infer scaled contraction dimensions from maps" + # Verify the inferred dimensions + # Expected: m=[0] (d0), n=[1] (d1), k=[2] (d2/Ko), kB=[3] (d3/Kb) + assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" + assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" + assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" + assert list(dims.kB) == [3], f"Expected kB=[3], got {list(dims.kB)}" + assert list(dims.batch) == [], f"Expected no batch dims, got {list(dims.batch)}" diff --git a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp index 619f40d55beb..df03dd0e4592 100644 --- a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp +++ b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" +#include "iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.h" #include "iree/compiler/dialects/iree_codegen.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/IR.h" @@ -320,3 +321,87 @@ ireeCodegenGetIGEMMGenericConvDetails(MlirOperation op) { return result; } + +bool ireeCodegenMlirOperationIsAScaledContractionOp(MlirOperation op) { + auto linalgOp = llvm::cast(unwrap(op)); + return mlir::iree_compiler::IREE::LinalgExt::isaScaledContractionOpInterface( + linalgOp); +} + +ireeCodegenScaledContractionDimensions +ireeCodegenInferScaledContractionDimensions(MlirOperation op) { + ireeCodegenScaledContractionDimensions result{}; + auto linalgOp = llvm::dyn_cast(unwrap(op)); + if (!linalgOp) { + return result; + } + + llvm::FailureOr< + mlir::iree_compiler::IREE::LinalgExt::ScaledContractionDimensions> + maybeDims = + mlir::iree_compiler::IREE::LinalgExt::inferScaledContractionDims( + linalgOp); + if (failed(maybeDims)) { + return result; + } + + const mlir::iree_compiler::IREE::LinalgExt::ScaledContractionDimensions + &scaledContractionDims = *maybeDims; + mlir::MLIRContext *ctx = linalgOp.getContext(); + + auto toAttr = [ctx](llvm::ArrayRef vals) -> MlirAttribute { + mlir::Builder b(ctx); + llvm::SmallVector attrs; + for (unsigned val : vals) { + attrs.push_back(b.getI32IntegerAttr(val)); + } + return wrap(b.getArrayAttr(attrs)); + }; + + result.batch = toAttr(scaledContractionDims.batch); + result.m = toAttr(scaledContractionDims.m); + result.n = toAttr(scaledContractionDims.n); + result.k = toAttr(scaledContractionDims.k); + result.kB = toAttr(scaledContractionDims.kB); + return result; +} + +ireeCodegenScaledContractionDimensions +ireeCodegenInferScaledContractionDimensionsFromMaps( + const MlirAffineMap *indexingMaps, size_t numMaps) { + ireeCodegenScaledContractionDimensions result{}; + if (!indexingMaps || numMaps == 0) { + return result; + } + + llvm::SmallVector maps; + for (size_t i = 0; i < numMaps; ++i) { + maps.push_back(unwrap(indexingMaps[i])); + } + + llvm::FailureOr< + mlir::iree_compiler::IREE::LinalgExt::ScaledContractionDimensions> + maybeDims = + mlir::iree_compiler::IREE::LinalgExt::inferScaledContractionDims( + maps); + if (failed(maybeDims)) { + return result; + } + + mlir::MLIRContext *ctx = maps[0].getContext(); + auto toAttr = [ctx](llvm::ArrayRef vals) -> MlirAttribute { + mlir::Builder b(ctx); + llvm::SmallVector attrs; + for (unsigned val : vals) { + attrs.push_back(b.getI32IntegerAttr(val)); + } + return wrap(b.getArrayAttr(attrs)); + }; + + result.batch = toAttr(maybeDims->batch); + result.m = toAttr(maybeDims->m); + result.n = toAttr(maybeDims->n); + result.k = toAttr(maybeDims->k); + result.kB = toAttr(maybeDims->kB); + return result; +} \ No newline at end of file From 60f3c7dd5aa52541905ade827e7808921ea3f64f Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Tue, 25 Nov 2025 19:22:58 -0600 Subject: [PATCH 2/6] fixing lint issues Signed-off-by: Muzammiluddin Syed --- .../c/iree/compiler/dialects/iree_codegen.h | 7 +- .../python/IREECompilerDialectsModule.cpp | 22 ++-- .../python/test/api/tuner_api_test.py | 113 +++++++++++------- .../API/Internal/IREECodegenDialectCAPI.cpp | 4 +- 4 files changed, 86 insertions(+), 60 deletions(-) diff --git a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h index 6e60f2c6c71d..be88bf84fee2 100644 --- a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h +++ b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h @@ -127,14 +127,15 @@ typedef struct ireeCodegenScaledContractionDimensions { MlirAttribute kB; } ireeCodegenScaledContractionDimensions; -MLIR_CAPI_EXPORTED bool ireeCodegenMlirOperationIsAScaledContractionOp(MlirOperation op); +MLIR_CAPI_EXPORTED bool +ireeCodegenMlirOperationIsAScaledContractionOp(MlirOperation op); MLIR_CAPI_EXPORTED ireeCodegenScaledContractionDimensions ireeCodegenInferScaledContractionDimensions(MlirOperation op); MLIR_CAPI_EXPORTED ireeCodegenScaledContractionDimensions -ireeCodegenInferScaledContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps, - size_t numMaps); +ireeCodegenInferScaledContractionDimensionsFromMaps( + const MlirAffineMap *indexingMaps, size_t numMaps); #ifdef __cplusplus } diff --git a/compiler/bindings/python/IREECompilerDialectsModule.cpp b/compiler/bindings/python/IREECompilerDialectsModule.cpp index f6f27217aae1..fed29d947a61 100644 --- a/compiler/bindings/python/IREECompilerDialectsModule.cpp +++ b/compiler/bindings/python/IREECompilerDialectsModule.cpp @@ -774,16 +774,17 @@ NB_MODULE(_ireeCompilerDialects, m) { //===-------------------------------------------------------------------===// // Binding to utility function ireeCodegenGetScaledContractionDetails //===-------------------------------------------------------------------===// - iree_codegen_module.def( - "isa_scaled_contraction_op", &ireeCodegenMlirOperationIsAScaledContractionOp, - "Checks if the given operation is an IREE LinalgExt scaled contraction op.", - py::arg("op")); + iree_codegen_module.def("isa_scaled_contraction_op", + &ireeCodegenMlirOperationIsAScaledContractionOp, + "Checks if the given operation is an IREE LinalgExt " + "scaled contraction op.", + py::arg("op")); //===-------------------------------------------------------------------===// // Binding to struct ireeCodegenScaledContractionDimensions //===-------------------------------------------------------------------===// - py::class_(iree_codegen_module, - "ScaledContractionDimensions") + py::class_( + iree_codegen_module, "ScaledContractionDimensions") .def_prop_ro("batch", [](const ireeCodegenScaledContractionDimensions &self) { return getIntArrayAttrValues(self.batch); @@ -806,16 +807,19 @@ NB_MODULE(_ireeCompilerDialects, m) { }); iree_codegen_module.def( - "infer_scaled_contraction_dimensions", &ireeCodegenInferScaledContractionDimensions, + "infer_scaled_contraction_dimensions", + &ireeCodegenInferScaledContractionDimensions, "Infers the scaled contraction dimensions for a given operation.", py::arg("op")); iree_codegen_module.def( "infer_scaled_contraction_dimensions_from_maps", - [](const std::vector &indexingMaps) -> ireeCodegenScaledContractionDimensions { + [](const std::vector &indexingMaps) + -> ireeCodegenScaledContractionDimensions { return ireeCodegenInferScaledContractionDimensionsFromMaps( indexingMaps.data(), indexingMaps.size()); }, - "Infers the scaled contraction dimensions for a given operation from indexing maps.", + "Infers the scaled contraction dimensions for a given operation from " + "indexing maps.", py::arg("indexing_maps")); } diff --git a/compiler/bindings/python/test/api/tuner_api_test.py b/compiler/bindings/python/test/api/tuner_api_test.py index de1f3d4e2288..f7c4083c77a3 100644 --- a/compiler/bindings/python/test/api/tuner_api_test.py +++ b/compiler/bindings/python/test/api/tuner_api_test.py @@ -354,11 +354,12 @@ def test_isa_scaled_contraction_op(): root_op_list = iree_codegen.get_tuner_root_ops(input_module) assert len(root_op_list) == 1 matmul_op = root_op_list[0] - + # Regular matmul should not be a scaled contraction - assert not iree_codegen.isa_scaled_contraction_op(matmul_op), \ - "Regular matmul should not be a scaled contraction" - + assert not iree_codegen.isa_scaled_contraction_op( + matmul_op + ), "Regular matmul should not be a scaled contraction" + # Test 2: Fill op is not a scaled contraction module_str = """ module { @@ -373,10 +374,11 @@ def test_isa_scaled_contraction_op(): root_op_list = iree_codegen.get_tuner_root_ops(input_module) assert len(root_op_list) == 1 fill_op = root_op_list[0] - - assert not iree_codegen.isa_scaled_contraction_op(fill_op), \ - "Fill op should not be a scaled contraction" - + + assert not iree_codegen.isa_scaled_contraction_op( + fill_op + ), "Fill op should not be a scaled contraction" + # Test 3: Scaled matmul as linalg.generic should be detected # Pattern: linalg.generic with 5 indexing maps (lhs, rhs, lhs_scale, rhs_scale, output) # and 4 iterator types (2 parallel for M,N; 2 reduction for Ko,Kb) @@ -413,15 +415,17 @@ def test_isa_scaled_contraction_op(): root_op_list = iree_codegen.get_tuner_root_ops(input_module) assert len(root_op_list) == 1, "Should have one root op" scaled_generic_op = root_op_list[0] - + # Check if it's recognized as a scaled contraction is_scaled = iree_codegen.isa_scaled_contraction_op(scaled_generic_op) - assert is_scaled, "linalg.generic with scaled matmul pattern should be detected as scaled contraction" - + assert ( + is_scaled + ), "linalg.generic with scaled matmul pattern should be detected as scaled contraction" + # Try to infer dimensions dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_generic_op) assert dims is not None, "Should be able to infer dimensions for scaled contraction" - + # Expected: m=[0], n=[1], k=[2], kB=[3] for the given indexing maps assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" @@ -465,19 +469,20 @@ def test_infer_scaled_contraction_dimensions(): root_op_list = iree_codegen.get_tuner_root_ops(input_module) assert len(root_op_list) == 1, "Should have exactly one root op" scaled_op = root_op_list[0] - + # Verify it's a scaled contraction first - assert iree_codegen.isa_scaled_contraction_op(scaled_op), \ - "Operation should be recognized as scaled contraction" - + assert iree_codegen.isa_scaled_contraction_op( + scaled_op + ), "Operation should be recognized as scaled contraction" + # Test dimension inference dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_op) assert dims is not None, "Should successfully infer dimensions" - + # Verify the inferred dimensions match expected values # For the given indexing maps: # d0 = M (parallel) -> m - # d1 = N (parallel) -> n + # d1 = N (parallel) -> n # d2 = Ko (reduction) -> k # d3 = Kb (reduction, block dim) -> kB assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" @@ -485,7 +490,7 @@ def test_infer_scaled_contraction_dimensions(): assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" assert list(dims.kB) == [3], f"Expected kB=[3], got {list(dims.kB)}" assert list(dims.batch) == [], f"Expected no batch dims, got {list(dims.batch)}" - + # Test 2: Non-scaled contraction should return None module_str_regular = """ module { @@ -499,19 +504,22 @@ def test_infer_scaled_contraction_dimensions(): regular_ops = iree_codegen.get_tuner_root_ops(input_module_regular) assert len(regular_ops) == 1 regular_matmul = regular_ops[0] - + # Regular matmul should not have scaled contraction dimensions dims_regular = iree_codegen.infer_scaled_contraction_dimensions(regular_matmul) # Check if all dimensions are empty (indicating it's not a scaled contraction) if dims_regular is not None: - all_empty = (len(list(dims_regular.m)) == 0 and - len(list(dims_regular.n)) == 0 and - len(list(dims_regular.k)) == 0 and - len(list(dims_regular.kB)) == 0 and - len(list(dims_regular.batch)) == 0) - assert all_empty or dims_regular is None, \ - "Regular matmul should not have valid scaled contraction dimensions" - + all_empty = ( + len(list(dims_regular.m)) == 0 + and len(list(dims_regular.n)) == 0 + and len(list(dims_regular.k)) == 0 + and len(list(dims_regular.kB)) == 0 + and len(list(dims_regular.batch)) == 0 + ) + assert ( + all_empty or dims_regular is None + ), "Regular matmul should not have valid scaled contraction dimensions" + # Test 3: Batched scaled matmul module_str_batched = """ module { @@ -545,17 +553,28 @@ def test_infer_scaled_contraction_dimensions(): batched_ops = iree_codegen.get_tuner_root_ops(input_module_batched) if len(batched_ops) == 1: batched_op = batched_ops[0] - assert iree_codegen.isa_scaled_contraction_op(batched_op), \ - "Batched scaled matmul should be recognized" - + assert iree_codegen.isa_scaled_contraction_op( + batched_op + ), "Batched scaled matmul should be recognized" + dims_batched = iree_codegen.infer_scaled_contraction_dimensions(batched_op) if dims_batched is not None: # Expected: batch=[0], m=[1], n=[2], k=[3], kB=[4] - assert list(dims_batched.batch) == [0], f"Expected batch=[0], got {list(dims_batched.batch)}" - assert list(dims_batched.m) == [1], f"Expected m=[1], got {list(dims_batched.m)}" - assert list(dims_batched.n) == [2], f"Expected n=[2], got {list(dims_batched.n)}" - assert list(dims_batched.k) == [3], f"Expected k=[3], got {list(dims_batched.k)}" - assert list(dims_batched.kB) == [4], f"Expected kB=[4], got {list(dims_batched.kB)}" + assert list(dims_batched.batch) == [ + 0 + ], f"Expected batch=[0], got {list(dims_batched.batch)}" + assert list(dims_batched.m) == [ + 1 + ], f"Expected m=[1], got {list(dims_batched.m)}" + assert list(dims_batched.n) == [ + 2 + ], f"Expected n=[2], got {list(dims_batched.n)}" + assert list(dims_batched.k) == [ + 3 + ], f"Expected k=[3], got {list(dims_batched.k)}" + assert list(dims_batched.kB) == [ + 4 + ], f"Expected kB=[4], got {list(dims_batched.kB)}" @run @@ -564,26 +583,28 @@ def test_infer_scaled_contraction_dimensions_from_maps(): # This follows the pattern of a scaled matmul with block scaling # Pattern: (M, N, Ko, Kb) where Ko is the outer reduction and Kb is the block dimension d0, d1, d2, d3 = [AffineDimExpr.get(i) for i in range(4)] - + # Maps for scaled contraction matching the example: # lhs_map: (M, Ko, Kb) - left operand with outer and block reduction dims # rhs_map: (N, Ko, Kb) - right operand with outer and block reduction dims # lhs_scale_map: (M, Ko) - left scale factors indexed by M and Ko # rhs_scale_map: (N, Ko) - right scale factors indexed by N and Ko # out_map: (M, N) - output indexed by parallel dims only - - lhs_map = AffineMap.get(4, 0, [d0, d2, d3]) # (M, Ko, Kb) - rhs_map = AffineMap.get(4, 0, [d1, d2, d3]) # (N, Ko, Kb) - lhs_scale_map = AffineMap.get(4, 0, [d0, d2]) # (M, Ko) - rhs_scale_map = AffineMap.get(4, 0, [d1, d2]) # (N, Ko) - out_map = AffineMap.get(4, 0, [d0, d1]) # (M, N) - + + lhs_map = AffineMap.get(4, 0, [d0, d2, d3]) # (M, Ko, Kb) + rhs_map = AffineMap.get(4, 0, [d1, d2, d3]) # (N, Ko, Kb) + lhs_scale_map = AffineMap.get(4, 0, [d0, d2]) # (M, Ko) + rhs_scale_map = AffineMap.get(4, 0, [d1, d2]) # (N, Ko) + out_map = AffineMap.get(4, 0, [d0, d1]) # (M, N) + # Call the inference function dims = iree_codegen.infer_scaled_contraction_dimensions_from_maps( [lhs_map, rhs_map, lhs_scale_map, rhs_scale_map, out_map] ) - - assert dims is not None, "Should be able to infer scaled contraction dimensions from maps" + + assert ( + dims is not None + ), "Should be able to infer scaled contraction dimensions from maps" # Verify the inferred dimensions # Expected: m=[0] (d0), n=[1] (d1), k=[2] (d2/Ko), kB=[3] (d3/Kb) assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" diff --git a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp index df03dd0e4592..4519683594ce 100644 --- a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp +++ b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp @@ -13,8 +13,8 @@ #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h" -#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.h" +#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" #include "iree/compiler/dialects/iree_codegen.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/IR.h" @@ -404,4 +404,4 @@ ireeCodegenInferScaledContractionDimensionsFromMaps( result.k = toAttr(maybeDims->k); result.kB = toAttr(maybeDims->kB); return result; -} \ No newline at end of file +} From ad026e05369999e276161fd9b36c9f7eba0c6442 Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Wed, 26 Nov 2025 09:12:42 -0600 Subject: [PATCH 3/6] Address PR comments Signed-off-by: Muzammiluddin Syed --- .../c/iree/compiler/dialects/iree_codegen.h | 8 +- .../python/IREECompilerDialectsModule.cpp | 11 --- .../python/test/api/tuner_api_test.py | 85 ++++++------------- .../API/Internal/IREECodegenDialectCAPI.cpp | 50 ++--------- compiler/src/iree/compiler/API/api_exports.c | 4 + .../src/iree/compiler/API/api_exports.def | 2 + compiler/src/iree/compiler/API/api_exports.ld | 2 + .../iree/compiler/API/api_exports.macos.lst | 2 + 8 files changed, 41 insertions(+), 123 deletions(-) diff --git a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h index be88bf84fee2..f6e4219fae02 100644 --- a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h +++ b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h @@ -119,13 +119,13 @@ MLIR_CAPI_EXPORTED bool ireeCodegenHasIGEMMGenericConvDetails(MlirOperation op); MLIR_CAPI_EXPORTED ireeCodegenIGEMMGenericConvDetails ireeCodegenGetIGEMMGenericConvDetails(MlirOperation op); -typedef struct ireeCodegenScaledContractionDimensions { +struct ireeCodegenScaledContractionDimensions { MlirAttribute batch; MlirAttribute m; MlirAttribute n; MlirAttribute k; MlirAttribute kB; -} ireeCodegenScaledContractionDimensions; +}; MLIR_CAPI_EXPORTED bool ireeCodegenMlirOperationIsAScaledContractionOp(MlirOperation op); @@ -133,10 +133,6 @@ ireeCodegenMlirOperationIsAScaledContractionOp(MlirOperation op); MLIR_CAPI_EXPORTED ireeCodegenScaledContractionDimensions ireeCodegenInferScaledContractionDimensions(MlirOperation op); -MLIR_CAPI_EXPORTED ireeCodegenScaledContractionDimensions -ireeCodegenInferScaledContractionDimensionsFromMaps( - const MlirAffineMap *indexingMaps, size_t numMaps); - #ifdef __cplusplus } #endif diff --git a/compiler/bindings/python/IREECompilerDialectsModule.cpp b/compiler/bindings/python/IREECompilerDialectsModule.cpp index fed29d947a61..3e69fbfc8018 100644 --- a/compiler/bindings/python/IREECompilerDialectsModule.cpp +++ b/compiler/bindings/python/IREECompilerDialectsModule.cpp @@ -811,15 +811,4 @@ NB_MODULE(_ireeCompilerDialects, m) { &ireeCodegenInferScaledContractionDimensions, "Infers the scaled contraction dimensions for a given operation.", py::arg("op")); - - iree_codegen_module.def( - "infer_scaled_contraction_dimensions_from_maps", - [](const std::vector &indexingMaps) - -> ireeCodegenScaledContractionDimensions { - return ireeCodegenInferScaledContractionDimensionsFromMaps( - indexingMaps.data(), indexingMaps.size()); - }, - "Infers the scaled contraction dimensions for a given operation from " - "indexing maps.", - py::arg("indexing_maps")); } diff --git a/compiler/bindings/python/test/api/tuner_api_test.py b/compiler/bindings/python/test/api/tuner_api_test.py index f7c4083c77a3..9c521b393ad0 100644 --- a/compiler/bindings/python/test/api/tuner_api_test.py +++ b/compiler/bindings/python/test/api/tuner_api_test.py @@ -85,7 +85,7 @@ def attention_op_detail(): ) # (d0, d3, d2). v_map = affine.AffineMap.get( 5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[4]] - ) # (d0, d3, d4). # () + ) # (d0, d3, d4). o_map = affine.AffineMap.get( 5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[4]] ) # (d0, d1, d4). @@ -340,7 +340,7 @@ def test_igemm_conv_details(): @run def test_isa_scaled_contraction_op(): - # Test 1: Regular matmul is not a scaled contraction + # Test 1: Regular matmul is not a scaled contraction. module_str = """ module { func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> { @@ -355,12 +355,12 @@ def test_isa_scaled_contraction_op(): assert len(root_op_list) == 1 matmul_op = root_op_list[0] - # Regular matmul should not be a scaled contraction + # Regular matmul should not be a scaled contraction. assert not iree_codegen.isa_scaled_contraction_op( matmul_op ), "Regular matmul should not be a scaled contraction" - # Test 2: Fill op is not a scaled contraction + # Test 2: Fill op is not a scaled contraction. module_str = """ module { func.func @fill(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { @@ -379,10 +379,10 @@ def test_isa_scaled_contraction_op(): fill_op ), "Fill op should not be a scaled contraction" - # Test 3: Scaled matmul as linalg.generic should be detected - # Pattern: linalg.generic with 5 indexing maps (lhs, rhs, lhs_scale, rhs_scale, output) - # and 4 iterator types (2 parallel for M,N; 2 reduction for Ko,Kb) - # Uses f4E2M1FN for operands and f8E8M0FNU for scales (matching real scaled matmul pattern) + # Test 3: Scaled matmul as linalg.generic should be detected. + # Pattern: linalg.generic with 5 indexing maps (lhs, rhs, lhs_scale, rhs_scale, output), + # and 4 iterator types (2 parallel for M,N; 2 reduction for Ko,Kb). + # Uses f4E2M1FN for operands and f8E8M0FNU for scales (matching real scaled matmul pattern). module_str = """ module { func.func @scaled_matmul(%lhs: tensor<16x4x32xf4E2M1FN>, %rhs: tensor<16x4x32xf4E2M1FN>, @@ -416,17 +416,17 @@ def test_isa_scaled_contraction_op(): assert len(root_op_list) == 1, "Should have one root op" scaled_generic_op = root_op_list[0] - # Check if it's recognized as a scaled contraction + # Check if it's recognized as a scaled contraction. is_scaled = iree_codegen.isa_scaled_contraction_op(scaled_generic_op) assert ( is_scaled ), "linalg.generic with scaled matmul pattern should be detected as scaled contraction" - # Try to infer dimensions + # Try to infer dimensions. dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_generic_op) assert dims is not None, "Should be able to infer dimensions for scaled contraction" - # Expected: m=[0], n=[1], k=[2], kB=[3] for the given indexing maps + # Expected: m=[0], n=[1], k=[2], kB=[3] for the given indexing maps. assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" @@ -436,7 +436,7 @@ def test_isa_scaled_contraction_op(): @run def test_infer_scaled_contraction_dimensions(): - # Test 1: Verify dimension inference on a scaled matmul operation + # Test 1: Verify dimension inference on a scaled matmul operation. module_str = """ module { func.func @scaled_matmul(%lhs: tensor<16x4x32xf4E2M1FN>, %rhs: tensor<16x4x32xf4E2M1FN>, @@ -470,28 +470,28 @@ def test_infer_scaled_contraction_dimensions(): assert len(root_op_list) == 1, "Should have exactly one root op" scaled_op = root_op_list[0] - # Verify it's a scaled contraction first + # Verify it's a scaled contraction first. assert iree_codegen.isa_scaled_contraction_op( scaled_op ), "Operation should be recognized as scaled contraction" - # Test dimension inference + # Test dimension inference. dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_op) assert dims is not None, "Should successfully infer dimensions" - # Verify the inferred dimensions match expected values + # Verify the inferred dimensions match expected values. # For the given indexing maps: - # d0 = M (parallel) -> m - # d1 = N (parallel) -> n - # d2 = Ko (reduction) -> k - # d3 = Kb (reduction, block dim) -> kB + # d0 = M (parallel) -> m, + # d1 = N (parallel) -> n, + # d2 = Ko (reduction) -> k, + # d3 = Kb (reduction, block dim) -> kB. assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" assert list(dims.kB) == [3], f"Expected kB=[3], got {list(dims.kB)}" assert list(dims.batch) == [], f"Expected no batch dims, got {list(dims.batch)}" - # Test 2: Non-scaled contraction should return None + # Test 2: Non-scaled contraction should return None. module_str_regular = """ module { func.func @regular_matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> { @@ -505,9 +505,9 @@ def test_infer_scaled_contraction_dimensions(): assert len(regular_ops) == 1 regular_matmul = regular_ops[0] - # Regular matmul should not have scaled contraction dimensions + # Regular matmul should not have scaled contraction dimensions. dims_regular = iree_codegen.infer_scaled_contraction_dimensions(regular_matmul) - # Check if all dimensions are empty (indicating it's not a scaled contraction) + # Check if all dimensions are empty (indicating it's not a scaled contraction). if dims_regular is not None: all_empty = ( len(list(dims_regular.m)) == 0 @@ -520,7 +520,7 @@ def test_infer_scaled_contraction_dimensions(): all_empty or dims_regular is None ), "Regular matmul should not have valid scaled contraction dimensions" - # Test 3: Batched scaled matmul + # Test 3: Batched scaled matmul. module_str_batched = """ module { func.func @batched_scaled_matmul(%lhs: tensor<8x16x4x32xf4E2M1FN>, %rhs: tensor<8x16x4x32xf4E2M1FN>, @@ -559,7 +559,7 @@ def test_infer_scaled_contraction_dimensions(): dims_batched = iree_codegen.infer_scaled_contraction_dimensions(batched_op) if dims_batched is not None: - # Expected: batch=[0], m=[1], n=[2], k=[3], kB=[4] + # Expected: batch=[0], m=[1], n=[2], k=[3], kB=[4]. assert list(dims_batched.batch) == [ 0 ], f"Expected batch=[0], got {list(dims_batched.batch)}" @@ -575,40 +575,3 @@ def test_infer_scaled_contraction_dimensions(): assert list(dims_batched.kB) == [ 4 ], f"Expected kB=[4], got {list(dims_batched.kB)}" - - -@run -def test_infer_scaled_contraction_dimensions_from_maps(): - # Test inferring scaled contraction dimensions from affine maps - # This follows the pattern of a scaled matmul with block scaling - # Pattern: (M, N, Ko, Kb) where Ko is the outer reduction and Kb is the block dimension - d0, d1, d2, d3 = [AffineDimExpr.get(i) for i in range(4)] - - # Maps for scaled contraction matching the example: - # lhs_map: (M, Ko, Kb) - left operand with outer and block reduction dims - # rhs_map: (N, Ko, Kb) - right operand with outer and block reduction dims - # lhs_scale_map: (M, Ko) - left scale factors indexed by M and Ko - # rhs_scale_map: (N, Ko) - right scale factors indexed by N and Ko - # out_map: (M, N) - output indexed by parallel dims only - - lhs_map = AffineMap.get(4, 0, [d0, d2, d3]) # (M, Ko, Kb) - rhs_map = AffineMap.get(4, 0, [d1, d2, d3]) # (N, Ko, Kb) - lhs_scale_map = AffineMap.get(4, 0, [d0, d2]) # (M, Ko) - rhs_scale_map = AffineMap.get(4, 0, [d1, d2]) # (N, Ko) - out_map = AffineMap.get(4, 0, [d0, d1]) # (M, N) - - # Call the inference function - dims = iree_codegen.infer_scaled_contraction_dimensions_from_maps( - [lhs_map, rhs_map, lhs_scale_map, rhs_scale_map, out_map] - ) - - assert ( - dims is not None - ), "Should be able to infer scaled contraction dimensions from maps" - # Verify the inferred dimensions - # Expected: m=[0] (d0), n=[1] (d1), k=[2] (d2/Ko), kB=[3] (d3/Kb) - assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" - assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" - assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" - assert list(dims.kB) == [3], f"Expected kB=[3], got {list(dims.kB)}" - assert list(dims.batch) == [], f"Expected no batch dims, got {list(dims.batch)}" diff --git a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp index 4519683594ce..c2382ad3dbbf 100644 --- a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp +++ b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp @@ -349,12 +349,12 @@ ireeCodegenInferScaledContractionDimensions(MlirOperation op) { &scaledContractionDims = *maybeDims; mlir::MLIRContext *ctx = linalgOp.getContext(); - auto toAttr = [ctx](llvm::ArrayRef vals) -> MlirAttribute { + auto toAttr = [&ctx](llvm::ArrayRef vals) -> MlirAttribute { mlir::Builder b(ctx); - llvm::SmallVector attrs; - for (unsigned val : vals) { - attrs.push_back(b.getI32IntegerAttr(val)); - } + llvm::SmallVector attrs = + llvm::map_to_vector(vals, [&b](unsigned val) -> mlir::Attribute { + return b.getI32IntegerAttr(val); + }); return wrap(b.getArrayAttr(attrs)); }; @@ -365,43 +365,3 @@ ireeCodegenInferScaledContractionDimensions(MlirOperation op) { result.kB = toAttr(scaledContractionDims.kB); return result; } - -ireeCodegenScaledContractionDimensions -ireeCodegenInferScaledContractionDimensionsFromMaps( - const MlirAffineMap *indexingMaps, size_t numMaps) { - ireeCodegenScaledContractionDimensions result{}; - if (!indexingMaps || numMaps == 0) { - return result; - } - - llvm::SmallVector maps; - for (size_t i = 0; i < numMaps; ++i) { - maps.push_back(unwrap(indexingMaps[i])); - } - - llvm::FailureOr< - mlir::iree_compiler::IREE::LinalgExt::ScaledContractionDimensions> - maybeDims = - mlir::iree_compiler::IREE::LinalgExt::inferScaledContractionDims( - maps); - if (failed(maybeDims)) { - return result; - } - - mlir::MLIRContext *ctx = maps[0].getContext(); - auto toAttr = [ctx](llvm::ArrayRef vals) -> MlirAttribute { - mlir::Builder b(ctx); - llvm::SmallVector attrs; - for (unsigned val : vals) { - attrs.push_back(b.getI32IntegerAttr(val)); - } - return wrap(b.getArrayAttr(attrs)); - }; - - result.batch = toAttr(maybeDims->batch); - result.m = toAttr(maybeDims->m); - result.n = toAttr(maybeDims->n); - result.k = toAttr(maybeDims->k); - result.kB = toAttr(maybeDims->kB); - return result; -} diff --git a/compiler/src/iree/compiler/API/api_exports.c b/compiler/src/iree/compiler/API/api_exports.c index 10f86ca1029d..edd1225342ef 100644 --- a/compiler/src/iree/compiler/API/api_exports.c +++ b/compiler/src/iree/compiler/API/api_exports.c @@ -30,7 +30,9 @@ extern void ireeCodegenGetExecutableVariantOps(); extern void ireeGPUGetSingleSubgroupLayout(); extern void ireeCodegenGetTunerRootOps(); extern void ireeCodegenGetAttentionOpDetail(); +extern void ireeCodegenInferScaledContractionDimensions(); extern void ireeCodegenMlirOperationIsACodegenAttentionOp(); +extern void ireeCodegenMlirOperationIsAScaledContractionOp(); extern void ireeCodegenHasIGEMMGenericConvDetails(); extern void ireeCodegenGetIGEMMGenericConvDetails(); extern void ireeCodegenTranslationInfoAttrGet(); @@ -949,7 +951,9 @@ uintptr_t __iree_compiler_hidden_force_extern() { x += (uintptr_t)&ireeGPUGetSingleSubgroupLayout; x += (uintptr_t)&ireeCodegenGetTunerRootOps; x += (uintptr_t)&ireeCodegenGetAttentionOpDetail; + x += (uintptr_t)&ireeCodegenInferScaledContractionDimensions; x += (uintptr_t)&ireeCodegenMlirOperationIsACodegenAttentionOp; + x += (uintptr_t)&ireeCodegenMlirOperationIsAScaledContractionOp; x += (uintptr_t)&ireeCodegenHasIGEMMGenericConvDetails; x += (uintptr_t)&ireeCodegenGetIGEMMGenericConvDetails; x += (uintptr_t)&ireeCodegenTranslationInfoAttrGet; diff --git a/compiler/src/iree/compiler/API/api_exports.def b/compiler/src/iree/compiler/API/api_exports.def index 110fbbf7a6ea..4472e8ad8776 100644 --- a/compiler/src/iree/compiler/API/api_exports.def +++ b/compiler/src/iree/compiler/API/api_exports.def @@ -20,7 +20,9 @@ EXPORTS ireeGPUGetSingleSubgroupLayout ireeCodegenGetTunerRootOps ireeCodegenGetAttentionOpDetail + ireeCodegenInferScaledContractionDimensions ireeCodegenMlirOperationIsACodegenAttentionOp + ireeCodegenMlirOperationIsAScaledContractionOp ireeCodegenHasIGEMMGenericConvDetails ireeCodegenGetIGEMMGenericConvDetails ireeCodegenTranslationInfoAttrGet diff --git a/compiler/src/iree/compiler/API/api_exports.ld b/compiler/src/iree/compiler/API/api_exports.ld index 92ff2ba39882..2bb31329aa75 100644 --- a/compiler/src/iree/compiler/API/api_exports.ld +++ b/compiler/src/iree/compiler/API/api_exports.ld @@ -21,7 +21,9 @@ VER_0 { ireeGPUGetSingleSubgroupLayout; ireeCodegenGetTunerRootOps; ireeCodegenGetAttentionOpDetail; + ireeCodegenInferScaledContractionDimensions; ireeCodegenMlirOperationIsACodegenAttentionOp; + ireeCodegenMlirOperationIsAScaledContractionOp; ireeCodegenHasIGEMMGenericConvDetails; ireeCodegenGetIGEMMGenericConvDetails; ireeCodegenTranslationInfoAttrGet; diff --git a/compiler/src/iree/compiler/API/api_exports.macos.lst b/compiler/src/iree/compiler/API/api_exports.macos.lst index bf8ad1a17223..564b434a7aba 100644 --- a/compiler/src/iree/compiler/API/api_exports.macos.lst +++ b/compiler/src/iree/compiler/API/api_exports.macos.lst @@ -19,7 +19,9 @@ _ireeCodegenGetExecutableVariantOps _ireeGPUGetSingleSubgroupLayout _ireeCodegenGetTunerRootOps _ireeCodegenGetAttentionOpDetail +_ireeCodegenInferScaledContractionDimensions _ireeCodegenMlirOperationIsACodegenAttentionOp +_ireeCodegenMlirOperationIsAScaledContractionOp _ireeCodegenHasIGEMMGenericConvDetails _ireeCodegenGetIGEMMGenericConvDetails _ireeCodegenTranslationInfoAttrGet From 0a4039941f01d82d458cd842af127d252a960712 Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Wed, 26 Nov 2025 13:04:59 -0600 Subject: [PATCH 4/6] Adressing PR review 2 Signed-off-by: Muzammiluddin Syed --- .../c/iree/compiler/dialects/iree_codegen.h | 5 ++ .../python/test/api/tuner_api_test.py | 52 +++++++------------ .../API/Internal/IREECodegenDialectCAPI.cpp | 5 +- 3 files changed, 27 insertions(+), 35 deletions(-) diff --git a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h index f6e4219fae02..4dad761b06c8 100644 --- a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h +++ b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h @@ -120,10 +120,15 @@ MLIR_CAPI_EXPORTED ireeCodegenIGEMMGenericConvDetails ireeCodegenGetIGEMMGenericConvDetails(MlirOperation op); struct ireeCodegenScaledContractionDimensions { + // Batch dimension for scaled contraction (ArrayAttr) MlirAttribute batch; + // M dimension for scaled contraction (ArrayAttr) MlirAttribute m; + // N dimension for scaled contraction (ArrayAttr) MlirAttribute n; + // K outer reduction dimension for scaled contraction (ArrayAttr) MlirAttribute k; + // K blocking dimension for scaled contraction (ArrayAttr) MlirAttribute kB; }; diff --git a/compiler/bindings/python/test/api/tuner_api_test.py b/compiler/bindings/python/test/api/tuner_api_test.py index 9c521b393ad0..f971f1625f9e 100644 --- a/compiler/bindings/python/test/api/tuner_api_test.py +++ b/compiler/bindings/python/test/api/tuner_api_test.py @@ -426,12 +426,11 @@ def test_isa_scaled_contraction_op(): dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_generic_op) assert dims is not None, "Should be able to infer dimensions for scaled contraction" - # Expected: m=[0], n=[1], k=[2], kB=[3] for the given indexing maps. - assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" - assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" - assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" - assert list(dims.kB) == [3], f"Expected kB=[3], got {list(dims.kB)}" - assert list(dims.batch) == [], f"Expected no batch dims, got {list(dims.batch)}" + assert dims.m == [0], f"Got {dims.m}" + assert dims.n == [1], f"Got {dims.n}" + assert dims.k == [2], f"Got {dims.k}" + assert dims.kB == [3], f"Got {dims.kB}" + assert dims.batch == [], f"Got {dims.batch}" @run @@ -485,11 +484,11 @@ def test_infer_scaled_contraction_dimensions(): # d1 = N (parallel) -> n, # d2 = Ko (reduction) -> k, # d3 = Kb (reduction, block dim) -> kB. - assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" - assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" - assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" - assert list(dims.kB) == [3], f"Expected kB=[3], got {list(dims.kB)}" - assert list(dims.batch) == [], f"Expected no batch dims, got {list(dims.batch)}" + assert dims.m == [0], f"Got {dims.m}" + assert dims.n == [1], f"Got {dims.n}" + assert dims.k == [2], f"Got {dims.k}" + assert dims.kB == [3], f"Got {dims.kB}" + assert dims.batch == [], f"Got {dims.batch}" # Test 2: Non-scaled contraction should return None. module_str_regular = """ @@ -510,11 +509,11 @@ def test_infer_scaled_contraction_dimensions(): # Check if all dimensions are empty (indicating it's not a scaled contraction). if dims_regular is not None: all_empty = ( - len(list(dims_regular.m)) == 0 - and len(list(dims_regular.n)) == 0 - and len(list(dims_regular.k)) == 0 - and len(list(dims_regular.kB)) == 0 - and len(list(dims_regular.batch)) == 0 + len(dims_regular.m) == 0 + and len(dims_regular.n) == 0 + and len(dims_regular.k) == 0 + and len(dims_regular.kB) == 0 + and len(dims_regular.batch) == 0 ) assert ( all_empty or dims_regular is None @@ -559,19 +558,8 @@ def test_infer_scaled_contraction_dimensions(): dims_batched = iree_codegen.infer_scaled_contraction_dimensions(batched_op) if dims_batched is not None: - # Expected: batch=[0], m=[1], n=[2], k=[3], kB=[4]. - assert list(dims_batched.batch) == [ - 0 - ], f"Expected batch=[0], got {list(dims_batched.batch)}" - assert list(dims_batched.m) == [ - 1 - ], f"Expected m=[1], got {list(dims_batched.m)}" - assert list(dims_batched.n) == [ - 2 - ], f"Expected n=[2], got {list(dims_batched.n)}" - assert list(dims_batched.k) == [ - 3 - ], f"Expected k=[3], got {list(dims_batched.k)}" - assert list(dims_batched.kB) == [ - 4 - ], f"Expected kB=[4], got {list(dims_batched.kB)}" + assert dims_batched.batch == [0], f"Got {dims_batched.batch}" + assert dims_batched.m == [1], f"Got {dims_batched.m}" + assert dims_batched.n == [2], f"Got {dims_batched.n}" + assert dims_batched.k == [3], f"Got {dims_batched.k}" + assert dims_batched.kB == [4], f"Got {dims_batched.kB}" diff --git a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp index c2382ad3dbbf..271fb333e732 100644 --- a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp +++ b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp @@ -348,9 +348,8 @@ ireeCodegenInferScaledContractionDimensions(MlirOperation op) { const mlir::iree_compiler::IREE::LinalgExt::ScaledContractionDimensions &scaledContractionDims = *maybeDims; mlir::MLIRContext *ctx = linalgOp.getContext(); - - auto toAttr = [&ctx](llvm::ArrayRef vals) -> MlirAttribute { - mlir::Builder b(ctx); + mlir::Builder b(ctx); + auto toAttr = [&b](llvm::ArrayRef vals) -> MlirAttribute { llvm::SmallVector attrs = llvm::map_to_vector(vals, [&b](unsigned val) -> mlir::Attribute { return b.getI32IntegerAttr(val); From d09b9721bfb1449651dfa07b8bbc9f5db459794f Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Wed, 26 Nov 2025 13:58:46 -0600 Subject: [PATCH 5/6] Adress pr comments pt 3. Signed-off-by: Muzammiluddin Syed --- .../c/iree/compiler/dialects/iree_codegen.h | 10 +++--- .../python/test/api/tuner_api_test.py | 32 ++++++------------- 2 files changed, 15 insertions(+), 27 deletions(-) diff --git a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h index 4dad761b06c8..abad5ba749ca 100644 --- a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h +++ b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h @@ -120,15 +120,15 @@ MLIR_CAPI_EXPORTED ireeCodegenIGEMMGenericConvDetails ireeCodegenGetIGEMMGenericConvDetails(MlirOperation op); struct ireeCodegenScaledContractionDimensions { - // Batch dimension for scaled contraction (ArrayAttr) + // Batch dimension for scaled contraction (ArrayAttr). MlirAttribute batch; - // M dimension for scaled contraction (ArrayAttr) + // M dimension for scaled contraction (ArrayAttr). MlirAttribute m; - // N dimension for scaled contraction (ArrayAttr) + // N dimension for scaled contraction (ArrayAttr). MlirAttribute n; - // K outer reduction dimension for scaled contraction (ArrayAttr) + // K outer reduction dimension for scaled contraction (ArrayAttr). MlirAttribute k; - // K blocking dimension for scaled contraction (ArrayAttr) + // K blocking dimension for scaled contraction (ArrayAttr). MlirAttribute kB; }; diff --git a/compiler/bindings/python/test/api/tuner_api_test.py b/compiler/bindings/python/test/api/tuner_api_test.py index f971f1625f9e..c4c4a0be5080 100644 --- a/compiler/bindings/python/test/api/tuner_api_test.py +++ b/compiler/bindings/python/test/api/tuner_api_test.py @@ -77,18 +77,10 @@ def root_op(): def attention_op_detail(): dim_exprs = [affine.AffineDimExpr.get(i) for i in range(5)] - q_map = affine.AffineMap.get( - 5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[2]] - ) # (d0, d1, d2). - k_map = affine.AffineMap.get( - 5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[2]] - ) # (d0, d3, d2). - v_map = affine.AffineMap.get( - 5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[4]] - ) # (d0, d3, d4). - o_map = affine.AffineMap.get( - 5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[4]] - ) # (d0, d1, d4). + q_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[2]]) + k_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[2]]) + v_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[4]]) + o_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[4]]) result = iree_codegen.get_attention_op_detail(q_map, k_map, v_map, o_map) @@ -102,10 +94,10 @@ def attention_op_detail(): dim_exprs = [affine.AffineDimExpr.get(i) for i in range(4)] # Input affine maps that do not follow the expected pattern for an attention operation. - q_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]]) # (d0, d1). - k_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[2]]) # (d0, d2). - v_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[3]]) # (d0, d3). - o_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]]) # (d0, d1). + q_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]]) + k_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[2]]) + v_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[3]]) + o_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]]) result = iree_codegen.get_attention_op_detail(q_map, k_map, v_map, o_map) assert result.domain_rank == 4 @@ -355,7 +347,6 @@ def test_isa_scaled_contraction_op(): assert len(root_op_list) == 1 matmul_op = root_op_list[0] - # Regular matmul should not be a scaled contraction. assert not iree_codegen.isa_scaled_contraction_op( matmul_op ), "Regular matmul should not be a scaled contraction" @@ -414,15 +405,13 @@ def test_isa_scaled_contraction_op(): input_module = ir.Module.parse(module_str) root_op_list = iree_codegen.get_tuner_root_ops(input_module) assert len(root_op_list) == 1, "Should have one root op" - scaled_generic_op = root_op_list[0] - # Check if it's recognized as a scaled contraction. + scaled_generic_op = root_op_list[0] is_scaled = iree_codegen.isa_scaled_contraction_op(scaled_generic_op) assert ( is_scaled ), "linalg.generic with scaled matmul pattern should be detected as scaled contraction" - # Try to infer dimensions. dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_generic_op) assert dims is not None, "Should be able to infer dimensions for scaled contraction" @@ -474,7 +463,6 @@ def test_infer_scaled_contraction_dimensions(): scaled_op ), "Operation should be recognized as scaled contraction" - # Test dimension inference. dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_op) assert dims is not None, "Should successfully infer dimensions" @@ -505,8 +493,8 @@ def test_infer_scaled_contraction_dimensions(): regular_matmul = regular_ops[0] # Regular matmul should not have scaled contraction dimensions. - dims_regular = iree_codegen.infer_scaled_contraction_dimensions(regular_matmul) # Check if all dimensions are empty (indicating it's not a scaled contraction). + dims_regular = iree_codegen.infer_scaled_contraction_dimensions(regular_matmul) if dims_regular is not None: all_empty = ( len(dims_regular.m) == 0 From c14199c943f19860adc7dd532b746901b8378305 Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Wed, 26 Nov 2025 18:40:00 -0600 Subject: [PATCH 6/6] address pr comments pt. 4 Signed-off-by: Muzammiluddin Syed --- .../python/test/api/tuner_api_test.py | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/compiler/bindings/python/test/api/tuner_api_test.py b/compiler/bindings/python/test/api/tuner_api_test.py index c4c4a0be5080..7387a4ea0a6c 100644 --- a/compiler/bindings/python/test/api/tuner_api_test.py +++ b/compiler/bindings/python/test/api/tuner_api_test.py @@ -458,20 +458,12 @@ def test_infer_scaled_contraction_dimensions(): assert len(root_op_list) == 1, "Should have exactly one root op" scaled_op = root_op_list[0] - # Verify it's a scaled contraction first. assert iree_codegen.isa_scaled_contraction_op( scaled_op ), "Operation should be recognized as scaled contraction" dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_op) assert dims is not None, "Should successfully infer dimensions" - - # Verify the inferred dimensions match expected values. - # For the given indexing maps: - # d0 = M (parallel) -> m, - # d1 = N (parallel) -> n, - # d2 = Ko (reduction) -> k, - # d3 = Kb (reduction, block dim) -> kB. assert dims.m == [0], f"Got {dims.m}" assert dims.n == [1], f"Got {dims.n}" assert dims.k == [2], f"Got {dims.k}" @@ -538,16 +530,18 @@ def test_infer_scaled_contraction_dimensions(): """ input_module_batched = ir.Module.parse(module_str_batched) batched_ops = iree_codegen.get_tuner_root_ops(input_module_batched) - if len(batched_ops) == 1: - batched_op = batched_ops[0] - assert iree_codegen.isa_scaled_contraction_op( - batched_op - ), "Batched scaled matmul should be recognized" - - dims_batched = iree_codegen.infer_scaled_contraction_dimensions(batched_op) - if dims_batched is not None: - assert dims_batched.batch == [0], f"Got {dims_batched.batch}" - assert dims_batched.m == [1], f"Got {dims_batched.m}" - assert dims_batched.n == [2], f"Got {dims_batched.n}" - assert dims_batched.k == [3], f"Got {dims_batched.k}" - assert dims_batched.kB == [4], f"Got {dims_batched.kB}" + assert len(batched_ops) == 1, "Batched op should be found" + batched_op = batched_ops[0] + assert iree_codegen.isa_scaled_contraction_op( + batched_op + ), "Batched scaled matmul should be recognized" + + dims_batched = iree_codegen.infer_scaled_contraction_dimensions(batched_op) + assert ( + dims_batched is not None + ), "Batch dimension must be present in batched scaled matmul" + assert dims_batched.batch == [0], f"Got {dims_batched.batch}" + assert dims_batched.m == [1], f"Got {dims_batched.m}" + assert dims_batched.n == [2], f"Got {dims_batched.n}" + assert dims_batched.k == [3], f"Got {dims_batched.k}" + assert dims_batched.kB == [4], f"Got {dims_batched.kB}"