diff --git a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h index 00095d5040a0..fa320324234e 100644 --- a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h +++ b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h @@ -60,7 +60,7 @@ MLIR_CAPI_EXPORTED void mlirTransformOptionsDestroy(MlirTransformOptions transformOptions); //----------------------------------------------------------------------------// -// Transform interpreter. +// Transform interpreter and utilities. //----------------------------------------------------------------------------// /// Applies the transformation script starting at the given transform root @@ -72,6 +72,16 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence( MlirOperation payload, MlirOperation transformRoot, MlirOperation transformModule, MlirTransformOptions transformOptions); +/// Merge the symbols from `other` into `target`, potentially renaming them to +/// avoid conflicts. Private symbols may be renamed during the merge, public +/// symbols must have at most one declaration. A name conflict in public symbols +/// is reported as an error before returning a failure. +/// +/// Note that this clones the `other` operation unlike the C++ counterpart that +/// takes ownership. +MLIR_CAPI_EXPORTED MlirLogicalResult +mlirMergeSymbolsIntoFromClone(MlirOperation target, MlirOperation other); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp index 6517f8c39dfa..f6b4532b1b6b 100644 --- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -82,6 +82,21 @@ static void populateTransformInterpreterSubmodule(py::module &m) { py::arg("payload_root"), py::arg("transform_root"), py::arg("transform_module"), py::arg("transform_options") = PyMlirTransformOptions()); + + m.def( + "copy_symbols_and_merge_into", + [](MlirOperation target, MlirOperation other) { + mlir::python::CollectDiagnosticsToStringScope scope( + mlirOperationGetContext(target)); + + MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other); + if (mlirLogicalResultIsFailure(result)) { + throw py::value_error( + "Failed to merge symbols.\nDiagnostic message " + + scope.takeMessage()); + } + }, + py::arg("target"), py::arg("other")); } PYBIND11_MODULE(_mlirTransformInterpreter, m) { diff --git a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp index eb6951dc5584..145455e1c1b3 100644 --- a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp +++ b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp @@ -15,6 +15,7 @@ #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" +#include "mlir/Dialect/Transform/IR/Utils.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" @@ -71,4 +72,12 @@ MlirLogicalResult mlirTransformApplyNamedSequence( unwrap(payload), unwrap(transformRoot), cast(unwrap(transformModule)), *unwrap(transformOptions))); } + +MlirLogicalResult mlirMergeSymbolsIntoFromClone(MlirOperation target, + MlirOperation other) { + OwningOpRef otherOwning(unwrap(other)->clone()); + LogicalResult result = transform::detail::mergeSymbolsInto( + unwrap(target), std::move(otherOwning)); + return wrap(result); +} } diff --git a/mlir/python/mlir/dialects/transform/interpreter/__init__.py b/mlir/python/mlir/dialects/transform/interpreter/__init__.py index 6145b99224eb..34cdc43cb617 100644 --- a/mlir/python/mlir/dialects/transform/interpreter/__init__.py +++ b/mlir/python/mlir/dialects/transform/interpreter/__init__.py @@ -5,7 +5,6 @@ from ....ir import Operation from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter - TransformOptions = _cextTransformInterpreter.TransformOptions @@ -31,3 +30,12 @@ def apply_named_sequence( _cextTransformInterpreter.apply_named_sequence(*args) else: _cextTransformInterpreter(*args, transform_options) + + +def copy_symbols_and_merge_into(target, other): + """Copies symbols from other into target, renaming private symbols to avoid + duplicates. Raises an error if copying would lead to duplicate public + symbols.""" + _cextTransformInterpreter.copy_symbols_and_merge_into( + _unpack_operation(target), _unpack_operation(other) + ) diff --git a/mlir/test/python/dialects/transform_interpreter.py b/mlir/test/python/dialects/transform_interpreter.py index 740c49f76a26..807a98c49327 100644 --- a/mlir/test/python/dialects/transform_interpreter.py +++ b/mlir/test/python/dialects/transform_interpreter.py @@ -54,3 +54,79 @@ def failed(): assert ( "must implement TransformOpInterface to be used as transform root" in str(e) ) + + +print_root_via_include_module = """ +module @print_root_via_include_module attributes {transform.with_named_sequence} { + transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly}) + transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly}) + transform.named_sequence @__transform_main(%root: !transform.any_op) { + transform.include @callee2 failures(propagate) + (%root) : (!transform.any_op) -> () + transform.yield + } +}""" + +callee2_definition = """ +module attributes {transform.with_named_sequence} { + transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly}) + transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) { + transform.include @callee1 failures(propagate) + (%root) : (!transform.any_op) -> () + transform.yield + } +} +""" + +callee1_definition = """ +module attributes {transform.with_named_sequence} { + transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) { + transform.print %root { name = \"from interpreter\" }: !transform.any_op + transform.yield + } +} +""" + + +@test_in_context +def include(): + main = ir.Module.parse(print_root_via_include_module) + callee1 = ir.Module.parse(callee1_definition) + callee2 = ir.Module.parse(callee2_definition) + interp.copy_symbols_and_merge_into(main, callee1) + interp.copy_symbols_and_merge_into(main, callee2) + + # CHECK: @print_root_via_include_module + # CHECK: transform.named_sequence @__transform_main + # CHECK: transform.include @callee2 + # + # CHECK: transform.named_sequence @callee1 + # CHECK: transform.print + # + # CHECK: transform.named_sequence @callee2 + # CHECK: transform.include @callee1 + interp.apply_named_sequence(main, main.body.operations[0], main) + + +@test_in_context +def partial_include(): + main = ir.Module.parse(print_root_via_include_module) + callee2 = ir.Module.parse(callee2_definition) + interp.copy_symbols_and_merge_into(main, callee2) + + try: + interp.apply_named_sequence(main, main.body.operations[0], main) + except ValueError as e: + assert "Failed to apply" in str(e) + + +@test_in_context +def repeated_include(): + main = ir.Module.parse(print_root_via_include_module) + callee2 = ir.Module.parse(callee2_definition) + interp.copy_symbols_and_merge_into(main, callee2) + + try: + interp.copy_symbols_and_merge_into(main, callee2) + except ValueError as e: + assert "doubly defined symbol @callee2" in str(e)