Skip to content

Commit 3c9c020

Browse files
committed
Convert torch._C._linalg.linalg_norm to torch.linalg.norm
1 parent 612864d commit 3c9c020

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,26 @@ def _impl_unstable_to_stable_linalg_vector_norm(self, gm):
178178

179179
return gm
180180

181-
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
181+
def _impl_unstable_to_stable_linalg_norm(self, gm):
182+
"""
183+
Convert torch._C._linalg.linalg_norm to torch.linalg.norm
184+
"""
185+
# Update graph nodes: replace torch._C._linalg.linalg_norm with torch.linalg.norm
186+
issue_nodes = (
187+
node
188+
for node in gm.graph.nodes
189+
if node.op == "call_function"
190+
if hasattr(node.target, "__module__")
191+
if node.target.__module__ == "torch._C._linalg"
192+
if hasattr(node.target, "__name__")
193+
if node.target.__name__ == "linalg_norm"
194+
)
195+
for node in issue_nodes:
196+
node.target = torch.linalg.norm
197+
198+
# Recompile the graph
199+
gm.recompile()
200+
return gm
182201

183202
def _impl_unstable_to_stable_softplus(self, gm):
184203
"""

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
140140
(r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("),
141141
(r"torch\._C\._special\.special_logit\(", "torch.special.logit("),
142142
(r"torch\._C\._linalg\.linalg_vector_norm\(", "torch.linalg.vector_norm("),
143-
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
143+
(r"torch\._C\._linalg\.linalg_norm\(", "torch.linalg.norm("),
144144
(r"torch\._C\._nn\.softplus\(", "torch.nn.functional.softplus("),
145145
(r"torch\._C\._nn\.one_hot\(", "torch.nn.functional.one_hot("),
146146
(r"torch\._C\._set_grad_enabled\(", "torch.set_grad_enabled("),

0 commit comments

Comments
 (0)