Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _is_onnx_op(node: ir.Node, op_type: str) -> bool:

def _process_constant_node(node: ir.Node) -> None:
"""Sets const_value of output value of a Constant op node."""
if node.op_type != "Constant" or node.domain != "":
if not _is_onnx_op(node, "Constant"):
return
if len(node.attributes) != 1:
return
Expand Down Expand Up @@ -1099,8 +1099,12 @@ def process_node(self, node: ir.Node) -> Replacement | None:
self._modified = True
# TODO(rama): consider merging type/other info from both values

# Propagate const_value, and manually find out shape and type
# to avoid potentially expensive shape inference on large tensors.
if _is_onnx_op(node, "Constant"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that we can standardize this with other node processers, but non blocking for now.

_process_constant_node(node)
# Do incremental shape inference
if self.shape_inference and not _is_control_flow_op(node):
elif self.shape_inference and not _is_control_flow_op(node):
self._do_inference(node)

if node.domain not in self._opset_imports:
Expand All @@ -1118,6 +1122,10 @@ def process_node(self, node: ir.Node) -> Replacement | None:
output = [output]
return Replacement(output, context.nodes)

if _is_onnx_op(node, "Constant"):
logger.debug("Skipping constant folding for Constant node %r", node.name)
return None

if _is_control_flow_op(node):
logger.info(
"Skipping constant folding for control flow op %r (%s::%s) because it is not supported yet",
Expand All @@ -1137,10 +1145,6 @@ def process_node(self, node: ir.Node) -> Replacement | None:
)
return None

if _is_onnx_op(node, "Constant"):
_process_constant_node(node)
return None

if any(x.is_graph_input() for x in node.inputs if x is not None):
logger.info(
"Skipping constant folding for node %r because it is graph input to preserve graph signature",
Expand Down