Skip to content

Commit

Permalink
[mlir] expose transform dialect symbol merge to python (#87690)
Browse files Browse the repository at this point in the history
This functionality is available in C++, make it available in Python
directly to operate on transform modules.
  • Loading branch information
ftynse authored Apr 17, 2024
1 parent 971ec1f commit 73140da
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 2 deletions.
12 changes: 11 additions & 1 deletion mlir/include/mlir-c/Dialect/Transform/Interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ MLIR_CAPI_EXPORTED void
mlirTransformOptionsDestroy(MlirTransformOptions transformOptions);

//----------------------------------------------------------------------------//
// Transform interpreter.
// Transform interpreter and utilities.
//----------------------------------------------------------------------------//

/// Applies the transformation script starting at the given transform root
Expand All @@ -72,6 +72,16 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence(
MlirOperation payload, MlirOperation transformRoot,
MlirOperation transformModule, MlirTransformOptions transformOptions);

/// Merge the symbols from `other` into `target`, potentially renaming them to
/// avoid conflicts. Private symbols may be renamed during the merge, public
/// symbols must have at most one declaration. A name conflict in public symbols
/// is reported as an error before returning a failure.
///
/// Note that this clones the `other` operation unlike the C++ counterpart that
/// takes ownership.
MLIR_CAPI_EXPORTED MlirLogicalResult
mlirMergeSymbolsIntoFromClone(MlirOperation target, MlirOperation other);

#ifdef __cplusplus
}
#endif
15 changes: 15 additions & 0 deletions mlir/lib/Bindings/Python/TransformInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,21 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
py::arg("payload_root"), py::arg("transform_root"),
py::arg("transform_module"),
py::arg("transform_options") = PyMlirTransformOptions());

m.def(
"copy_symbols_and_merge_into",
[](MlirOperation target, MlirOperation other) {
mlir::python::CollectDiagnosticsToStringScope scope(
mlirOperationGetContext(target));

MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
if (mlirLogicalResultIsFailure(result)) {
throw py::value_error(
"Failed to merge symbols.\nDiagnostic message " +
scope.takeMessage());
}
},
py::arg("target"), py::arg("other"));
}

PYBIND11_MODULE(_mlirTransformInterpreter, m) {
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"

Expand Down Expand Up @@ -71,4 +72,12 @@ MlirLogicalResult mlirTransformApplyNamedSequence(
unwrap(payload), unwrap(transformRoot),
cast<ModuleOp>(unwrap(transformModule)), *unwrap(transformOptions)));
}

MlirLogicalResult mlirMergeSymbolsIntoFromClone(MlirOperation target,
MlirOperation other) {
OwningOpRef<Operation *> otherOwning(unwrap(other)->clone());
LogicalResult result = transform::detail::mergeSymbolsInto(
unwrap(target), std::move(otherOwning));
return wrap(result);
}
}
10 changes: 9 additions & 1 deletion mlir/python/mlir/dialects/transform/interpreter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ....ir import Operation
from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter


TransformOptions = _cextTransformInterpreter.TransformOptions


Expand All @@ -31,3 +30,12 @@ def apply_named_sequence(
_cextTransformInterpreter.apply_named_sequence(*args)
else:
_cextTransformInterpreter(*args, transform_options)


def copy_symbols_and_merge_into(target, other):
"""Copies symbols from other into target, renaming private symbols to avoid
duplicates. Raises an error if copying would lead to duplicate public
symbols."""
_cextTransformInterpreter.copy_symbols_and_merge_into(
_unpack_operation(target), _unpack_operation(other)
)
76 changes: 76 additions & 0 deletions mlir/test/python/dialects/transform_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,79 @@ def failed():
assert (
"must implement TransformOpInterface to be used as transform root" in str(e)
)


print_root_via_include_module = """
module @print_root_via_include_module attributes {transform.with_named_sequence} {
transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly})
transform.named_sequence @__transform_main(%root: !transform.any_op) {
transform.include @callee2 failures(propagate)
(%root) : (!transform.any_op) -> ()
transform.yield
}
}"""

callee2_definition = """
module attributes {transform.with_named_sequence} {
transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) {
transform.include @callee1 failures(propagate)
(%root) : (!transform.any_op) -> ()
transform.yield
}
}
"""

callee1_definition = """
module attributes {transform.with_named_sequence} {
transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) {
transform.print %root { name = \"from interpreter\" }: !transform.any_op
transform.yield
}
}
"""


@test_in_context
def include():
main = ir.Module.parse(print_root_via_include_module)
callee1 = ir.Module.parse(callee1_definition)
callee2 = ir.Module.parse(callee2_definition)
interp.copy_symbols_and_merge_into(main, callee1)
interp.copy_symbols_and_merge_into(main, callee2)

# CHECK: @print_root_via_include_module
# CHECK: transform.named_sequence @__transform_main
# CHECK: transform.include @callee2
#
# CHECK: transform.named_sequence @callee1
# CHECK: transform.print
#
# CHECK: transform.named_sequence @callee2
# CHECK: transform.include @callee1
interp.apply_named_sequence(main, main.body.operations[0], main)


@test_in_context
def partial_include():
main = ir.Module.parse(print_root_via_include_module)
callee2 = ir.Module.parse(callee2_definition)
interp.copy_symbols_and_merge_into(main, callee2)

try:
interp.apply_named_sequence(main, main.body.operations[0], main)
except ValueError as e:
assert "Failed to apply" in str(e)


@test_in_context
def repeated_include():
main = ir.Module.parse(print_root_via_include_module)
callee2 = ir.Module.parse(callee2_definition)
interp.copy_symbols_and_merge_into(main, callee2)

try:
interp.copy_symbols_and_merge_into(main, callee2)
except ValueError as e:
assert "doubly defined symbol @callee2" in str(e)

0 comments on commit 73140da

Please sign in to comment.