diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 0a179af852..752a52a243 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -18,6 +18,7 @@ class Tape(Iterable[ir.Node]): def __init__(self) -> None: self._nodes: list[ir.Node] = [] + self._initializers: list[ir.Value] = [] def __iter__(self) -> Iterator[ir.Node]: return iter(self._nodes) @@ -26,6 +27,10 @@ def __iter__(self) -> Iterator[ir.Node]: def nodes(self) -> Sequence[ir.Node]: return tuple(self._nodes) + @property + def initializers(self) -> Sequence[ir.Value]: + return tuple(self._initializers) + def op( self, op_type: str, @@ -60,6 +65,17 @@ def op_multi_output( return node.outputs + def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: + name = name or tensor.name + if name is None: + raise ValueError("Name must be provided for initializer.") + shape = ir.Shape((d if isinstance(d, int) else d.value) for d in tensor.shape.dims) + value = ir.Value( + name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor + ) + self._initializers.append(value) + return value + # A type representing the domains/versions used in creating nodes in IR. UsedOpsets = List[Tuple[str, Optional[int]]] diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 8b4dbbfe55..deb1be9e9e 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -119,7 +119,11 @@ def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: evaluator = self.get_evaluator(domain, op, version) if evaluator is None: return None - return evaluator(*args, **kwargs) + try: + return evaluator(*args, **kwargs) + except Exception as e: + logger.warning("Evaluation failed: %s", e) + return None _reference_evaluator = ReferenceEvaluator() diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 84ac42beb2..868da62443 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -900,6 +900,7 @@ class ReplacementSubgraph: match: MatchResult new_outputs: Sequence[ir.Value] new_nodes: Sequence[ir.Node] + new_initializers: Sequence[ir.Value] used_opsets: _tape.UsedOpsets @@ -928,7 +929,9 @@ def get_replacement(self, match: MatchResult) -> ReplacementSubgraph | None: return None # Failed to create replacement subgraph if not isinstance(new_outputs, Sequence): new_outputs = [new_outputs] - return ReplacementSubgraph(match, new_outputs, context.nodes, context.used_opsets) + return ReplacementSubgraph( + match, new_outputs, context.nodes, context.initializers, context.used_opsets + ) def _update_opset_imports( @@ -1566,6 +1569,23 @@ def _apply_to_graph_or_function( if delta is None or tracer is not None: continue assert isinstance(delta, ReplacementSubgraph) + if delta.new_initializers: + if isinstance(graph_or_function, ir.Function): + # TODO(rama): Can't add initializers to functions. But currently this is not + # an issue, as we apply inlining before applying rewrite rules. + if verbose: + print( + f"Rewrites adding initializers not supported for functions: {rule}" + ) + continue + initializers = graph_or_function.initializers + for initializer in delta.new_initializers: + if initializer.name in initializers: + if verbose: + print(f"Initializer {initializer.name} already exists.") + continue + for initializer in delta.new_initializers: + initializers[initializer.name] = initializer # type: ignore[index] # TODO: This does not yet handle the problem of determining the correct insertion point # for inserted nodes in the case of patterns with multiple output-nodes. The following # is sufficient for patterns with a single output-node "node", which can serve as the diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 1803ab6706..ca865ecde1 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -5,6 +5,7 @@ import logging import unittest +import numpy as np import onnx.checker import onnx.parser @@ -543,6 +544,39 @@ def test_model(x: FLOAT[1024]) -> FLOAT[1024]: # Not a robust test. But test serves to ensure that debug mode is producing something. self.assertIn("OpType mismatch: expected Abs, got Neg", captured_output) + def test_new_initializer(self): + def source_pattern(op, x, y): + return op.Gemm(x, op.Transpose(y)) + + def check(context, x, y): + return y.const_value is not None + + def replacement(op, x, y): + tensor = y.const_value + name = y.name + "_transposed" + transposed = ir.tensor(tensor.numpy().T, name=name) + initializer = op.initializer(transposed) + return op.Gemm(x, initializer) + + rule = pattern.RewriteRule(source_pattern, replacement, check) + + y_value = np.random.rand(8, 4).astype(np.float32) + + @script() + def test_model(x: FLOAT[16, 8]) -> FLOAT[16, 4]: + y = op.Constant(value=y_value) + return op.Gemm(x, op.Transpose(y)) + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual(len(model.graph.initializers), 1) + last_node = model.graph[-1] + self.assertEqual(len(last_node.inputs), 2) + init_name = last_node.inputs[1].name + self.assertIn(init_name, model.graph.initializers) + self.assertIs(last_node.inputs[1], model.graph.initializers[init_name]) + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self):