Skip to content

How to use rewriter to append new nodes ? #2064

@Johansmm

Description

@Johansmm

I would like to know if it is possible to write pattern functions to update an ONNX model, in order to including new nodes but without remove the previous ones. Here is a simple example of what I would like to achieve

import onnx
from onnxscript.rewriter import pattern, rewrite


def _target_pattern(op, x):
    return op.Relu(x, _outputs=['relu'])


def _replacement_pattern(op, relu, **_):
    return op.Identity(relu)


def _validate_pattern(op, relu, **_):
    return not any(outbound.op_type in ["Identity"] for outbound in relu.consumers())


# Define model
model = onnx.parser.parse_model(
    """< ir_version: 8, opset_import: ["" : 15] >
    test_model (float[N, 3, 112, 112] X) => (float [N, ?, ?, ?] Y){
        A = Relu(X)
        Y = Relu(A)
    }"""
)

# Define transformation rules
rules = [
    pattern.RewriteRule(_target_pattern, _replacement_pattern, _validate_pattern,
                        remove_nodes=False)
]

# Apply rewrites
new_model = rewrite(model, pattern_rewrite_rules=rules)

# Expected sequence
assert [node.op_type for node in new_model.graph.node] == ["Relu", "Identity", "Relu", "Identity"]

I am trying to use the output of _target_pattern in _replacement_pattern to add the node, but I can't seem to stop the nodes from being removed.

Any idea?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions