Skip to content

Commit ab09107

Browse files
committed
Working on engine
Signed-off-by: Justin Chu <[email protected]>
1 parent 5a34891 commit ab09107

File tree

3 files changed

+31
-36
lines changed

3 files changed

+31
-36
lines changed

src/onnx_ir/_shape_type_inference/_common.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,21 @@ class NodeInferrer(abc.ABC):
7676
This class provides a common interface for all node inferrers.
7777
"""
7878

79-
def __init__(self, op_type: str, opsets: Collection[int], domain: str = "") -> None:
79+
def __init__(
80+
self, op_type: str, opsets: Collection[int], domain: str = "", overload: str = ""
81+
) -> None:
8082
"""Initialize the node inferrer.
8183
8284
Args:
8385
op_type: The type of the operation.
8486
opsets: A collection of ONNX opset versions supported by this inferrer.
8587
domain: The domain of the operation, default is an empty string.
88+
overload: The overload identifier for the operation, default is an empty string.
8689
"""
8790
self.op_type = op_type
8891
self.opsets = opsets
8992
self.domain = domain
93+
self.overload = overload
9094

9195
def __repr__(self) -> str:
9296
"""Return a string representation of the node inferrer."""

src/onnx_ir/_shape_type_inference/_engine.py

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import enum
66
import logging
7-
from collections.abc import Sequence
7+
from collections.abc import Iterable, Sequence
88

99
import onnx_ir as ir
1010
from onnx_ir._shape_type_inference import _common
@@ -30,7 +30,7 @@ class SymbolicInferenceEngine:
3030

3131
def __init__(
3232
self,
33-
node_inferrers: Sequence[_common.NodeInferrer],
33+
node_inferrers: Iterable[_common.NodeInferrer],
3434
reconciliation_policy: str = "reconcile",
3535
) -> None:
3636
"""Initialize the symbolic inference engine.
@@ -40,16 +40,12 @@ def __init__(
4040
reconciliation_policy: Policy for handling conflicts between inferred and existing values.
4141
"""
4242
self.reconciliation_policy = ReconciliationPolicy(reconciliation_policy)
43-
self._inferrer_registry: dict[tuple[str, str], list[_common.NodeInferrer]] = {}
43+
self._inferrer_registry: dict[ir.OperatorIdentifier, list[_common.NodeInferrer]] = {}
4444

4545
# Register inferrers by (op_type, domain)
4646
for inferrer in node_inferrers:
47-
key = (inferrer.op_type, inferrer.domain)
48-
if key not in self._inferrer_registry:
49-
self._inferrer_registry[key] = []
50-
self._inferrer_registry[key].append(inferrer)
51-
52-
logger.info("Initialized inference engine with %s inferrers", len(node_inferrers))
47+
key = (inferrer.domain, inferrer.op_type, inferrer.overload)
48+
self._inferrer_registry.setdefault(key, []).append(inferrer)
5349

5450
def infer_model(self, model: ir.Model) -> None:
5551
"""Perform shape and type inference on an entire model.
@@ -60,10 +56,10 @@ def infer_model(self, model: ir.Model) -> None:
6056
Raises:
6157
InferenceError: If inference fails for any node.
6258
"""
63-
logger.info("Starting inference on model with %s nodes", len(model.graph.nodes))
59+
logger.info("Starting inference on model with %s nodes", len(model.graph))
6460

6561
# Process nodes in topological order
66-
for i, node in enumerate(model.graph.nodes):
62+
for i, node in enumerate(model.graph):
6763
try:
6864
self._infer_node(node, model)
6965
logger.debug("Successfully inferred node %s: %s", i, node.op_type)
@@ -123,14 +119,24 @@ def _find_inferrer(self, node: ir.Node, model: ir.Model) -> _common.NodeInferrer
123119
Returns:
124120
The best matching inferrer, or None if no suitable inferrer is found.
125121
"""
126-
key = (node.op_type, node.domain)
122+
key = (node.domain, node.op_type, node.overload)
127123
inferrers = self._inferrer_registry.get(key, [])
128124

129125
if not inferrers:
130126
return None
131127

132128
# Get model opset version for this domain
133-
opset_version = self._get_opset_version(model, node.domain)
129+
if node.version is not None:
130+
opset_version = node.version
131+
elif node.graph is not None and node.domain in node.graph.opset_imports:
132+
opset_version = node.graph.opset_imports[node.domain]
133+
else:
134+
# Fallback to model-level opset import
135+
if node.domain not in model.opset_imports:
136+
raise InferenceError(
137+
f"No opset import found for domain '{node.domain}' in model"
138+
)
139+
opset_version = model.opset_imports[node.domain]
134140

135141
# Find inferrers that support this opset version
136142
suitable_inferrers = [
@@ -140,31 +146,15 @@ def _find_inferrer(self, node: ir.Node, model: ir.Model) -> _common.NodeInferrer
140146
if not suitable_inferrers:
141147
logger.warning(
142148
"No inferrer supports opset %s for %s (domain: %s)",
143-
opset_version, node.op_type, node.domain
149+
opset_version,
150+
node.op_type,
151+
node.domain,
144152
)
145153
return None
146154

147155
# Return the first suitable inferrer (could be enhanced with priority logic)
148156
return suitable_inferrers[0]
149157

150-
def _get_opset_version(self, model: ir.Model, domain: str) -> int:
151-
"""Get the opset version for a given domain in the model.
152-
153-
Args:
154-
model: The model to check.
155-
domain: The domain to get the opset version for.
156-
157-
Returns:
158-
The opset version for the domain.
159-
"""
160-
# Look for opset import for this domain
161-
for opset_import in model.opset_imports:
162-
if opset_import.domain == domain:
163-
return opset_import.version
164-
165-
# Default to a high version if not found
166-
return 999
167-
168158
def _reconcile_outputs(self, node: ir.Node, inferred_values: Sequence[ir.Value]) -> None:
169159
"""Reconcile inferred output values with existing node outputs.
170160
@@ -251,8 +241,7 @@ def _reconcile_shapes(self, shape1: ir.Shape, shape2: ir.Shape) -> ir.Shape:
251241
"""
252242
if len(shape1) != len(shape2):
253243
logger.warning(
254-
"Shape rank mismatch: %s vs %s. Using first shape.",
255-
len(shape1), len(shape2)
244+
"Shape rank mismatch: %s vs %s. Using first shape.", len(shape1), len(shape2)
256245
)
257246
return shape1
258247

src/onnx_ir/_shape_type_inference/ops/matmul.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def infer(self, node: ir.Node) -> _common.InferenceResult:
5151
rhs_batch = rhs_shape[:-2]
5252
if lhs_batch and rhs_batch:
5353
# TODO(justinchuby): Ensure this is correct
54-
batch_shape = broadcast_shapes_bidirectional(ir.Shape(lhs_batch), ir.Shape(rhs_batch))
54+
batch_shape = broadcast_shapes_bidirectional(
55+
ir.Shape(lhs_batch), ir.Shape(rhs_batch)
56+
)
5557
output_dims = [*batch_shape, lhs_shape[-2], rhs_shape[-1]]
5658
output_shape = ir.Shape(output_dims)
5759
elif lhs_batch:

0 commit comments

Comments
 (0)