Skip to content

Commit a828505

Browse files
authored
Convert torch._C._linalg.linalg_vector_norm to torch.linalg.vector_norm (#339)
1 parent afa6121 commit a828505

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,27 @@ def _impl_unstable_to_stable_special_logit(self, gm):
156156
gm.recompile()
157157
return gm
158158

159-
# replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
159+
def _impl_unstable_to_stable_linalg_vector_norm(self, gm):
160+
"""
161+
Convert torch._C._linalg.linalg_vector_norm to torch.linalg.vector_norm
162+
"""
163+
# Update graph nodes: replace torch._C._linalg.linalg_vector_norm with torch.linalg.vector_norm
164+
issue_nodes = (
165+
node
166+
for node in gm.graph.nodes
167+
if node.op == "call_function"
168+
if hasattr(node.target, "__module__")
169+
if node.target.__module__ == "torch._C._linalg"
170+
if hasattr(node.target, "__name__")
171+
if node.target.__name__ == "linalg_vector_norm"
172+
)
173+
for node in issue_nodes:
174+
node.target = torch.linalg.vector_norm
175+
176+
# Recompile the graph
177+
gm.recompile()
178+
179+
return gm
160180

161181
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
162182

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
139139
(r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("),
140140
(r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("),
141141
(r"torch\._C\._special\.special_logit\(", "torch.special.logit("),
142-
# replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
142+
(r"torch\._C\._linalg\.linalg_vector_norm\(", "torch.linalg.vector_norm("),
143143
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
144144
(r"torch\._C\._nn\.softplus\(", "torch.nn.functional.softplus("),
145145
# replace this line with modification code for task 119 (torch._C._nn.one_hot)

0 commit comments

Comments
 (0)