Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions compiler/bindings/c/iree/compiler/dialects/iree_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,25 @@ MLIR_CAPI_EXPORTED bool ireeCodegenHasIGEMMGenericConvDetails(MlirOperation op);
MLIR_CAPI_EXPORTED ireeCodegenIGEMMGenericConvDetails
ireeCodegenGetIGEMMGenericConvDetails(MlirOperation op);

struct ireeCodegenScaledContractionDimensions {
// Batch dimension for scaled contraction (ArrayAttr).
MlirAttribute batch;
// M dimension for scaled contraction (ArrayAttr).
MlirAttribute m;
// N dimension for scaled contraction (ArrayAttr).
MlirAttribute n;
// K outer reduction dimension for scaled contraction (ArrayAttr).
MlirAttribute k;
// K blocking dimension for scaled contraction (ArrayAttr).
MlirAttribute kB;
};

MLIR_CAPI_EXPORTED bool
ireeCodegenMlirOperationIsAScaledContractionOp(MlirOperation op);

MLIR_CAPI_EXPORTED ireeCodegenScaledContractionDimensions
ireeCodegenInferScaledContractionDimensions(MlirOperation op);

#ifdef __cplusplus
}
#endif
Expand Down
41 changes: 41 additions & 0 deletions compiler/bindings/python/IREECompilerDialectsModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,4 +770,45 @@ NB_MODULE(_ireeCompilerDialects, m) {
"Gets IGEMM details for a linalg operation. "
"Returns None if failed to infer IGEMM convolution details.",
py::arg("linalg_op"));

//===-------------------------------------------------------------------===//
// Binding to utility function ireeCodegenGetScaledContractionDetails
//===-------------------------------------------------------------------===//
iree_codegen_module.def("isa_scaled_contraction_op",
&ireeCodegenMlirOperationIsAScaledContractionOp,
"Checks if the given operation is an IREE LinalgExt "
"scaled contraction op.",
py::arg("op"));

//===-------------------------------------------------------------------===//
// Binding to struct ireeCodegenScaledContractionDimensions
//===-------------------------------------------------------------------===//
py::class_<ireeCodegenScaledContractionDimensions>(
iree_codegen_module, "ScaledContractionDimensions")
.def_prop_ro("batch",
[](const ireeCodegenScaledContractionDimensions &self) {
return getIntArrayAttrValues(self.batch);
})
.def_prop_ro("m",
[](const ireeCodegenScaledContractionDimensions &self) {
return getIntArrayAttrValues(self.m);
})
.def_prop_ro("n",
[](const ireeCodegenScaledContractionDimensions &self) {
return getIntArrayAttrValues(self.n);
})
.def_prop_ro("k",
[](const ireeCodegenScaledContractionDimensions &self) {
return getIntArrayAttrValues(self.k);
})
.def_prop_ro("kB",
[](const ireeCodegenScaledContractionDimensions &self) {
return getIntArrayAttrValues(self.kB);
});

iree_codegen_module.def(
"infer_scaled_contraction_dimensions",
&ireeCodegenInferScaledContractionDimensions,
"Infers the scaled contraction dimensions for a given operation.",
py::arg("op"));
}
241 changes: 225 additions & 16 deletions compiler/bindings/python/test/api/tuner_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,10 @@ def root_op():
def attention_op_detail():
dim_exprs = [affine.AffineDimExpr.get(i) for i in range(5)]

q_map = affine.AffineMap.get(
5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[2]]
) # (d0, d1, d2).
k_map = affine.AffineMap.get(
5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[2]]
) # (d0, d3, d2).
v_map = affine.AffineMap.get(
5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[4]]
) # (d0, d3, d4). # ()
o_map = affine.AffineMap.get(
5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[4]]
) # (d0, d1, d4).
q_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[2]])
k_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[2]])
v_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[4]])
o_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[4]])

result = iree_codegen.get_attention_op_detail(q_map, k_map, v_map, o_map)

Expand All @@ -102,10 +94,10 @@ def attention_op_detail():
dim_exprs = [affine.AffineDimExpr.get(i) for i in range(4)]

# Input affine maps that do not follow the expected pattern for an attention operation.
q_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]]) # (d0, d1).
k_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[2]]) # (d0, d2).
v_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[3]]) # (d0, d3).
o_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]]) # (d0, d1).
q_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]])
k_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[2]])
v_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[3]])
o_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]])

result = iree_codegen.get_attention_op_detail(q_map, k_map, v_map, o_map)
assert result.domain_rank == 4
Expand Down Expand Up @@ -336,3 +328,220 @@ def test_igemm_conv_details():

details = iree_codegen.get_igemm_generic_conv_details(matmul_op)
assert details is None, "IGEMM details should be None for non-conv operation"


@run
def test_isa_scaled_contraction_op():
# Test 1: Regular matmul is not a scaled contraction.
module_str = """
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
%0 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1
matmul_op = root_op_list[0]

assert not iree_codegen.isa_scaled_contraction_op(
matmul_op
), "Regular matmul should not be a scaled contraction"

# Test 2: Fill op is not a scaled contraction.
module_str = """
module {
func.func @fill(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill { root_op } ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1
fill_op = root_op_list[0]

assert not iree_codegen.isa_scaled_contraction_op(
fill_op
), "Fill op should not be a scaled contraction"

# Test 3: Scaled matmul as linalg.generic should be detected.
# Pattern: linalg.generic with 5 indexing maps (lhs, rhs, lhs_scale, rhs_scale, output),
# and 4 iterator types (2 parallel for M,N; 2 reduction for Ko,Kb).
# Uses f4E2M1FN for operands and f8E8M0FNU for scales (matching real scaled matmul pattern).
module_str = """
module {
func.func @scaled_matmul(%lhs: tensor<16x4x32xf4E2M1FN>, %rhs: tensor<16x4x32xf4E2M1FN>,
%lhs_scales: tensor<16x4xf8E8M0FNU>, %rhs_scales: tensor<16x4xf8E8M0FNU>,
%out: tensor<16x16xf32>) -> tensor<16x16xf32> {
%result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>
],
iterator_types = ["parallel", "parallel", "reduction", "reduction"],
root_op
} ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<16x4x32xf4E2M1FN>, tensor<16x4x32xf4E2M1FN>, tensor<16x4xf8E8M0FNU>, tensor<16x4xf8E8M0FNU>)
outs(%out : tensor<16x16xf32>) {
^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32):
%a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32
%b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32
%prod = arith.mulf %a_scaled, %b_scaled : f32
%sum = arith.addf %acc, %prod : f32
linalg.yield %sum : f32
} -> tensor<16x16xf32>
return %result : tensor<16x16xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1, "Should have one root op"

scaled_generic_op = root_op_list[0]
is_scaled = iree_codegen.isa_scaled_contraction_op(scaled_generic_op)
assert (
is_scaled
), "linalg.generic with scaled matmul pattern should be detected as scaled contraction"

dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_generic_op)
assert dims is not None, "Should be able to infer dimensions for scaled contraction"

assert dims.m == [0], f"Got {dims.m}"
assert dims.n == [1], f"Got {dims.n}"
assert dims.k == [2], f"Got {dims.k}"
assert dims.kB == [3], f"Got {dims.kB}"
assert dims.batch == [], f"Got {dims.batch}"


@run
def test_infer_scaled_contraction_dimensions():
# Test 1: Verify dimension inference on a scaled matmul operation.
module_str = """
module {
func.func @scaled_matmul(%lhs: tensor<16x4x32xf4E2M1FN>, %rhs: tensor<16x4x32xf4E2M1FN>,
%lhs_scales: tensor<16x4xf8E8M0FNU>, %rhs_scales: tensor<16x4xf8E8M0FNU>,
%out: tensor<16x16xf32>) -> tensor<16x16xf32> {
%result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>
],
iterator_types = ["parallel", "parallel", "reduction", "reduction"],
root_op
} ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<16x4x32xf4E2M1FN>, tensor<16x4x32xf4E2M1FN>, tensor<16x4xf8E8M0FNU>, tensor<16x4xf8E8M0FNU>)
outs(%out : tensor<16x16xf32>) {
^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32):
%a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32
%b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32
%prod = arith.mulf %a_scaled, %b_scaled : f32
%sum = arith.addf %acc, %prod : f32
linalg.yield %sum : f32
} -> tensor<16x16xf32>
return %result : tensor<16x16xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1, "Should have exactly one root op"
scaled_op = root_op_list[0]

assert iree_codegen.isa_scaled_contraction_op(
scaled_op
), "Operation should be recognized as scaled contraction"

dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_op)
assert dims is not None, "Should successfully infer dimensions"
assert dims.m == [0], f"Got {dims.m}"
assert dims.n == [1], f"Got {dims.n}"
assert dims.k == [2], f"Got {dims.k}"
assert dims.kB == [3], f"Got {dims.kB}"
assert dims.batch == [], f"Got {dims.batch}"

# Test 2: Non-scaled contraction should return None.
module_str_regular = """
module {
func.func @regular_matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
%0 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
}
"""
input_module_regular = ir.Module.parse(module_str_regular)
regular_ops = iree_codegen.get_tuner_root_ops(input_module_regular)
assert len(regular_ops) == 1
regular_matmul = regular_ops[0]

# Regular matmul should not have scaled contraction dimensions.
# Check if all dimensions are empty (indicating it's not a scaled contraction).
dims_regular = iree_codegen.infer_scaled_contraction_dimensions(regular_matmul)
if dims_regular is not None:
all_empty = (
len(dims_regular.m) == 0
and len(dims_regular.n) == 0
and len(dims_regular.k) == 0
and len(dims_regular.kB) == 0
and len(dims_regular.batch) == 0
)
assert (
all_empty or dims_regular is None
), "Regular matmul should not have valid scaled contraction dimensions"

# Test 3: Batched scaled matmul.
module_str_batched = """
module {
func.func @batched_scaled_matmul(%lhs: tensor<8x16x4x32xf4E2M1FN>, %rhs: tensor<8x16x4x32xf4E2M1FN>,
%lhs_scales: tensor<8x16x4xf8E8M0FNU>, %rhs_scales: tensor<8x16x4xf8E8M0FNU>,
%out: tensor<8x16x16xf32>) -> tensor<8x16x16xf32> {
%result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
],
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"],
root_op
} ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<8x16x4x32xf4E2M1FN>, tensor<8x16x4x32xf4E2M1FN>, tensor<8x16x4xf8E8M0FNU>, tensor<8x16x4xf8E8M0FNU>)
outs(%out : tensor<8x16x16xf32>) {
^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32):
%a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32
%b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32
%prod = arith.mulf %a_scaled, %b_scaled : f32
%sum = arith.addf %acc, %prod : f32
linalg.yield %sum : f32
} -> tensor<8x16x16xf32>
return %result : tensor<8x16x16xf32>
}
}
"""
input_module_batched = ir.Module.parse(module_str_batched)
batched_ops = iree_codegen.get_tuner_root_ops(input_module_batched)
assert len(batched_ops) == 1, "Batched op should be found"
batched_op = batched_ops[0]
assert iree_codegen.isa_scaled_contraction_op(
batched_op
), "Batched scaled matmul should be recognized"

dims_batched = iree_codegen.infer_scaled_contraction_dimensions(batched_op)
assert (
dims_batched is not None
), "Batch dimension must be present in batched scaled matmul"
assert dims_batched.batch == [0], f"Got {dims_batched.batch}"
assert dims_batched.m == [1], f"Got {dims_batched.m}"
assert dims_batched.n == [2], f"Got {dims_batched.n}"
assert dims_batched.k == [3], f"Got {dims_batched.k}"
assert dims_batched.kB == [4], f"Got {dims_batched.kB}"
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/dialects/iree_codegen.h"
#include "mlir-c/BuiltinAttributes.h"
Expand Down Expand Up @@ -320,3 +321,46 @@ ireeCodegenGetIGEMMGenericConvDetails(MlirOperation op) {

return result;
}

bool ireeCodegenMlirOperationIsAScaledContractionOp(MlirOperation op) {
auto linalgOp = llvm::cast<mlir::linalg::LinalgOp>(unwrap(op));
return mlir::iree_compiler::IREE::LinalgExt::isaScaledContractionOpInterface(
linalgOp);
}

ireeCodegenScaledContractionDimensions
ireeCodegenInferScaledContractionDimensions(MlirOperation op) {
ireeCodegenScaledContractionDimensions result{};
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
if (!linalgOp) {
return result;
}

llvm::FailureOr<
mlir::iree_compiler::IREE::LinalgExt::ScaledContractionDimensions>
maybeDims =
mlir::iree_compiler::IREE::LinalgExt::inferScaledContractionDims(
linalgOp);
if (failed(maybeDims)) {
return result;
}

const mlir::iree_compiler::IREE::LinalgExt::ScaledContractionDimensions
&scaledContractionDims = *maybeDims;
mlir::MLIRContext *ctx = linalgOp.getContext();
mlir::Builder b(ctx);
auto toAttr = [&b](llvm::ArrayRef<unsigned> vals) -> MlirAttribute {
llvm::SmallVector<mlir::Attribute, 2> attrs =
llvm::map_to_vector(vals, [&b](unsigned val) -> mlir::Attribute {
return b.getI32IntegerAttr(val);
});
return wrap(b.getArrayAttr(attrs));
};

result.batch = toAttr(scaledContractionDims.batch);
result.m = toAttr(scaledContractionDims.m);
result.n = toAttr(scaledContractionDims.n);
result.k = toAttr(scaledContractionDims.k);
result.kB = toAttr(scaledContractionDims.kB);
return result;
}
Loading
Loading