Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -399,15 +399,15 @@ def ApplyLoopInvariantCodeMotionOp : TransformDialectOp<"apply_licm",
}

def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
[TransformOpInterface, TransformEachOpTrait,
FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Applies the specified registered pass or pass pipeline";
let description = [{
This transform applies the specified pass or pass pipeline to the targeted
ops. The name of the pass/pipeline is specified as a string attribute, as
set during pass/pipeline registration. Optionally, pass options may be
specified as a string attribute. The pass options syntax is identical to the
one used with "mlir-opt".
specified as a string attribute with the option to pass the attribute as a
param. The pass options syntax is identical to the one used with "mlir-opt".

This op first looks for a pass pipeline with the specified name. If no such
pipeline exists, it looks for a pass with the specified name. If no such
Expand All @@ -420,20 +420,15 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
of targeted ops.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
let arguments = (ins Optional<TransformParamTypeInterface>:$dynamic_options,
TransformHandleTypeInterface:$target,
StrAttr:$pass_name,
DefaultValuedAttr<StrAttr, "\"\"">:$options);
DefaultValuedAttr<StrAttr, "\"\"">:$static_options);
let results = (outs TransformHandleTypeInterface:$result);
let assemblyFormat = [{
$pass_name `to` $target attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
$pass_name (`with` `options` `=`
custom<ApplyRegisteredPassOptions>($dynamic_options, $static_options)^)?
`to` $target attr-dict `:` functional-type(operands, results)
}];
}

Expand Down
117 changes: 99 additions & 18 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@

using namespace mlir;

static ParseResult parseApplyRegisteredPassOptions(
OpAsmParser &parser,
std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
StringAttr &staticOptions);
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
Operation *op, Value dynamicOptions,
StringAttr staticOptions);
static ParseResult parseSequenceOpOperands(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
Type &rootType,
Expand Down Expand Up @@ -766,17 +773,38 @@ void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
// ApplyRegisteredPassOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
ApplyToEachResultList &results, transform::TransformState &state) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous. Even
// more so when applying passes because they may perform a wide range of IR
// modifications.
DiagnosedSilenceableFailure payloadCheck =
ensurePayloadIsSeparateFromTransform(*this, target);
if (!payloadCheck.succeeded())
return payloadCheck;
void transform::ApplyRegisteredPassOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
onlyReadsHandle(getDynamicOptionsMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}

DiagnosedSilenceableFailure
transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
// Check whether pass options are specified, either as a dynamic param or
// a static attribute. In either case, options are passed as a single string.
StringRef options;
if (auto dynamicOptions = getDynamicOptions()) {
ArrayRef<Attribute> dynamicOptionsParam = state.getParams(dynamicOptions);
if (dynamicOptionsParam.size() != 1) {
return emitSilenceableError()
<< "options passed as a param must be a single value, got "
<< dynamicOptionsParam.size();
}
if (auto optionsStrAttr = dyn_cast<StringAttr>(dynamicOptionsParam[0])) {
options = optionsStrAttr.getValue();
} else {
return emitSilenceableError()
<< "options passed as a param must be a string, got "
<< dynamicOptionsParam[0];
}
} else {
options = getStaticOptions();
}

// Get pass or pass pipeline from registry.
const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
Expand All @@ -786,26 +814,79 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
return emitDefiniteFailure()
<< "unknown pass or pass pipeline: " << getPassName();

// Create pass manager and run the pass or pass pipeline.
// Create pass manager and add the pass or pass pipeline.
PassManager pm(getContext());
if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
emitError(msg);
return failure();
}))) {
return emitDefiniteFailure()
<< "failed to add pass or pass pipeline to pipeline: "
<< getPassName();
}
if (failed(pm.run(target))) {
auto diag = emitSilenceableError() << "pass pipeline failed";
diag.attachNote(target->getLoc()) << "target op";
return diag;

auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
for (Operation *target : targets) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous. Even
// more so when applying passes because they may perform a wide range of IR
// modifications.
DiagnosedSilenceableFailure payloadCheck =
ensurePayloadIsSeparateFromTransform(*this, target);
if (!payloadCheck.succeeded())
return payloadCheck;

// Run the pass or pass pipeline on the current target operation.
if (failed(pm.run(target))) {
auto diag = emitSilenceableError() << "pass pipeline failed";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
}

results.push_back(target);
// The applied pass will have directly modified the payload IR(s).
results.set(llvm::cast<OpResult>(getResult()), targets);
return DiagnosedSilenceableFailure::success();
}

static ParseResult parseApplyRegisteredPassOptions(
OpAsmParser &parser,
std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
StringAttr &staticOptions) {
dynamicOptions = std::nullopt;
OpAsmParser::UnresolvedOperand dynamicOptionsOperand;
OptionalParseResult hasDynamicOptions =
parser.parseOptionalOperand(dynamicOptionsOperand);

if (hasDynamicOptions.has_value()) {
if (failed(hasDynamicOptions.value()))
return failure();

dynamicOptions = dynamicOptionsOperand;
return success();
}

OptionalParseResult hasStaticOptions =
parser.parseOptionalAttribute(staticOptions);
if (hasStaticOptions.has_value()) {
if (failed(hasStaticOptions.value()))
return failure();
return success();
}

return success();
}

static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
Operation *op, Value dynamicOptions,
StringAttr staticOptions) {
if (dynamicOptions) {
printer.printOperand(dynamicOptions);
} else if (!staticOptions.getValue().empty()) {
printer.printAttribute(staticOptions);
}
}

//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
Expand Down
53 changes: 51 additions & 2 deletions mlir/test/Dialect/Transform/test-pass-application.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ module attributes {transform.with_named_sequence} {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{failed to add pass or pass pipeline to pipeline: canonicalize}}
// expected-error @below {{<Pass-Options-Parser>: no such option invalid-option}}
transform.apply_registered_pass "canonicalize" to %1 {options = "invalid-option=1"} : (!transform.any_op) -> !transform.any_op
transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
Expand All @@ -94,7 +94,56 @@ func.func @valid_pass_option() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_registered_pass "canonicalize" to %1 {options = "top-down=false"} : (!transform.any_op) -> !transform.any_op
transform.apply_registered_pass "canonicalize" with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: func @valid_dynamic_pass_option()
func.func @valid_dynamic_pass_option() {
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%pass_options = transform.param.constant "top-down=false" -> !transform.any_param
transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----

func.func @invalid_pass_option_param() {
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%pass_options = transform.param.constant 42 -> !transform.any_param
// expected-error @below {{options passed as a param must be a string, got 42}}
transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

func.func @too_many_pass_option_params() {
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%x = transform.param.constant "x" -> !transform.any_param
%pass_options = transform.merge_handles %x, %x : !transform.any_param
// expected-error @below {{options passed as a param must be a single value, got 2}}
transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
transform.yield
}
}
Expand Down
Loading