Skip to content

Commit e9da94f

Browse files
committed
Convert torch._C._linalg.linalg_norm to torch.linalg.norm
1 parent 44ca1a5 commit e9da94f

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,27 @@ def _impl_unstable_to_stable_fftn(self, gm):
126126

127127
return gm
128128

129+
def _impl_unstable_to_stable_linalg_norm(self, gm):
130+
"""
131+
Convert torch._C._linalg.linalg_norm to torch.linalg.norm
132+
"""
133+
# Update graph nodes: replace torch._C._linalg.linalg_norm with torch.linalg.norm
134+
issue_nodes = (
135+
node
136+
for node in gm.graph.nodes
137+
if node.op == "call_function"
138+
if hasattr(node.target, "__module__")
139+
if node.target.__module__ == "torch._C._linalg"
140+
if hasattr(node.target, "__name__")
141+
if node.target.__name__ == "linalg_norm"
142+
)
143+
for node in issue_nodes:
144+
node.target = torch.linalg.norm
145+
146+
# Recompile the graph
147+
gm.recompile()
148+
return gm
149+
129150
def unstable_to_stable(self, gm):
130151
methods = (
131152
name

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
2424
(r"torch\._C\._fft\.fft_irfft\(", "torch.fft.irfft("),
2525
(r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("),
2626
(r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("),
27+
(r"torch\._C\._linalg\.linalg_norm\(", "torch.linalg.norm("),
2728
# Add new rules to this list as needed
2829
]
2930
for pattern, repl in replacements:

0 commit comments

Comments
 (0)