Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 96 additions & 66 deletions graph_net/torch/extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch
import json
import shutil
from typing import Union, Callable
from . import utils

Expand Down Expand Up @@ -80,76 +81,105 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
def wrapper(model: torch.nn.Module):
assert isinstance(model, torch.nn.Module), f"{type(model)=}"

def extractor(gm: torch.fx.GraphModule, sample_inputs):
# 1. Get workspace path
workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
if not workspace_path:
raise EnvironmentError(
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
class GraphExtractor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以定义到 globals 下边吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已在#235 中修改

def __init__(self):
self.subgraph_counter = 0

def move_files(self, source_dir, target_dir):
os.makedirs(target_dir, exist_ok=True)
for item in os.listdir(source_dir):
source_path = os.path.join(source_dir, item)
if os.path.isfile(source_path):
target_path = os.path.join(target_dir, item)
shutil.move(source_path, target_path)

def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
# 1. Get workspace path
workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
if not workspace_path:
raise EnvironmentError(
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
)
model_path = os.path.join(workspace_path, name)
os.makedirs(model_path, exist_ok=True)

if self.subgraph_counter == 0:
subgraph_path = model_path
else:
if self.subgraph_counter == 1:
subgraph_0_path = os.path.join(model_path, f"subgraph_0")
self.move_files(model_path, subgraph_0_path)

subgraph_path = os.path.join(
model_path, f"subgraph_{self.subgraph_counter}"
)
os.makedirs(subgraph_path, exist_ok=True)

self.subgraph_counter += 1

# 2. Get full params
params = {}
input_idx = 0
unique_id = 0

def try_rename_placeholder(node):
assert node.op == "placeholder"
if not placeholder_auto_rename:
return
nonlocal unique_id
node.target = f"v{unique_id}"
unique_id += 1
node.name = f"v{unique_id}"
unique_id += 1

for node in gm.graph.nodes:
if node.op == "placeholder":
try_rename_placeholder(node)
input = sample_inputs[input_idx]
if isinstance(input, torch.SymInt):
input = torch.tensor(4)
params[node.target] = input
input_idx += 1
assert input_idx == len(sample_inputs)
if mut_graph_codes is not None:
assert isinstance(mut_graph_codes, list)
mut_graph_codes.append(gm.code)
# 3. Generate and save model code
base_code = gm.code
# gm.graph.print_tabular()
write_code = utils.apply_templates(base_code)
with open(os.path.join(subgraph_path, "model.py"), "w") as fp:
fp.write(write_code)

# 4. Save metadata
metadata = {
"framework": "torch",
"num_devices_required": 1,
"num_nodes_required": 1,
"dynamic": bool(dynamic),
"model_name": name,
}
with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f:
json.dump(metadata, f, indent=4)

# 5. Save tensor metadata
# Adapt to different input structures (e.g., single tensor vs. dict/tuple of tensors)
converted = utils.convert_state_and_inputs(params, [])
utils.save_converted_to_text(converted, file_path=subgraph_path)
utils.save_constraints_text(
converted,
file_path=os.path.join(
subgraph_path, "input_tensor_constraints.py"
),
)
model_path = os.path.join(workspace_path, name)
os.makedirs(model_path, exist_ok=True)

# 2. Get full params
params = {}
input_idx = 0
unique_id = 0

def try_rename_placeholder(node):
assert node.op == "placeholder"
if not placeholder_auto_rename:
return
nonlocal unique_id
node.target = f"v{unique_id}"
unique_id += 1
node.name = f"v{unique_id}"
unique_id += 1

for node in gm.graph.nodes:
if node.op == "placeholder":
try_rename_placeholder(node)
input = sample_inputs[input_idx]
if isinstance(input, torch.SymInt):
input = torch.tensor(4)
params[node.target] = input
input_idx += 1
assert input_idx == len(sample_inputs)
if mut_graph_codes is not None:
assert isinstance(mut_graph_codes, list)
mut_graph_codes.append(gm.code)
# 3. Generate and save model code
base_code = gm.code
# gm.graph.print_tabular()
write_code = utils.apply_templates(base_code)
with open(os.path.join(model_path, "model.py"), "w") as fp:
fp.write(write_code)

# 4. Save metadata
metadata = {
"framework": "torch",
"num_devices_required": 1,
"num_nodes_required": 1,
"dynamic": bool(dynamic),
"model_name": name,
}
with open(os.path.join(model_path, "graph_net.json"), "w") as f:
json.dump(metadata, f, indent=4)

# 5. Save tensor metadata
# Adapt to different input structures (e.g., single tensor vs. dict/tuple of tensors)
converted = utils.convert_state_and_inputs(params, [])
utils.save_converted_to_text(converted, file_path=model_path)
utils.save_constraints_text(
converted,
file_path=os.path.join(model_path, "input_tensor_constraints.py"),
)

print(
f"Graph and tensors for '{name}' extracted successfully to: {model_path}"
)
print(
f"Graph and tensors for '{name}' extracted successfully to: {model_path}"
)

return gm.forward
return gm.forward

extractor = GraphExtractor()
# return torch.compile(backend=extractor, dynamic=dynamic)
compiled_model = torch.compile(model, backend=extractor, dynamic=dynamic)

Expand Down
37 changes: 33 additions & 4 deletions graph_net/torch/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ def temp_workspace():
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = old


def main(args):
model_path = args.model_path
def validate(args, model_path):
with temp_workspace() as tmp_dir_name:
print("Check extractability ...")
cmd = f"{sys.executable} -m graph_net.torch.single_device_runner --model-path {model_path}"
Expand All @@ -36,16 +35,46 @@ def main(args):
if args.graph_net_samples_path is None
else args.graph_net_samples_path
)
cmd = f"{sys.executable} -m graph_net.torch.check_redundant_incrementally --model-path {args.model_path} --graph-net-samples-path {graph_net_samples_path}"
cmd = f"{sys.executable} -m graph_net.torch.check_redundant_incrementally --model-path {model_path} --graph-net-samples-path {graph_net_samples_path}"
cmd_ret = os.system(cmd)
rm_cmd = f"{sys.executable} -m graph_net.torch.remove_redundant_incrementally --model-path {args.model_path} --graph-net-samples-path {graph_net_samples_path}"
rm_cmd = f"{sys.executable} -m graph_net.torch.remove_redundant_incrementally --model-path {model_path} --graph-net-samples-path {graph_net_samples_path}"
assert (
cmd_ret == 0
), f"\nPlease use the following command to remove redundant model directories:\n\n{rm_cmd}\n"

print(f"Validation success, {model_path=}")


def get_recursively_model_path(root_dir):
for sub_dir in get_immediate_subdirectory_paths(root_dir):
if is_single_model_dir(sub_dir):
yield sub_dir
else:
yield from get_recursively_model_path(sub_dir)


def get_immediate_subdirectory_paths(parent_dir):
return [
sub_dir
for name in os.listdir(parent_dir)
for sub_dir in [os.path.join(parent_dir, name)]
if os.path.isdir(sub_dir)
]


def is_single_model_dir(model_dir):
return os.path.isfile(f"{model_dir}/graph_net.json")


def main(args):
model_path = args.model_path
if is_single_model_dir(args.model_path):
validate(args, model_path)
else:
for model_path in get_recursively_model_path(args.model_path):
validate(args, model_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Validate a computation graph sample. return 0 if success"
Expand Down