-
Notifications
You must be signed in to change notification settings - Fork 40
/
train_facade.py
119 lines (103 loc) · 4.61 KB
/
train_facade.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python
# python train_facade.py -g 0 -i ./facade/base --out result_facade --snapshot_interval 10000
from __future__ import print_function
import argparse
import os
import chainer
from chainer import training
from chainer.training import extensions
from chainer import serializers
from net import Discriminator
from net import Encoder
from net import Decoder
from updater import FacadeUpdater
from facade_dataset import FacadeDataset
from facade_visualizer import out_image
def main():
parser = argparse.ArgumentParser(description='chainer implementation of pix2pix')
parser.add_argument('--batchsize', '-b', type=int, default=1,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=200,
help='Number of sweeps over the dataset to train')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--dataset', '-i', default='./facade/base',
help='Directory of image files.')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--seed', type=int, default=0,
help='Random seed')
parser.add_argument('--snapshot_interval', type=int, default=1000,
help='Interval of snapshot')
parser.add_argument('--display_interval', type=int, default=100,
help='Interval of displaying log to console')
args = parser.parse_args()
print('GPU: {}'.format(args.gpu))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')
# Set up a neural network to train
enc = Encoder(in_ch=12)
dec = Decoder(out_ch=3)
dis = Discriminator(in_ch=12, out_ch=3)
if args.gpu >= 0:
chainer.cuda.get_device(args.gpu).use() # Make a specified GPU current
enc.to_gpu() # Copy the model to the GPU
dec.to_gpu()
dis.to_gpu()
# Setup an optimizer
def make_optimizer(model, alpha=0.0002, beta1=0.5):
optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.WeightDecay(0.00001), 'hook_dec')
return optimizer
opt_enc = make_optimizer(enc)
opt_dec = make_optimizer(dec)
opt_dis = make_optimizer(dis)
train_d = FacadeDataset(args.dataset, data_range=(1,300))
test_d = FacadeDataset(args.dataset, data_range=(300,379))
#train_iter = chainer.iterators.MultiprocessIterator(train_d, args.batchsize, n_processes=4)
#test_iter = chainer.iterators.MultiprocessIterator(test_d, args.batchsize, n_processes=4)
train_iter = chainer.iterators.SerialIterator(train_d, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test_d, args.batchsize)
# Set up a trainer
updater = FacadeUpdater(
models=(enc, dec, dis),
iterator={
'main': train_iter,
'test': test_iter},
optimizer={
'enc': opt_enc, 'dec': opt_dec,
'dis': opt_dis},
device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
snapshot_interval = (args.snapshot_interval, 'iteration')
display_interval = (args.display_interval, 'iteration')
trainer.extend(extensions.snapshot(
filename='snapshot_iter_{.updater.iteration}.npz'),
trigger=snapshot_interval)
trainer.extend(extensions.snapshot_object(
enc, 'enc_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
trainer.extend(extensions.snapshot_object(
dec, 'dec_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
trainer.extend(extensions.snapshot_object(
dis, 'dis_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
trainer.extend(extensions.LogReport(trigger=display_interval))
trainer.extend(extensions.PrintReport([
'epoch', 'iteration', 'enc/loss', 'dec/loss', 'dis/loss',
]), trigger=display_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))
trainer.extend(
out_image(
updater, enc, dec,
5, 5, args.seed, args.out),
trigger=snapshot_interval)
if args.resume:
# Resume from a snapshot
chainer.serializers.load_npz(args.resume, trainer)
# Run the training
trainer.run()
if __name__ == '__main__':
main()