diff --git a/graph_net/torch/backend/unstable_to_stable_backend.py b/graph_net/torch/backend/unstable_to_stable_backend.py index f8cb3969c..fafe2c4cf 100644 --- a/graph_net/torch/backend/unstable_to_stable_backend.py +++ b/graph_net/torch/backend/unstable_to_stable_backend.py @@ -147,7 +147,27 @@ def _impl_unstable_to_stable_special_logit(self, gm): return gm - # replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm) + def _impl_unstable_to_stable_linalg_vector_norm(self, gm): + """ + Convert torch._C._linalg.linalg_vector_norm to torch.linalg.vector_norm + """ + # Update graph nodes: replace torch._C._linalg.linalg_vector_norm with torch.linalg.vector_norm + issue_nodes = ( + node + for node in gm.graph.nodes + if node.op == "call_function" + if hasattr(node.target, "__module__") + if node.target.__module__ == "torch._C._linalg" + if hasattr(node.target, "__name__") + if node.target.__name__ == "linalg_vector_norm" + ) + for node in issue_nodes: + node.target = torch.linalg.vector_norm + + # Recompile the graph + gm.recompile() + + return gm # replace this line with modification code for task 117 (torch._C._linalg.linalg_norm) diff --git a/graph_net/torch/fx_graph_serialize_util.py b/graph_net/torch/fx_graph_serialize_util.py index a67f83882..b6ebd77c2 100644 --- a/graph_net/torch/fx_graph_serialize_util.py +++ b/graph_net/torch/fx_graph_serialize_util.py @@ -25,7 +25,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str: (r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("), (r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("), (r"torch\._C\._special\.special_logit\(", "torch.special.logit("), - # replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm) + (r"torch\._C\._linalg\.linalg_vector_norm\(", "torch.linalg.vector_norm("), # replace this line with modification code for task 117 (torch._C._linalg.linalg_norm) # replace this line with modification code for task 118 (torch._C._nn.softplus) # replace this line with modification code for task 119 (torch._C._nn.one_hot)