Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions graph_net/torch/backend/unstable_to_stable_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,34 @@ def avg_pool2d_to_avg_pool2d(self, gm):

return gm

def fft_rfft_to_rfft(self, gm):
"""
Convert torch._C._fft.fft_rfft to torch.fft.rfft
"""
# Update graph nodes: replace torch._C._fft.fft_rfft with torch.fft.rfft
issue_nodes = (
node
for node in gm.graph.nodes
if node.op == "call_function"
if hasattr(node.target, "__module__")
if node.target.__module__ == "torch._C._fft"
if hasattr(node.target, "__name__")
if node.target.__name__ == "fft_rfft"
)
for node in issue_nodes:
node.target = torch.fft.rfft

# Recompile the graph
gm.recompile()

return gm

def unstable_to_stable(self, gm):
# Convert based on unstable_api environment variable
if self.unstable_api == "torch._C._nn.avg_pool2d":
gm = self.avg_pool2d_to_avg_pool2d(gm)
elif self.unstable_api == "torch._C._fft.fft_rfft":
gm = self.fft_rfft_to_rfft(gm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里每次提交都会冲突,给合入造成了一些小困扰。我建立在这里写一行这样的代码:

methods = (name for name in vars(type(self)).keys() if name.startswith('_impl_unstable_to_stable'))
for method in methods:
    gm = getattr(self, method)(gm)

这样每个人的代码就写到各自的_impl_unstable_to_stable_XXX方法里,不再彼此干扰。

return gm

def check_unstable_api(self, gm):
Expand Down
5 changes: 5 additions & 0 deletions graph_net/torch/fx_graph_serialize_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,9 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
"torch.nn.functional.avg_pool2d(",
code,
)
code = re.sub(
r"torch\._C\._fft\.fft_rfft\(",
"torch.fft.rfft(",
code,
)
return code