Skip to content

Commit 612864d

Browse files
G2ugelixinqi
andauthored
【Hackathon 9th No.119】Convert torch._C._nn.one_hot to torch.nn.functional.one_hot (#335)
* Convert torch._C._nn.one_hot to torch.nn.functional.one_hot * resolve merge conflicts * resolve merge conflicts * update code --------- Co-authored-by: Li Xinqi <[email protected]>
1 parent a828505 commit 612864d

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,28 @@ def _impl_unstable_to_stable_softplus(self, gm):
201201
gm.recompile()
202202
return gm
203203

204-
# replace this line with modification code for task 119 (torch._C._nn.one_hot)
204+
def _impl_unstable_to_stable_one_hot(self, gm):
205+
"""
206+
Convert torch._C._nn.one_hot to torch.nn.functional.one_hot
207+
"""
208+
import torch.nn.functional as F
209+
210+
issue_nodes = (
211+
node
212+
for node in gm.graph.nodes
213+
if node.op == "call_function"
214+
if hasattr(node.target, "__module__")
215+
if node.target.__module__ == "torch._C._nn"
216+
if hasattr(node.target, "__name__")
217+
if node.target.__name__ == "one_hot"
218+
)
219+
for node in issue_nodes:
220+
node.target = F.one_hot
221+
222+
# Recompile the graph
223+
gm.recompile()
224+
225+
return gm
205226

206227
def _impl_unstable_to_stable_set_grad_enabled(self, gm):
207228
"""

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
142142
(r"torch\._C\._linalg\.linalg_vector_norm\(", "torch.linalg.vector_norm("),
143143
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
144144
(r"torch\._C\._nn\.softplus\(", "torch.nn.functional.softplus("),
145-
# replace this line with modification code for task 119 (torch._C._nn.one_hot)
145+
(r"torch\._C\._nn\.one_hot\(", "torch.nn.functional.one_hot("),
146146
(r"torch\._C\._set_grad_enabled\(", "torch.set_grad_enabled("),
147147
(r"torch\._C\.set_grad_enabled\(", "torch.set_grad_enabled("),
148148
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)

0 commit comments

Comments
 (0)