forked from pytorch/benchmark
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinstall.py
98 lines (90 loc) · 3.86 KB
/
install.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import argparse
import subprocess
import os
import sys
import importlib
import tarfile
from torchbenchmark import setup, _test_https, proxy_suggestion, TORCH_DEPS
from torchbenchmark.util.env_check import get_pkg_versions
def git_lfs_checkout():
tb_dir = os.path.dirname(os.path.realpath(__file__))
try:
subprocess.check_call(['git', 'lfs', 'install'], stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, cwd=tb_dir)
subprocess.check_call(['git', 'lfs', 'fetch'], stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, cwd=tb_dir)
subprocess.check_call(['git', 'lfs', 'checkout', '.'], stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, cwd=tb_dir)
except subprocess.CalledProcessError as e:
return (False, e.output)
except Exception as e:
return (False, e)
return True, None
def decompress_input():
tb_dir = os.path.dirname(os.path.realpath(__file__))
data_dir = os.path.join(tb_dir, "torchbenchmark", "data")
# Hide decompressed file in .data directory so that they won't be checked in
decompress_dir = os.path.join(data_dir, ".data")
os.makedirs(decompress_dir, exist_ok=True)
# Decompress every tar.gz file
for tarball in filter(lambda x: x.endswith(".tar.gz"), os.listdir(data_dir)):
tarball_path = os.path.join(data_dir, tarball)
print(f"decompressing input tarball: {tarball}...", end="", flush=True)
tar = tarfile.open(tarball_path)
tar.extractall(path=decompress_dir)
tar.close()
print("OK")
def pip_install_requirements():
if not _test_https():
print(proxy_suggestion)
sys.exit(-1)
try:
subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'],
check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
return (False, e.output)
except Exception as e:
return (False, e)
return True, None
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--continue_on_fail", action="store_true")
parser.add_argument("--models", nargs='+', default=[],
help="Specify one or more models to install. If not set, install all models.")
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()
print(f"checking packages {', '.join(TORCH_DEPS)} are installed...", end="", flush=True)
try:
versions = get_pkg_versions(TORCH_DEPS)
except ModuleNotFoundError as e:
print("FAIL")
print(f"Error: Users must first manually install packages {TORCH_DEPS} before installing the benchmark.")
sys.exit(-1)
print("OK")
print("checking out Git LFS files...", end="", flush=True)
success, errmsg = git_lfs_checkout()
if success:
print("OK")
else:
print("FAIL")
print("Failed to checkout git lfs files. Please make sure you have installed git lfs.")
print(errmsg)
sys.exit(-1)
decompress_input()
success, errmsg = pip_install_requirements()
if not success:
print("Failed to install torchbenchmark requirements:")
print(errmsg)
if not args.continue_on_fail:
sys.exit(-1)
new_versions = get_pkg_versions(TORCH_DEPS)
if versions != new_versions:
print(f"The torch packages are re-installed after installing the benchmark deps. \
Before: {versions}, after: {new_versions}")
sys.exit(-1)
success &= setup(models=args.models, verbose=args.verbose, continue_on_fail=args.continue_on_fail)
if not success:
if args.continue_on_fail:
print("Warning: some benchmarks were not installed due to failure")
else:
raise RuntimeError("Failed to complete setup")