-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
62 lines (40 loc) · 1.71 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger
from ignite.engine import Engine,Events
from torchvision.utils import make_grid
from ignite.contrib.handlers import tqdm_logger
from data import valid_dl
import torch
def attach_ignite(
trainer:Engine,
gen
):
tb_logger = TensorboardLogger(log_dir ='./pix2pix_log')
tqdm_train = tqdm_logger.ProgressBar().attach(trainer,output_transform=lambda x:x)
tb_logger.attach_output_handler(
engine=trainer,
event_name=Events.EPOCH_COMPLETED,
tag='train',
output_transform=lambda x: {
'g_loss':x['generator_loss'],
'd_loss':x['discriminator_loss']
}
)
def log_generated_images(engine, logger, gen, epoch):
gen.eval()
with torch.no_grad():
batch = next(iter(valid_dl))
input_images, real_output_images = batch
# Generate fake images
fake_imgs = gen(input_images)
# Prepare the images to be logged
input_grid = make_grid(input_images, normalize=True, value_range=(-1, 1))
fake_grid = make_grid(fake_imgs, normalize=True, value_range=(-1, 1))
real_grid = make_grid(real_output_images, normalize=True, value_range=(-1, 1))
# Log the images
logger.writer.add_image('input_images', input_grid, epoch)
logger.writer.add_image('fake_images', fake_grid, epoch)
logger.writer.add_image('real_images', real_grid, epoch)
@trainer.on(Events.EPOCH_COMPLETED)
def log_images(engine):
epoch = engine.state.epoch
log_generated_images(engine, tb_logger, gen, epoch)