Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 62 additions & 19 deletions graph_net/torch/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,44 @@
from contextlib import contextmanager
import time


def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
try:
import torch_tensorrt
except ImportError:
torch_tensorrt = None


registry_backend = {
"inductor": {
"compiler": torch.compile,
"backend": "inductor",
"synchronizer": torch.cuda.synchronize,
},
"tensorrt": {
"compiler": torch.compile,
"backend": "tensorrt",
"synchronizer": torch.cuda.synchronize,
},
"default": {
"compiler": torch.compile,
"backend": "inductor",
"synchronizer": torch.cuda.synchronize,
},
}


def load_class_from_file(
args: argparse.Namespace, class_name: str
) -> Type[torch.nn.Module]:
file_path = f"{args.model_path}/model.py"
file = Path(file_path).resolve()
module_name = file.stem

with open(file_path, "r", encoding="utf-8") as f:
original_code = f.read()
import_stmt = "import torch"
modified_code = f"{import_stmt}\n{original_code}"
if args.device == "cuda":
modified_code = original_code.replace("cpu", "cuda")
else:
modified_code = original_code
spec = importlib.util.spec_from_loader(module_name, loader=None)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
Expand All @@ -33,26 +62,32 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul


def get_compiler(args):
assert args.compiler == "default"
return torch.compile
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
return registry_backend[args.compiler]["compiler"]


def get_backend(args):
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
return registry_backend[args.compiler]["backend"]


def get_synchronizer_func(args):
assert args.compiler == "default"
return torch.cuda.synchronize
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
return registry_backend[args.compiler]["synchronizer"]


def get_model(args):
model_class = load_class_from_file(
f"{args.model_path}/model.py", class_name="GraphModule"
)
return model_class()
model_class = load_class_from_file(args, class_name="GraphModule")
return model_class().to(torch.device(args.device))


def get_input_dict(args):
inputs_params = utils.load_converted_from_text(f"{args.model_path}")
params = inputs_params["weight_info"]
return {k: utils.replay_tensor(v) for k, v in params.items()}
return {
k: utils.replay_tensor(v).to(torch.device(args.device))
for k, v in params.items()
}


@dataclass
Expand All @@ -72,10 +107,11 @@ def naive_timer(duration_box, get_synchronizer_func):

def test_single_model(args):
compiler = get_compiler(args)
backend = get_backend(args)
synchronizer_func = get_synchronizer_func(args)
input_dict = get_input_dict(args)
model = get_model(args)
compiled_model = compiler(model)
compiled_model = compiler(model, backend=backend)

# eager
eager_duration_box = DurationBox(-1)
Expand Down Expand Up @@ -157,11 +193,11 @@ def test_multi_models(args):
cmd = "".join(
[
sys.executable,
"-m graph_net.torch.test_compiler",
f"--model-path {model_path}",
f"--compiler {args.compiler}",
f"--warmup {args.warmup}",
f"--log-prompt {args.log_prompt}",
" -m graph_net.torch.test_compiler",
f" --model-path {model_path}",
f" --compiler {args.compiler}",
f" --warmup {args.warmup}",
f" --log-prompt {args.log_prompt}",
]
)
cmd_ret = os.system(cmd)
Expand Down Expand Up @@ -212,6 +248,13 @@ def main(args):
default="default",
help="Path to customized compiler python file",
)
parser.add_argument(
"--device",
type=str,
required=False,
default="cpu",
help="Device for testing the compiler",
)
parser.add_argument(
"--warmup", type=int, required=False, default=5, help="Number of warmup steps"
)
Expand Down