@@ -30,7 +30,7 @@ def my_backend(gm, sample_inputs):
3030 **Stable API reference link:**
3131 """
3232
33- def fft_irfft_to_irfft (self , gm ):
33+ def _impl_unstable_to_stable_irfft (self , gm ):
3434 def replace_in_graph (graph_mod ):
3535 # Register stable implementation on GraphModule, codegen can use self.irfft
3636 try :
@@ -60,7 +60,7 @@ def replace_in_graph(graph_mod):
6060
6161 return gm
6262
63- def avg_pool2d_to_avg_pool2d (self , gm ):
63+ def _impl_unstable_to_stable_avg_pool2d (self , gm ):
6464 """
6565 Convert torch._C._nn.avg_pool2d to torch.nn.functional.avg_pool2d
6666 """
@@ -82,12 +82,36 @@ def avg_pool2d_to_avg_pool2d(self, gm):
8282
8383 return gm
8484
85+ def _impl_unstable_to_stable_rfft (self , gm ):
86+ """
87+ Convert torch._C._fft.fft_rfft to torch.fft.rfft
88+ """
89+ # Update graph nodes: replace torch._C._fft.fft_rfft with torch.fft.rfft
90+ issue_nodes = (
91+ node
92+ for node in gm .graph .nodes
93+ if node .op == "call_function"
94+ if hasattr (node .target , "__module__" )
95+ if node .target .__module__ == "torch._C._fft"
96+ if hasattr (node .target , "__name__" )
97+ if node .target .__name__ == "fft_rfft"
98+ )
99+ for node in issue_nodes :
100+ node .target = torch .fft .rfft
101+
102+ # Recompile the graph
103+ gm .recompile ()
104+
105+ return gm
106+
85107 def unstable_to_stable (self , gm ):
86- # Convert based on unstable_api environment variable
87- if self .unstable_api == "torch._C._nn.avg_pool2d" :
88- gm = self .avg_pool2d_to_avg_pool2d (gm )
89- elif self .unstable_api == "torch._C._fft.fft_irfft" :
90- gm = self .fft_irfft_to_irfft (gm )
108+ methods = (
109+ name
110+ for name in vars (type (self )).keys ()
111+ if name .startswith ("_impl_unstable_to_stable" )
112+ )
113+ for method in methods :
114+ gm = getattr (self , method )(gm )
91115 return gm
92116
93117 def check_unstable_api (self , gm ):
0 commit comments