Skip to content

Commit d5ce907

Browse files
authored
【Hackathon 9th No.114】torch._C._fft.fft_rfft API转换 torch.fft.rfft (#331)
* replace torch._C._fft.fft_rfft with torch.fft.rfft * updata code * 解决冲突问题
1 parent b45e611 commit d5ce907

File tree

2 files changed

+39
-17
lines changed

2 files changed

+39
-17
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def my_backend(gm, sample_inputs):
3030
**Stable API reference link:**
3131
"""
3232

33-
def fft_irfft_to_irfft(self, gm):
33+
def _impl_unstable_to_stable_irfft(self, gm):
3434
def replace_in_graph(graph_mod):
3535
# Register stable implementation on GraphModule, codegen can use self.irfft
3636
try:
@@ -60,7 +60,7 @@ def replace_in_graph(graph_mod):
6060

6161
return gm
6262

63-
def avg_pool2d_to_avg_pool2d(self, gm):
63+
def _impl_unstable_to_stable_avg_pool2d(self, gm):
6464
"""
6565
Convert torch._C._nn.avg_pool2d to torch.nn.functional.avg_pool2d
6666
"""
@@ -82,12 +82,36 @@ def avg_pool2d_to_avg_pool2d(self, gm):
8282

8383
return gm
8484

85+
def _impl_unstable_to_stable_rfft(self, gm):
86+
"""
87+
Convert torch._C._fft.fft_rfft to torch.fft.rfft
88+
"""
89+
# Update graph nodes: replace torch._C._fft.fft_rfft with torch.fft.rfft
90+
issue_nodes = (
91+
node
92+
for node in gm.graph.nodes
93+
if node.op == "call_function"
94+
if hasattr(node.target, "__module__")
95+
if node.target.__module__ == "torch._C._fft"
96+
if hasattr(node.target, "__name__")
97+
if node.target.__name__ == "fft_rfft"
98+
)
99+
for node in issue_nodes:
100+
node.target = torch.fft.rfft
101+
102+
# Recompile the graph
103+
gm.recompile()
104+
105+
return gm
106+
85107
def unstable_to_stable(self, gm):
86-
# Convert based on unstable_api environment variable
87-
if self.unstable_api == "torch._C._nn.avg_pool2d":
88-
gm = self.avg_pool2d_to_avg_pool2d(gm)
89-
elif self.unstable_api == "torch._C._fft.fft_irfft":
90-
gm = self.fft_irfft_to_irfft(gm)
108+
methods = (
109+
name
110+
for name in vars(type(self)).keys()
111+
if name.startswith("_impl_unstable_to_stable")
112+
)
113+
for method in methods:
114+
gm = getattr(self, method)(gm)
91115
return gm
92116

93117
def check_unstable_api(self, gm):

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
1919
"""
2020
code = gm.code
2121
# Replace torch._C._nn.avg_pool2d with torch.nn.functional.avg_pool2d
22-
code = re.sub(
23-
r"torch\._C\._nn\.avg_pool2d\(",
24-
"torch.nn.functional.avg_pool2d(",
25-
code,
26-
)
27-
code = re.sub(
28-
r"torch\._C\._fft\.fft_irfft\(",
29-
"torch.fft.irfft(",
30-
code,
31-
)
22+
replacements = [
23+
(r"torch\._C\._nn\.avg_pool2d\(", "torch.nn.functional.avg_pool2d("),
24+
(r"torch\._C\._fft\.fft_irfft\(", "torch.fft.irfft("),
25+
(r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("),
26+
# Add new rules to this list as needed
27+
]
28+
for pattern, repl in replacements:
29+
code = re.sub(pattern, repl, code)
3230
return code

0 commit comments

Comments
 (0)