-
Notifications
You must be signed in to change notification settings - Fork 82
/
train_decomp.py
64 lines (50 loc) · 2.28 KB
/
train_decomp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import argparse
import functools
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from model_decomp import *
from networks import *
from options import DecompOptions
from data import get_loader
def getDecompNetworks(args):
initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
initp_dec = InitPose_Dec(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,
hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))
stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,
hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)
return initp_enc, initp_dec, movement_enc, stdp_dec
if __name__ == "__main__":
parser = DecompOptions()
args = parser.parse()
args.train = True
if args.name is None:
args.name = 'Decomp'
args.log_dir = os.path.join(args.log_dir, args.name)
if not os.path.exists(args.log_dir):
os.mkdir(args.log_dir)
args.snapshot_dir = os.path.join(args.snapshot_dir, args.name)
if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir)
data_loader = get_loader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, dataset=args.dataset, data_dir=args.data_dir, tolerance=args.tolerance)
initp_enc, initp_dec, movement_enc, stdp_dec = getDecompNetworks(args)
trainer = Trainer_Decomp(data_loader,
initp_enc = initp_enc,
initp_dec = initp_dec,
movement_enc = movement_enc,
stdp_dec = stdp_dec,
args = args
)
if not args.resume is None:
ep, it = trainer.resume(args.resume, False)
else:
ep, it = 0, 0
trainer.train(ep=ep, it=it)