Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions onnxscript/ir/_tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

def __init__(self) -> None:
self._nodes: list[ir.Node] = []
self._initializers: list[ir.Value] = []

def __iter__(self) -> Iterator[ir.Node]:
return iter(self._nodes)
Expand All @@ -26,6 +27,10 @@
def nodes(self) -> Sequence[ir.Node]:
return tuple(self._nodes)

@property
def initializers(self) -> Sequence[ir.Value]:
return tuple(self._initializers)

def op(
self,
op_type: str,
Expand Down Expand Up @@ -60,6 +65,17 @@

return node.outputs

def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value:
name = name or tensor.name
if name is None:
raise ValueError("Name must be provided for initializer.")

Check warning on line 71 in onnxscript/ir/_tape.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_tape.py#L71

Added line #L71 was not covered by tests
shape = ir.Shape((d if isinstance(d, int) else d.value) for d in tensor.shape.dims)
value = ir.Value(
name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor
)
self._initializers.append(value)
return value


# A type representing the domains/versions used in creating nodes in IR.
UsedOpsets = List[Tuple[str, Optional[int]]]
Expand Down
6 changes: 5 additions & 1 deletion onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@
evaluator = self.get_evaluator(domain, op, version)
if evaluator is None:
return None
return evaluator(*args, **kwargs)
try:
return evaluator(*args, **kwargs)
except Exception as e:
logger.warning("Evaluation failed: %s", e)
return None

Check warning on line 126 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L124-L126

Added lines #L124 - L126 were not covered by tests


_reference_evaluator = ReferenceEvaluator()
Expand Down
22 changes: 21 additions & 1 deletion onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,7 @@
match: MatchResult
new_outputs: Sequence[ir.Value]
new_nodes: Sequence[ir.Node]
new_initializers: Sequence[ir.Value]
used_opsets: _tape.UsedOpsets


Expand Down Expand Up @@ -928,7 +929,9 @@
return None # Failed to create replacement subgraph
if not isinstance(new_outputs, Sequence):
new_outputs = [new_outputs]
return ReplacementSubgraph(match, new_outputs, context.nodes, context.used_opsets)
return ReplacementSubgraph(
match, new_outputs, context.nodes, context.initializers, context.used_opsets
)


def _update_opset_imports(
Expand Down Expand Up @@ -1566,6 +1569,23 @@
if delta is None or tracer is not None:
continue
assert isinstance(delta, ReplacementSubgraph)
if delta.new_initializers:
if isinstance(graph_or_function, ir.Function):
# TODO(rama): Can't add initializers to functions. But currently this is not
# an issue, as we apply inlining before applying rewrite rules.
if verbose:
print(

Check warning on line 1577 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L1577

Added line #L1577 was not covered by tests
f"Rewrites adding initializers not supported for functions: {rule}"
)
continue

Check warning on line 1580 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L1580

Added line #L1580 was not covered by tests
initializers = graph_or_function.initializers
for initializer in delta.new_initializers:
if initializer.name in initializers:
if verbose:
print(f"Initializer {initializer.name} already exists.")
continue

Check warning on line 1586 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L1585-L1586

Added lines #L1585 - L1586 were not covered by tests
for initializer in delta.new_initializers:
initializers[initializer.name] = initializer # type: ignore[index]
# TODO: This does not yet handle the problem of determining the correct insertion point
# for inserted nodes in the case of patterns with multiple output-nodes. The following
# is sufficient for patterns with a single output-node "node", which can serve as the
Expand Down
34 changes: 34 additions & 0 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import unittest

import numpy as np
import onnx.checker
import onnx.parser

Expand Down Expand Up @@ -543,6 +544,39 @@
# Not a robust test. But test serves to ensure that debug mode is producing something.
self.assertIn("OpType mismatch: expected Abs, got Neg", captured_output)

def test_new_initializer(self):
def source_pattern(op, x, y):
return op.Gemm(x, op.Transpose(y))

def check(context, x, y):
return y.const_value is not None

def replacement(op, x, y):
tensor = y.const_value
name = y.name + "_transposed"
transposed = ir.tensor(tensor.numpy().T, name=name)
initializer = op.initializer(transposed)
return op.Gemm(x, initializer)

rule = pattern.RewriteRule(source_pattern, replacement, check)

y_value = np.random.rand(8, 4).astype(np.float32)

@script()
def test_model(x: FLOAT[16, 8]) -> FLOAT[16, 4]:
y = op.Constant(value=y_value)
return op.Gemm(x, op.Transpose(y))

Check warning on line 568 in onnxscript/rewriter/pattern_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern_test.py#L567-L568

Added lines #L567 - L568 were not covered by tests

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
rule.apply_to_model(model)
self.assertEqual(len(model.graph.initializers), 1)
last_node = model.graph[-1]
self.assertEqual(len(last_node.inputs), 2)
init_name = last_node.inputs[1].name
self.assertIn(init_name, model.graph.initializers)
self.assertIs(last_node.inputs[1], model.graph.initializers[init_name])


class PatternBuilderTest(unittest.TestCase):
def test_pattern_builder_context(self):
Expand Down
Loading