-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·85 lines (67 loc) · 3.15 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from main_utils import *
import tempfile
import dnnlib
import os
from torch_utils import training_stats
from torch_utils import custom_ops
from training import training_loop
import json
import shutil
#num_gpus
def subprocess_fn(rank, args, temp_dir):
dnnlib.util.Logger(file_name=os.path.join(args.run_dir, "log.txt"), file_mode = "a", should_flush = True)
if args.num_gpus > 1:
init_file = os.path.abspath(os.path.join(temp_dir, ".torch_distributed_init"))
if os.name == "nt":
init_method = "file:///" + init_file.replace("\\", "/")
torch.distributed.init_process_group(backend = "gloo", init_method = init_method, rank = rank, world_size = args.num_gpus)
else:
init_method = f"file://{init_file}"
torch.distributed.init_process_group(backend = "nccl", init_method = init_method, rank = rank, world_size = args.num_gpus)
sync_device = torch.device("cuda", rank) if args.num_gpus > 1 else None
training_stats.init_multiprocessing(rank = rank, sync_device = sync_device)
if rank != 0:
custom_ops.verbosity = "none"
training_loop.training_loop(rank, args)
def main():
#log = set_logger(parser_args) #logger 저장위치 바꾸기, 시간 반영하기
#log.info("parser_args: {}".format(parser_args))
dnnlib.util.Logger(should_flush=True) # setting 어떻게?
print("\n\nBeginning of process.")
print_time()
run_desc = setup_training_loop_kwargs(parser_args)
parser_args.run_dir = parser_args.save_ckpt + "_" + time.strftime("%m-%d_%H:%M:%S", time.localtime())
print()
print('Training options:')
print(json.dumps(vars(parser_args), indent=2))
print()
print(f'Output directory: {parser_args.run_dir}')
print(f'Training data: {parser_args.training_set_kwargs.path}')
print(f'Training duration: {parser_args.total_kimg} kimg')
print(f'Number of GPUs: {parser_args.num_gpus}')
print(f'Number of images: {parser_args.training_set_kwargs.max_size}')
print(f'Image resolution: {parser_args.training_set_kwargs.resolution}')
print(f'Conditional model: {parser_args.training_set_kwargs.use_labels}')
print(f'Dataset x-flips: {parser_args.training_set_kwargs.xflip}')
print()
print('Creating output directory...')
os.makedirs(parser_args.run_dir)
with open(os.path.join(parser_args.run_dir, 'training_options.json'), 'wt') as f:
json.dump(vars(parser_args), f, indent=2)
print('Launching processes...')
torch.multiprocessing.set_start_method("spawn")
with tempfile.TemporaryDirectory() as temp_dir:
if parser_args.num_gpus == 1:
subprocess_fn(rank=0,
args = parser_args,
temp_dir = temp_dir)
else:
torch.multiprocessing.spawn(fn = subprocess_fn,
args = (parser_args, temp_dir),
nprocs = parser_args.num_gpus)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
if os.path.exists("/root/.cache/torch_extensions"):
shutil.rmtree("/root/.cache/torch_extensions")