Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions graph_net/torch/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dataclasses import dataclass
from contextlib import contextmanager
import time
import torch_tensorrt


def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
Expand All @@ -20,8 +21,10 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul

with open(file_path, "r", encoding="utf-8") as f:
original_code = f.read()
import_stmt = "import torch"
modified_code = f"{import_stmt}\n{original_code}"
if torch.cuda.is_available():
modified_code = original_code.replace("cpu", "cuda")
else:
modified_code = original_code
spec = importlib.util.spec_from_loader(module_name, loader=None)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
Expand All @@ -33,20 +36,32 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul


def get_compiler(args):
assert args.compiler == "default"
return torch.compile
if args.compiler == "tensorrt":
return torch.compile
else:
assert args.compiler == "default"
return torch.compile


def get_backend(args):
if args.compiler == "tensorrt":
return "tensorrt"
else:
assert args.compiler == "default"
return "inductor"


def get_synchronizer_func(args):
assert args.compiler == "default"
assert args.compiler == "default" or args.compiler == "tensorrt"
return torch.cuda.synchronize


def get_model(args):
model_class = load_class_from_file(
f"{args.model_path}/model.py", class_name="GraphModule"
)
return model_class()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return model_class().to(device)


def get_input_dict(args):
Expand All @@ -72,10 +87,11 @@ def naive_timer(duration_box, get_synchronizer_func):

def test_single_model(args):
compiler = get_compiler(args)
backend = get_backend(args)
synchronizer_func = get_synchronizer_func(args)
input_dict = get_input_dict(args)
model = get_model(args)
compiled_model = compiler(model)
compiled_model = compiler(model, backend=backend)

# eager
eager_duration_box = DurationBox(-1)
Expand Down Expand Up @@ -157,11 +173,11 @@ def test_multi_models(args):
cmd = "".join(
[
sys.executable,
"-m graph_net.torch.test_compiler",
f"--model-path {model_path}",
f"--compiler {args.compiler}",
f"--warmup {args.warmup}",
f"--log-prompt {args.log_prompt}",
" -m graph_net.torch.test_compiler",
f" --model-path {model_path}",
f" --compiler {args.compiler}",
f" --warmup {args.warmup}",
f" --log-prompt {args.log_prompt}",
]
)
cmd_ret = os.system(cmd)
Expand Down
2 changes: 1 addition & 1 deletion graph_net/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def extract_dynamic_shapes(example_inputs):


def replay_tensor(info):
device = info["info"]["device"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = info["info"]["dtype"]
shape = info["info"]["shape"]

Expand Down