-
Notifications
You must be signed in to change notification settings - Fork 31
/
pretrain_a2kp.py
95 lines (72 loc) · 3.95 KB
/
pretrain_a2kp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import matplotlib
matplotlib.use('Agg')
import os, sys
import yaml
from argparse import ArgumentParser
from time import gmtime, strftime
from shutil import copy
# from frames_dataset_transformer25 import FramesWavsDatasetMEL25VoxWoTBatch as FramesWavsDatasetMEL25
# from frames_dataset_transformer25 import FramesWavsDatasetMEL25VoxBoxQG2Img as FramesWavsDatasetMEL25
from frames_dataset_transformer25 import FramesWavsDatasetMEL25VoxBoxQG2 as FramesWavsDatasetMEL25
from modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
from modules.discriminator import MultiScaleDiscriminator
from modules.keypoint_detector import KPDetector, HEEstimator
from modules.transformer import Audio2kpTransformerBBoxQDeep as Audio2kpTransformer
import torch
from train_transformer import train_batch as train
if __name__ == "__main__":
if sys.version_info[0] < 3:
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
parser = ArgumentParser()
parser.add_argument("--config", default="config/vox-transformer.yaml", help="path to config")
parser.add_argument("--mode", default="train", choices=["train",])
parser.add_argument("--gen", default="spade", choices=["original", "spade"])
parser.add_argument("--log_dir", default='./output/', help="path to log into")
parser.add_argument("--checkpoint", default='./00000189-checkpoint.pth.tar', help="path to checkpoint to restore")
parser.add_argument("--device_ids", default="0, 1, 2, 3, 4, 5, 6, 7", type=lambda x: list(map(int, x.split(','))),
help="Names of the devices comma separated.")
parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")
parser.set_defaults(verbose=False)
opt = parser.parse_args()
with open(opt.config) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
# log dir when checkpoint is set
# if opt.checkpoint is not None:
# log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
# else:
log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
if opt.gen == 'original':
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
elif opt.gen == 'spade':
generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
if torch.cuda.is_available():
print('cuda is available')
generator.to(opt.device_ids[0])
if opt.verbose:
print(generator)
discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],
**config['model_params']['common_params'])
if torch.cuda.is_available():
discriminator.to(opt.device_ids[0])
if opt.verbose:
print(discriminator)
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
if torch.cuda.is_available():
kp_detector.to(opt.device_ids[0])
if opt.verbose:
print(kp_detector)
audio2kptransformer = Audio2kpTransformer(**config['model_params']['audio2kp_params'])
if torch.cuda.is_available():
audio2kptransformer.to(opt.device_ids[0])
dataset = FramesWavsDatasetMEL25(is_train=(opt.mode == 'train'), **config['dataset_params'])
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
copy(opt.config, log_dir)
if opt.mode == 'train':
print("Training...")
train(config, generator, discriminator, kp_detector, audio2kptransformer, opt.checkpoint, log_dir, dataset, opt.device_ids)