Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions graph_net/torch/backend/unstable_to_stable_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import inspect
from .graph_compiler_backend import GraphCompilerBackend
from ..fx_graph_serialize_util import serialize_graph_module_to_str


class UnstableToStableBackend(GraphCompilerBackend):
Expand All @@ -29,8 +30,32 @@ def my_backend(gm, sample_inputs):
**Stable API reference link:**
"""

def avg_pool2d_to_avg_pool2d(self, gm):
"""
Convert torch._C._nn.avg_pool2d to torch.nn.functional.avg_pool2d
"""
import torch.nn.functional as F

# Update graph nodes: replace torch._C._nn.avg_pool2d with F.avg_pool2d
for node in gm.graph.nodes:
if node.op == "call_function":
if (
hasattr(node.target, "__module__")
and hasattr(node.target, "__name__")
and node.target.__module__ == "torch._C._nn"
and node.target.__name__ == "avg_pool2d"
):
node.target = F.avg_pool2d
Comment on lines +40 to +48
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我比较讨厌嵌套太深,会习惯性地改为如下列表解析的形式:

issue_nodes = (
    node
    for node in gm.graph.nodes
    if node.op == "call_function"
    if hasattr(node.target, "__module__")
    if node.target.__module__ == "torch._C._nn"
    if hasattr(node.target, "__name__")
    if node.target.__name__ == "avg_pool2d"
)
for node in issue_nodes:
    node.target = F.avg_pool2d

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好滴好滴,收到


# Recompile the graph
gm.recompile()

return gm

def unstable_to_stable(self, gm):
# TODO
# Convert based on unstable_api environment variable
if self.unstable_api == "torch._C._nn.avg_pool2d":
gm = self.avg_pool2d_to_avg_pool2d(gm)
return gm

def check_unstable_api(self, gm):
Expand All @@ -44,7 +69,8 @@ def check_unstable_api(self, gm):
Do NOT modify, remove, or bypass this check under any circumstances.
"""

graph_text = gm.code
# Use serialized code to check for unstable APIs
graph_text = serialize_graph_module_to_str(gm)
# Search for the unstable API substring
if self.unstable_api in graph_text:
count = graph_text.count(self.unstable_api)
Expand Down
5 changes: 3 additions & 2 deletions graph_net/torch/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
from typing import Union, Callable
from . import utils
from .fx_graph_serialize_util import serialize_graph_module_to_str

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
Expand Down Expand Up @@ -89,9 +90,9 @@ def try_rename_placeholder(node):
assert input_idx == len(sample_inputs)
if self.mut_graph_codes is not None:
assert isinstance(self.mut_graph_codes, list)
self.mut_graph_codes.append(gm.code)
self.mut_graph_codes.append(serialize_graph_module_to_str(gm))
# 3. Generate and save model code
base_code = gm.code
base_code = serialize_graph_module_to_str(gm)
# gm.graph.print_tabular()
write_code = utils.apply_templates(base_code)
with open(os.path.join(subgraph_path, "model.py"), "w") as fp:
Expand Down
27 changes: 27 additions & 0 deletions graph_net/torch/fx_graph_serialize_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import re
import torch.fx


def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
"""
Serialize a GraphModule to a string representation, replacing unstable APIs
with their stable counterparts.

This function is used to normalize the code representation of GraphModule
for consistency checks and code generation.

Args:
gm: The GraphModule to serialize.

Returns:
A string representation of the GraphModule code with unstable APIs
replaced by stable ones.
"""
code = gm.code
# Replace torch._C._nn.avg_pool2d with torch.nn.functional.avg_pool2d
code = re.sub(
r"torch\._C\._nn\.avg_pool2d\(",
"torch.nn.functional.avg_pool2d(",
code,
)
return code