From 9a744c57bb804b33edfa71792401d99e42d4b01c Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Thu, 9 Oct 2025 21:15:14 +0000 Subject: [PATCH 1/2] fix constant in constant folding --- onnxscript/optimizer/_constant_folding.py | 42 ++++++++++++----------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 8317d2be63..d2606f0b1e 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -76,7 +76,7 @@ def _is_onnx_op(node: ir.Node, op_type: str) -> bool: def _process_constant_node(node: ir.Node) -> None: """Sets const_value of output value of a Constant op node.""" - if node.op_type != "Constant" or node.domain != "": + if not _is_onnx_op(node, "Constant"): return if len(node.attributes) != 1: return @@ -1099,24 +1099,11 @@ def process_node(self, node: ir.Node) -> Replacement | None: self._modified = True # TODO(rama): consider merging type/other info from both values - # Do incremental shape inference - if self.shape_inference and not _is_control_flow_op(node): - self._do_inference(node) - - if node.domain not in self._opset_imports: + # Propagate const_value, and manually find out shape and type + # to avoid potentially expensive shape inference on large tensors. + if _is_onnx_op(node, "Constant"): + _process_constant_node(node) return None - version = self._opset_imports[node.domain] - op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) - for optimizer in op_optimizers: - assert optimizer - context = RewriterContext() - output = optimizer(node, context, self._state) - if output is not None: - if isinstance(output, Replacement): - return output - if isinstance(output, ir.Value): - output = [output] - return Replacement(output, context.nodes) if _is_control_flow_op(node): logger.info( @@ -1137,9 +1124,24 @@ def process_node(self, node: ir.Node) -> Replacement | None: ) return None - if _is_onnx_op(node, "Constant"): - _process_constant_node(node) + # Do incremental shape inference + if self.shape_inference: + self._do_inference(node) + + if node.domain not in self._opset_imports: return None + version = self._opset_imports[node.domain] + op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) + for optimizer in op_optimizers: + assert optimizer + context = RewriterContext() + output = optimizer(node, context, self._state) + if output is not None: + if isinstance(output, Replacement): + return output + if isinstance(output, ir.Value): + output = [output] + return Replacement(output, context.nodes) if any(x.is_graph_input() for x in node.inputs if x is not None): logger.info( From 84deb77671dfe761fc992f82a63e23b66ce9fbc7 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Thu, 9 Oct 2025 21:45:01 +0000 Subject: [PATCH 2/2] revert some changes --- onnxscript/optimizer/_constant_folding.py | 40 ++++++++++++----------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index d2606f0b1e..9a740c783c 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1103,6 +1103,27 @@ def process_node(self, node: ir.Node) -> Replacement | None: # to avoid potentially expensive shape inference on large tensors. if _is_onnx_op(node, "Constant"): _process_constant_node(node) + # Do incremental shape inference + elif self.shape_inference and not _is_control_flow_op(node): + self._do_inference(node) + + if node.domain not in self._opset_imports: + return None + version = self._opset_imports[node.domain] + op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) + for optimizer in op_optimizers: + assert optimizer + context = RewriterContext() + output = optimizer(node, context, self._state) + if output is not None: + if isinstance(output, Replacement): + return output + if isinstance(output, ir.Value): + output = [output] + return Replacement(output, context.nodes) + + if _is_onnx_op(node, "Constant"): + logger.debug("Skipping constant folding for Constant node %r", node.name) return None if _is_control_flow_op(node): @@ -1124,25 +1145,6 @@ def process_node(self, node: ir.Node) -> Replacement | None: ) return None - # Do incremental shape inference - if self.shape_inference: - self._do_inference(node) - - if node.domain not in self._opset_imports: - return None - version = self._opset_imports[node.domain] - op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) - for optimizer in op_optimizers: - assert optimizer - context = RewriterContext() - output = optimizer(node, context, self._state) - if output is not None: - if isinstance(output, Replacement): - return output - if isinstance(output, ir.Value): - output = [output] - return Replacement(output, context.nodes) - if any(x.is_graph_input() for x in node.inputs if x is not None): logger.info( "Skipping constant folding for node %r because it is graph input to preserve graph signature",