-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualize.py
102 lines (80 loc) · 4.48 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from datetime import datetime
from typing import List
import matplotlib.pyplot as plt
import numpy as np
import torch
from joeynmt.constants import BOS_TOKEN, EOS_TOKEN
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from data import Flickr8k
from model import Image2Caption
class NormalizeInverse(transforms.Normalize):
"""
Undoes the normalization and returns the reconstructed images in the input domain.
copied from https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/8
"""
def __init__(self, mean, std):
mean = torch.as_tensor(mean)
std = torch.as_tensor(std)
std_inv = 1 / (std + 1e-7)
mean_inv = -mean * std_inv
super().__init__(mean=mean_inv, std=std_inv)
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
return super().__call__(tensor.clone())
normalize_inverse = NormalizeInverse(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
class Tensorboard:
"""
Tensorboard helper class
"""
def __init__(self, log_dir: str = f'runs/image_captioning_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}', image_idxs: List[int] = None, device: str = 'cpu') -> None:
if image_idxs is None:
image_idxs = [7, 42, 128, 512, 1337]
self.log_dir = log_dir
self.writer = SummaryWriter(log_dir)
self.image_idxs = image_idxs
self.device = device
def add_images_with_ground_truth(self, dataset: Flickr8k) -> None:
for image_idx in self.image_idxs:
img, caption, image_name = dataset[image_idx]
self.writer.add_image(f'image-{image_idx}', normalize_inverse(img).cpu().detach().numpy())
self.writer.add_text(f'image-{image_idx}', ' ' + ' | '.join([' '.join(sentence) for sentence in dataset.corpus.vocab.arrays_to_sentences(dataset.get_all_references_for_image_name(image_name))]), -1)
self.writer.flush()
def add_predicted_text(self, global_step: int, dataset: Flickr8k, model: Image2Caption, max_output_length: int, beam_size: int = 1, beam_alpha: float = 0.4, **kwargs) -> None:
for image_idx in self.image_idxs:
img, _, _ = dataset[image_idx]
img = img.unsqueeze(0).to(self.device)
prediction, attention_scores = model.predict(dataset, img, max_output_length, beam_size, beam_alpha, **kwargs)
decoded_prediction = dataset.corpus.vocab.arrays_to_sentences(prediction)[0]
self.writer.add_text(f'image-{image_idx}', ' ' + ' '.join(decoded_prediction), global_step)
if attention_scores is not None: # only with RecurrentDecoder, TransformerDecoder does not have attention
visualize_attention(img.squeeze(0), decoded_prediction, attention_scores[0], dataset.max_length, f'{self.log_dir}/{image_idx}-step_{global_step:03d}.png')
self.writer.flush()
def visualize_attention(image: torch.Tensor, word_seq: List[str], attention_scores: np.ndarray, max_length: int, file_name: str) -> None:
"""
Generate one image out of all predicted word with their respective attention as an overlay.
:param image: original image used for the prediction
:param word_seq: predicted word sequence
:param attention_scores: attention scores for eachpredictionn
:param max_length: max_length of predicted words (this is to keep all images the same size)
:param file_name: file_name where to save the image to
:return:
"""
image = normalize_inverse(image).cpu().detach().numpy()
image = (image.transpose((1, 2, 0)) * 225).astype(np.uint8)
height = int(np.ceil((max_length + 1) / 5.))
fig, ax = plt.subplots(height, 5, figsize=(height * 3, 15))
[axi.set_axis_off() for axi in ax.ravel()] # hide axes for subplots
ax[0][0].set_title(BOS_TOKEN)
ax[0][0].imshow(image)
if len(word_seq) > 0:
if word_seq[-1] != EOS_TOKEN:
word_seq.append(EOS_TOKEN)
extent = [0, image.shape[0], image.shape[1], 0]
attention_score_shape = [np.round(np.sqrt(attention_scores.shape[-1])).astype(np.uint8)] * 2
for idx, (word, attention_score) in enumerate(zip(word_seq, attention_scores), start=1):
current_axis = ax[idx // 5][idx % 5]
current_axis.set_title(word)
current_axis.imshow(image)
current_axis.imshow(attention_score.reshape(attention_score_shape), cmap='hot', interpolation='bilinear', alpha=0.5, extent=extent, origin='upper')
fig.savefig(file_name)
plt.close(fig)