44
55import enum
66import logging
7- from collections .abc import Sequence
7+ from collections .abc import Iterable , Sequence
88
99import onnx_ir as ir
1010from onnx_ir ._shape_type_inference import _common
@@ -30,7 +30,7 @@ class SymbolicInferenceEngine:
3030
3131 def __init__ (
3232 self ,
33- node_inferrers : Sequence [_common .NodeInferrer ],
33+ node_inferrers : Iterable [_common .NodeInferrer ],
3434 reconciliation_policy : str = "reconcile" ,
3535 ) -> None :
3636 """Initialize the symbolic inference engine.
@@ -40,16 +40,12 @@ def __init__(
4040 reconciliation_policy: Policy for handling conflicts between inferred and existing values.
4141 """
4242 self .reconciliation_policy = ReconciliationPolicy (reconciliation_policy )
43- self ._inferrer_registry : dict [tuple [ str , str ] , list [_common .NodeInferrer ]] = {}
43+ self ._inferrer_registry : dict [ir . OperatorIdentifier , list [_common .NodeInferrer ]] = {}
4444
4545 # Register inferrers by (op_type, domain)
4646 for inferrer in node_inferrers :
47- key = (inferrer .op_type , inferrer .domain )
48- if key not in self ._inferrer_registry :
49- self ._inferrer_registry [key ] = []
50- self ._inferrer_registry [key ].append (inferrer )
51-
52- logger .info ("Initialized inference engine with %s inferrers" , len (node_inferrers ))
47+ key = (inferrer .domain , inferrer .op_type , inferrer .overload )
48+ self ._inferrer_registry .setdefault (key , []).append (inferrer )
5349
5450 def infer_model (self , model : ir .Model ) -> None :
5551 """Perform shape and type inference on an entire model.
@@ -60,10 +56,10 @@ def infer_model(self, model: ir.Model) -> None:
6056 Raises:
6157 InferenceError: If inference fails for any node.
6258 """
63- logger .info ("Starting inference on model with %s nodes" , len (model .graph . nodes ))
59+ logger .info ("Starting inference on model with %s nodes" , len (model .graph ))
6460
6561 # Process nodes in topological order
66- for i , node in enumerate (model .graph . nodes ):
62+ for i , node in enumerate (model .graph ):
6763 try :
6864 self ._infer_node (node , model )
6965 logger .debug ("Successfully inferred node %s: %s" , i , node .op_type )
@@ -123,14 +119,24 @@ def _find_inferrer(self, node: ir.Node, model: ir.Model) -> _common.NodeInferrer
123119 Returns:
124120 The best matching inferrer, or None if no suitable inferrer is found.
125121 """
126- key = (node .op_type , node .domain )
122+ key = (node .domain , node . op_type , node .overload )
127123 inferrers = self ._inferrer_registry .get (key , [])
128124
129125 if not inferrers :
130126 return None
131127
132128 # Get model opset version for this domain
133- opset_version = self ._get_opset_version (model , node .domain )
129+ if node .version is not None :
130+ opset_version = node .version
131+ elif node .graph is not None and node .domain in node .graph .opset_imports :
132+ opset_version = node .graph .opset_imports [node .domain ]
133+ else :
134+ # Fallback to model-level opset import
135+ if node .domain not in model .opset_imports :
136+ raise InferenceError (
137+ f"No opset import found for domain '{ node .domain } ' in model"
138+ )
139+ opset_version = model .opset_imports [node .domain ]
134140
135141 # Find inferrers that support this opset version
136142 suitable_inferrers = [
@@ -140,31 +146,15 @@ def _find_inferrer(self, node: ir.Node, model: ir.Model) -> _common.NodeInferrer
140146 if not suitable_inferrers :
141147 logger .warning (
142148 "No inferrer supports opset %s for %s (domain: %s)" ,
143- opset_version , node .op_type , node .domain
149+ opset_version ,
150+ node .op_type ,
151+ node .domain ,
144152 )
145153 return None
146154
147155 # Return the first suitable inferrer (could be enhanced with priority logic)
148156 return suitable_inferrers [0 ]
149157
150- def _get_opset_version (self , model : ir .Model , domain : str ) -> int :
151- """Get the opset version for a given domain in the model.
152-
153- Args:
154- model: The model to check.
155- domain: The domain to get the opset version for.
156-
157- Returns:
158- The opset version for the domain.
159- """
160- # Look for opset import for this domain
161- for opset_import in model .opset_imports :
162- if opset_import .domain == domain :
163- return opset_import .version
164-
165- # Default to a high version if not found
166- return 999
167-
168158 def _reconcile_outputs (self , node : ir .Node , inferred_values : Sequence [ir .Value ]) -> None :
169159 """Reconcile inferred output values with existing node outputs.
170160
@@ -251,8 +241,7 @@ def _reconcile_shapes(self, shape1: ir.Shape, shape2: ir.Shape) -> ir.Shape:
251241 """
252242 if len (shape1 ) != len (shape2 ):
253243 logger .warning (
254- "Shape rank mismatch: %s vs %s. Using first shape." ,
255- len (shape1 ), len (shape2 )
244+ "Shape rank mismatch: %s vs %s. Using first shape." , len (shape1 ), len (shape2 )
256245 )
257246 return shape1
258247
0 commit comments