Skip to content

Commit ea928bb

Browse files
authored
Convert torch._C._nn.pad to torch.nn.functional.pad (#342)
1 parent 68ec87c commit ea928bb

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,29 @@ def replace_in_graph(graph_mod):
268268

269269
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
270270

271-
# replace this line with modification code for task 123 (torch._C._nn.pad)
271+
def _impl_unstable_to_stable_pad(self, gm):
272+
"""
273+
Convert torch._C._nn.pad to torch.nn.functional.pad
274+
"""
275+
import torch.nn.functional as F
276+
277+
def replace_in_graph(graph_mod):
278+
for node in graph_mod.graph.nodes:
279+
if node.op == "call_function":
280+
if "pad" in str(node.target) and "torch._C._nn" in str(node.target):
281+
node.target = F.pad
282+
graph_mod.recompile()
283+
284+
modules = [gm]
285+
modules += [
286+
m
287+
for _, m in gm.named_modules()
288+
if isinstance(m, torch.fx.GraphModule) and m is not gm
289+
]
290+
for m in modules:
291+
replace_in_graph(m)
292+
293+
return gm
272294

273295
# replace this line with modification code for task 125 (torch._C._nn.gelu)
274296

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
146146
(r"torch\._C\._set_grad_enabled\(", "torch.set_grad_enabled("),
147147
(r"torch\._C\.set_grad_enabled\(", "torch.set_grad_enabled("),
148148
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
149-
# replace this line with modification code for task 123 (torch._C._nn.pad)
149+
(r"torch\._C\._nn\.pad\(", "torch.nn.functional.pad("),
150150
# replace this line with modification code for task 125 (torch._C._nn.gelu)
151151
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
152152
(r"torch\._C\._nn\.linear\(", "torch.nn.functional.linear("),

0 commit comments

Comments
 (0)