-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
429 lines (370 loc) · 13.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
import argparse
import itertools
import os
import random
import torch
from torch.backends import cudnn
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm
from cyclegan_pytorch import DecayLR
from cyclegan_pytorch import Discriminator
from cyclegan_pytorch import Generator
from cyclegan_pytorch import ImageDataset
from cyclegan_pytorch import ReplayBuffer
from cyclegan_pytorch import weights_init
writer = SummaryWriter()
parser = argparse.ArgumentParser(
description="PyTorch implements `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks`"
)
parser.add_argument(
"--dataroot",
type=str,
default="./data",
help="path to datasets. (default:./data)",
)
parser.add_argument(
"--dataset",
type=str,
default="horse2zebra",
help="dataset name. (default:`horse2zebra`)"
"Option: [apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, "
"cezanne2photo, ukiyoe2photo, vangogh2photo, maps, facades, selfie2anime, "
"iphone2dslr_flower, ae_photos, ]",
)
parser.add_argument(
"--epochs",
default=200,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--decay_epochs",
type=int,
default=100,
help="epoch to start linearly decaying the learning rate to 0. (default:100)",
)
parser.add_argument(
"-b",
"--batch-size",
default=1,
type=int,
metavar="N",
help="mini-batch size (default: 1), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"--lr", type=float, default=0.0002, help="learning rate. (default:0.0002)"
)
parser.add_argument(
"-p",
"--print-freq",
default=100,
type=int,
metavar="N",
help="print frequency. (default:100)",
)
parser.add_argument("--cuda", action="store_true", help="Enables cuda")
parser.add_argument(
"--netG_A2B", default="", help="path to netG_A2B (to continue training)"
)
parser.add_argument(
"--netG_B2A", default="", help="path to netG_B2A (to continue training)"
)
parser.add_argument(
"--netD_A", default="", help="path to netD_A (to continue training)"
)
parser.add_argument(
"--netD_B", default="", help="path to netD_B (to continue training)"
)
parser.add_argument(
"--image-size",
type=int,
default=256,
help="size of the data crop (squared assumed). (default:256)",
)
parser.add_argument(
"--outf",
default="./outputs",
help="folder to output images. (default:`./outputs`).",
)
parser.add_argument(
"--manualSeed",
type=int,
help="Seed for initializing training. (default:none)",
)
args = parser.parse_args()
print(args)
try:
os.makedirs(args.outf)
except OSError:
pass
try:
os.makedirs("weights")
except OSError:
pass
if args.manualSeed is None:
args.manualSeed = random.randint(1, 10000)
print("Random Seed: ", args.manualSeed)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
cudnn.benchmark = True
if torch.cuda.is_available() and not args.cuda:
print(
"WARNING: You have a CUDA device, so you should probably run with --cuda"
)
# Dataset
dataset = ImageDataset(
root=os.path.join(args.dataroot, args.dataset),
transform=transforms.Compose(
[
# transforms.Resize(256, Image.BICUBIC),
# transforms.Resize(int(args.image_size * 1.12), Image.BICUBIC),
# transforms.RandomCrop(args.image_size),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
),
unaligned=True,
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True
)
try:
os.makedirs(os.path.join(args.outf, args.dataset, "A"))
os.makedirs(os.path.join(args.outf, args.dataset, "B"))
except OSError:
pass
try:
os.makedirs(os.path.join("weights", args.dataset))
except OSError:
pass
device = torch.device("cuda:0" if args.cuda else "cpu")
# create model
netG_A2B = Generator().to(device)
netG_B2A = Generator().to(device)
netD_A = Discriminator().to(device)
netD_B = Discriminator().to(device)
netG_A2B.apply(weights_init)
netG_B2A.apply(weights_init)
netD_A.apply(weights_init)
netD_B.apply(weights_init)
if args.netG_A2B != "":
netG_A2B.load_state_dict(torch.load(args.netG_A2B))
if args.netG_B2A != "":
netG_B2A.load_state_dict(torch.load(args.netG_B2A))
if args.netD_A != "":
netD_A.load_state_dict(torch.load(args.netD_A))
if args.netD_B != "":
netD_B.load_state_dict(torch.load(args.netD_B))
# define loss function (adversarial_loss) and optimizer
cycle_loss = torch.nn.L1Loss().to(device)
identity_loss = torch.nn.L1Loss().to(device)
adversarial_loss = torch.nn.MSELoss().to(device)
# Optimizers
optimizer_G = torch.optim.Adam(
itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
lr=args.lr,
betas=(0.5, 0.999),
)
optimizer_D_A = torch.optim.Adam(
netD_A.parameters(), lr=args.lr, betas=(0.5, 0.999)
)
optimizer_D_B = torch.optim.Adam(
netD_B.parameters(), lr=args.lr, betas=(0.5, 0.999)
)
lr_lambda = DecayLR(args.epochs, 0, args.decay_epochs).step
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
optimizer_G, lr_lambda=lr_lambda
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
optimizer_D_A, lr_lambda=lr_lambda
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
optimizer_D_B, lr_lambda=lr_lambda
)
g_losses = list[float]
d_losses = list[float]
identity_losses = list[float]
gan_losses = list[float]
cycle_losses = list[float]
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
errD_A_best = 1000.0
errD_A_last = 1000.0
best_netG_A2B = None
best_netG_B2A = None
best_netD_A = None
best_netD_B = None
for epoch in range(0, args.epochs):
progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
for i, data in progress_bar:
# get batch size data
real_image_A = data["A"].to(device)
real_image_B = data["B"].to(device)
batch_size = real_image_A.size(0)
# real data label is 1, fake data label is 0.
real_label = torch.full(
(batch_size, 1), 1, device=device, dtype=torch.float32
)
fake_label = torch.full(
(batch_size, 1), 0, device=device, dtype=torch.float32
)
##############################################
# (1) Update G network: Generators A2B and B2A
##############################################
# Set G_A and G_B's gradients to zero
optimizer_G.zero_grad()
# Identity loss
# G_B2A(A) should equal A if real A is fed
identity_image_A = netG_B2A(real_image_A)
loss_identity_A = identity_loss(identity_image_A, real_image_A) * 5.0
# G_A2B(B) should equal B if real B is fed
identity_image_B = netG_A2B(real_image_B)
loss_identity_B = identity_loss(identity_image_B, real_image_B) * 5.0
# GAN loss
# GAN loss D_A(G_A(A))
fake_image_A = netG_B2A(real_image_B)
fake_output_A = netD_A(fake_image_A)
loss_GAN_B2A = adversarial_loss(fake_output_A, real_label)
writer.add_scalar("errG_B2A/train", loss_GAN_B2A, epoch)
# GAN loss D_B(G_B(B))
fake_image_B = netG_A2B(real_image_A)
fake_output_B = netD_B(fake_image_B)
loss_GAN_A2B = adversarial_loss(fake_output_B, real_label)
writer.add_scalar("errD_A2B/train", loss_GAN_A2B, epoch)
# Cycle loss
recovered_image_A = netG_B2A(fake_image_B)
loss_cycle_ABA = cycle_loss(recovered_image_A, real_image_A) * 10.0
recovered_image_B = netG_A2B(fake_image_A)
loss_cycle_BAB = cycle_loss(recovered_image_B, real_image_B) * 10.0
# Combined loss and calculate gradients
errG = (
loss_identity_A
+ loss_identity_B
+ loss_GAN_A2B
+ loss_GAN_B2A
+ loss_cycle_ABA
+ loss_cycle_BAB
)
# Calculate gradients for G_A and G_B
errG.backward()
writer.add_scalar("errG/train", errG, epoch)
# Update G_A and G_B's weights
optimizer_G.step()
##############################################
# (2) Update D network: Discriminator A
##############################################
# Set D_A gradients to zero
optimizer_D_A.zero_grad()
# Real A image loss
real_output_A = netD_A(real_image_A)
errD_real_A = adversarial_loss(real_output_A, real_label)
# Fake A image loss
fake_image_A = fake_A_buffer.push_and_pop(fake_image_A)
fake_output_A = netD_A(fake_image_A.detach())
errD_fake_A = adversarial_loss(fake_output_A, fake_label)
# Combined loss and calculate gradients
errD_A = (errD_real_A + errD_fake_A) / 2
errD_A_last = errD_A
# Calculate gradients for D_A
errD_A.backward()
# Update D_A weights
optimizer_D_A.step()
writer.add_scalar("errD_A/train", errD_A, epoch)
##############################################
# (3) Update D network: Discriminator B
##############################################
# Set D_B gradients to zero
optimizer_D_B.zero_grad()
# Real B image loss
real_output_B = netD_B(real_image_B)
errD_real_B = adversarial_loss(real_output_B, real_label)
# Fake B image loss
fake_image_B = fake_B_buffer.push_and_pop(fake_image_B)
fake_output_B = netD_B(fake_image_B.detach())
errD_fake_B = adversarial_loss(fake_output_B, fake_label)
# Combined loss and calculate gradients
errD_B = (errD_real_B + errD_fake_B) / 2
# Calculate gradients for D_B
errD_B.backward()
# Update D_B weights
optimizer_D_B.step()
writer.add_scalar("errD_B/train", errD_B, epoch)
progress_bar.set_description(
f"[{epoch}/{args.epochs - 1}][{i}/{len(dataloader) - 1}] "
f"Loss_D: {(errD_A + errD_B).item():.4f} "
f"Loss_G: {errG.item():.4f} "
f"Loss_G_identity: {(loss_identity_A + loss_identity_B).item():.4f} "
f"loss_G_GAN: {(loss_GAN_A2B + loss_GAN_B2A).item():.4f} "
f"loss_G_cycle: {(loss_cycle_ABA + loss_cycle_BAB).item():.4f}"
)
# if i % args.print_freq == 0:
# vutils.save_image(real_image_A,
# f"{args.outf}/{args.dataset}/A/real_samples_epoch_{epoch}_{i}.png",
# normalize=True)
# vutils.save_image(real_image_B,
# f"{args.outf}/{args.dataset}/B/real_samples_epoch_{epoch}_{i}.png",
# normalize=True)
# fake_image_A = 0.5 * (netG_B2A(real_image_B).data + 1.0)
# fake_image_B = 0.5 * (netG_A2B(real_image_A).data + 1.0)
# vutils.save_image(fake_image_A.detach(),
# f"{args.outf}/{args.dataset}/A/fake_samples_epoch_{epoch}_{i}.png",
# normalize=True)
# vutils.save_image(fake_image_B.detach(),
# f"{args.outf}/{args.dataset}/B/fake_samples_epoch_{epoch}_{i}.png",
# normalize=True)
if errD_A_best > errD_A_last:
best_errD_A = errD_A_last
best_netG_A2B = netG_A2B.state_dict()
best_netG_B2A = netG_A2B.state_dict()
best_netD_A = netG_A2B.state_dict()
best_netD_B = netG_A2B.state_dict()
if epoch % 10 == 0:
# do check pointing
torch.save(
netG_A2B.state_dict(),
f"weights/{args.dataset}/last_netG_A2B_epoch_{epoch}.pth",
)
torch.save(
netG_B2A.state_dict(),
f"weights/{args.dataset}/last_netG_B2A_epoch_{epoch}.pth",
)
torch.save(
netD_A.state_dict(),
f"weights/{args.dataset}/last_netD_A_epoch_{epoch}.pth",
)
torch.save(
netD_B.state_dict(),
f"weights/{args.dataset}/last_netD_B_epoch_{epoch}.pth",
)
# do best check pointing
torch.save(
best_netG_A2B,
f"weights/{args.dataset}/best_netG_A2B_epoch_{epoch}.pth",
)
torch.save(
best_netG_B2A,
f"weights/{args.dataset}/best_netG_B2A_epoch_{epoch}.pth",
)
torch.save(
best_netD_A,
f"weights/{args.dataset}/best_netD_A_epoch_{epoch}.pth",
)
torch.save(
best_netD_B,
f"weights/{args.dataset}/best_netD_B_epoch_{epoch}.pth",
)
# Update learning rates
lr_scheduler_G.step()
lr_scheduler_D_A.step()
lr_scheduler_D_B.step()
# save last check pointing
torch.save(netG_A2B.state_dict(), f"weights/{args.dataset}/netG_A2B.pth")
torch.save(netG_B2A.state_dict(), f"weights/{args.dataset}/netG_B2A.pth")
torch.save(netD_A.state_dict(), f"weights/{args.dataset}/netD_A.pth")
torch.save(netD_B.state_dict(), f"weights/{args.dataset}/netD_B.pth")
writer.flush()