diff --git a/exir/program/_program.py b/exir/program/_program.py index 990804fdcda..10d0043398f 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -48,7 +48,6 @@ unsafe_remove_auto_functionalized_pass, ) from torch.export.exported_program import ( - _get_updated_range_constraints, ConstantArgument, ExportGraphSignature, InputKind, @@ -64,6 +63,39 @@ Val = Any +def _get_updated_range_constraints(gm): + def get_shape_env(gm): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + from torch._guards import detect_fake_mode # type: ignore[21] + + fake_mode = detect_fake_mode(vals) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + shape_env = get_shape_env(gm) + if shape_env is None: + return {} + range_constraints = { + k: v + for k, v in shape_env.var_to_range.items() + if k not in shape_env.replacements + } + # Only when we have an unbacked symint, and it's used as constructor inputs, + # runtime_var_to_range will make a difference compated to var_to_range. + # e.g. [2, oo) -> [0, oo) + for k, v in shape_env.var_to_range.items(): + if k not in shape_env.replacements: + range_constraints[k] = v + return range_constraints + + def _get_updated_graph_signature( old_signature: ExportGraphSignature, new_gm: torch.fx.GraphModule,