-
Notifications
You must be signed in to change notification settings - Fork 88
Closed
Labels
Description
I would like to know if it is possible to write pattern functions to update an ONNX model, in order to including new nodes but without remove the previous ones. Here is a simple example of what I would like to achieve
import onnx
from onnxscript.rewriter import pattern, rewrite
def _target_pattern(op, x):
return op.Relu(x, _outputs=['relu'])
def _replacement_pattern(op, relu, **_):
return op.Identity(relu)
def _validate_pattern(op, relu, **_):
return not any(outbound.op_type in ["Identity"] for outbound in relu.consumers())
# Define model
model = onnx.parser.parse_model(
"""< ir_version: 8, opset_import: ["" : 15] >
test_model (float[N, 3, 112, 112] X) => (float [N, ?, ?, ?] Y){
A = Relu(X)
Y = Relu(A)
}"""
)
# Define transformation rules
rules = [
pattern.RewriteRule(_target_pattern, _replacement_pattern, _validate_pattern,
remove_nodes=False)
]
# Apply rewrites
new_model = rewrite(model, pattern_rewrite_rules=rules)
# Expected sequence
assert [node.op_type for node in new_model.graph.node] == ["Relu", "Identity", "Relu", "Identity"]I am trying to use the output of _target_pattern in _replacement_pattern to add the node, but I can't seem to stop the nodes from being removed.
Any idea?
leshabirukov