Skip to content

Commit 48f3147

Browse files
committed
Convert torch._C._linalg.linalg_norm to torch.linalg.norm
1 parent c65f7fa commit 48f3147

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,26 @@ def _impl_unstable_to_stable_special_logit(self, gm):
149149

150150
# replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
151151

152-
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
152+
def _impl_unstable_to_stable_linalg_norm(self, gm):
153+
"""
154+
Convert torch._C._linalg.linalg_norm to torch.linalg.norm
155+
"""
156+
# Update graph nodes: replace torch._C._linalg.linalg_norm with torch.linalg.norm
157+
issue_nodes = (
158+
node
159+
for node in gm.graph.nodes
160+
if node.op == "call_function"
161+
if hasattr(node.target, "__module__")
162+
if node.target.__module__ == "torch._C._linalg"
163+
if hasattr(node.target, "__name__")
164+
if node.target.__name__ == "linalg_norm"
165+
)
166+
for node in issue_nodes:
167+
node.target = torch.linalg.norm
168+
169+
# Recompile the graph
170+
gm.recompile()
171+
return gm
153172

154173
# replace this line with modification code for task 118 (torch._C._nn.softplus)
155174

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
2626
(r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("),
2727
(r"torch\._C\._special\.special_logit\(", "torch.special.logit("),
2828
# replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
29-
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
29+
(r"torch\._C\._linalg\.linalg_norm\(", "torch.linalg.norm("),
3030
# replace this line with modification code for task 118 (torch._C._nn.softplus)
3131
# replace this line with modification code for task 119 (torch._C._nn.one_hot)
3232
# replace this line with modification code for task 121 (torch._C._set_grad_enabled)

0 commit comments

Comments
 (0)