Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
unsafe_remove_auto_functionalized_pass,
)
from torch.export.exported_program import (
_get_updated_range_constraints,
ConstantArgument,
ExportGraphSignature,
InputKind,
Expand All @@ -64,6 +63,39 @@
Val = Any


def _get_updated_range_constraints(gm):
def get_shape_env(gm):
vals = [
node.meta["val"]
for node in gm.graph.nodes
if node.meta.get("val", None) is not None
]
from torch._guards import detect_fake_mode # type: ignore[21]

fake_mode = detect_fake_mode(vals)
if fake_mode is not None:
return fake_mode.shape_env
for v in vals:
if isinstance(v, torch.SymInt):
return v.node.shape_env

shape_env = get_shape_env(gm)
if shape_env is None:
return {}
range_constraints = {
k: v
for k, v in shape_env.var_to_range.items()
if k not in shape_env.replacements
}
# Only when we have an unbacked symint, and it's used as constructor inputs,
# runtime_var_to_range will make a difference compated to var_to_range.
# e.g. [2, oo) -> [0, oo)
for k, v in shape_env.var_to_range.items():
if k not in shape_env.replacements:
range_constraints[k] = v
return range_constraints


def _get_updated_graph_signature(
old_signature: ExportGraphSignature,
new_gm: torch.fx.GraphModule,
Expand Down