File tree Expand file tree Collapse file tree 2 files changed +24
-2
lines changed Expand file tree Collapse file tree 2 files changed +24
-2
lines changed Original file line number Diff line number Diff line change @@ -268,7 +268,29 @@ def replace_in_graph(graph_mod):
268268
269269 # replace this line with modification code for task 122 (torch._C._log_api_usage_once)
270270
271- # replace this line with modification code for task 123 (torch._C._nn.pad)
271+ def _impl_unstable_to_stable_pad (self , gm ):
272+ """
273+ Convert torch._C._nn.pad to torch.nn.functional.pad
274+ """
275+ import torch .nn .functional as F
276+
277+ def replace_in_graph (graph_mod ):
278+ for node in graph_mod .graph .nodes :
279+ if node .op == "call_function" :
280+ if "pad" in str (node .target ) and "torch._C._nn" in str (node .target ):
281+ node .target = F .pad
282+ graph_mod .recompile ()
283+
284+ modules = [gm ]
285+ modules += [
286+ m
287+ for _ , m in gm .named_modules ()
288+ if isinstance (m , torch .fx .GraphModule ) and m is not gm
289+ ]
290+ for m in modules :
291+ replace_in_graph (m )
292+
293+ return gm
272294
273295 # replace this line with modification code for task 125 (torch._C._nn.gelu)
274296
Original file line number Diff line number Diff line change @@ -146,7 +146,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
146146 (r"torch\._C\._set_grad_enabled\(" , "torch.set_grad_enabled(" ),
147147 (r"torch\._C\.set_grad_enabled\(" , "torch.set_grad_enabled(" ),
148148 # replace this line with modification code for task 122 (torch._C._log_api_usage_once)
149- # replace this line with modification code for task 123 ( torch._C._nn.pad)
149+ ( r" torch\ ._C\ ._nn\ .pad\(" , "torch.nn.functional.pad(" ),
150150 # replace this line with modification code for task 125 (torch._C._nn.gelu)
151151 # replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
152152 (r"torch\._C\._nn\.linear\(" , "torch.nn.functional.linear(" ),
You can’t perform that action at this time.
0 commit comments