forked from PalAvik/hycoclip
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinterpolating_points.py
415 lines (339 loc) · 16.6 KB
/
interpolating_points.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
"""
Interpolate between points using a trained HyCoCLIP, MERU or CLIP model,
and a pool of text and images (and their encoded representations).
"""
from __future__ import annotations
import os
import cv2
import numpy as np
import argparse
import pandas as pd
import json
from io import BytesIO
from PIL import Image
import base64
from enum import Enum
import albumentations as A
from tqdm import tqdm
import torch
from torchvision import transforms as T
from huggingface_hub import snapshot_download
from hycoclip import lorentz as L
from hycoclip.config import LazyConfig, LazyFactory
from hycoclip.models import HyCoCLIP, MERU, CLIPBaseline
from hycoclip.utils.checkpointing import CheckpointManager
from hycoclip.tokenizer import Tokenizer
parser = argparse.ArgumentParser(description=__doc__)
_AA = parser.add_argument
_AA("--checkpoint-path", help="Path to checkpoint of a trained HyCoCLIP/MERU/CLIP model.")
_AA("--train-config", help="Path to train config (.yaml/py) for given checkpoint.")
_AA("--image-path", help="Path to an image (.jpg) for perfoming traversal.")
_AA("--target-image-path", help="Path to an image (.jpg) for perfoming traversal to this target image.")
_AA("--steps", type=int, default=50, help="Number of traversal steps.")
_AA("--download-data", action='store_true', help="Download the data for the Flickr dataset.")
_AA("--data-path", help="Path to download the data for the Flickr dataset.")
_AA("--feats-path", help="Path to the features for the Flickr dataset.")
_AA("--image-to-image-traversal", action='store_true', help="Do interpolated traversal from image to image.")
def interpolate(model, feats: torch.Tensor, root_feat: torch.Tensor, steps: int):
"""
Interpolate between given feature vector and `[ROOT]` depending on model type.
"""
# Linear interpolation between root and image features. For HyCoCLIP and MERU,
# this happens in the tangent space of the origin.
if isinstance(model, (HyCoCLIP, MERU)):
feats = L.log_map0(feats, model.curv.exp())
interp_feats = [
torch.lerp(root_feat, feats, weight.item())
for weight in torch.linspace(0.0, 1.0, steps=steps)
]
interp_feats = torch.stack(interp_feats)
# Lift on the Hyperboloid (for HyCoCLIP and MERU), or L2 normalize (for CLIP).
if isinstance(model, (HyCoCLIP, MERU)):
feats = L.log_map0(feats, model.curv.exp())
interp_feats = L.exp_map0(interp_feats, model.curv.exp())
else:
interp_feats = torch.nn.functional.normalize(interp_feats, dim=-1)
# Reverse the traversal order: (image first, root last)
return interp_feats.flip(0)
def calc_scores(
model, image_feats: torch.Tensor, all_feats: torch.Tensor, has_root: bool
):
"""
Calculate similarity scores between the input image and dataset features depending
on model type.
Args:
has_root: Flag to indicate whether the last text embedding (at dim=0)
is the `[ROOT]` embedding.
"""
all_scores = []
if isinstance(model, (HyCoCLIP, MERU)):
for feats_batch in all_feats.split(65536):
scores = L.pairwise_inner(image_feats, feats_batch, model.curv.exp())
all_scores.append(scores)
all_scores = torch.cat(all_scores, dim=1)
return all_scores
else:
# model is not needed here.
return image_feats @ all_feats.T
_INTER_STR_TO_CV2 = {
"nearest": cv2.INTER_NEAREST,
"linear": cv2.INTER_LINEAR,
"bilinear": cv2.INTER_LINEAR,
"cubic": cv2.INTER_CUBIC,
"bicubic": cv2.INTER_CUBIC,
"area": cv2.INTER_AREA,
"lanczos": cv2.INTER_LANCZOS4,
"lanczos4": cv2.INTER_LANCZOS4,
}
def inter_str_to_cv2(inter_str):
inter_str = inter_str.lower()
if inter_str not in _INTER_STR_TO_CV2:
raise ValueError(f"Invalid option for interpolation: {inter_str}")
return _INTER_STR_TO_CV2[inter_str]
class ResizeMode(Enum):
no = 0 # pylint: disable=invalid-name
keep_ratio = 1 # pylint: disable=invalid-name
center_crop = 2 # pylint: disable=invalid-name
border = 3 # pylint: disable=invalid-name
keep_ratio_largest = 4 # pylint: disable=invalid-name
class Resizer:
def __init__(
self,
image_size,
resize_mode,
resize_only_if_bigger,
upscale_interpolation="lanczos",
downscale_interpolation="area",
encode_quality=95,
skip_reencode=False,
min_image_size=0,
max_image_area=float("inf"),
max_aspect_ratio=float("inf"),
):
self.image_size = image_size
if isinstance(resize_mode, str):
if resize_mode not in ResizeMode.__members__: # pylint: disable=unsupported-membership-test
raise ValueError(f"Invalid option for resize_mode: {resize_mode}")
resize_mode = ResizeMode[resize_mode]
self.resize_mode = resize_mode
self.min_image_size = min_image_size
self.max_image_area = max_image_area
self.max_aspect_ratio = max_aspect_ratio
self.resize_only_if_bigger = resize_only_if_bigger
cv2_img_quality = int(cv2.IMWRITE_JPEG_QUALITY)
self.encode_params = [cv2_img_quality, encode_quality]
self.what_ext = "jpeg"
self.skip_reencode = skip_reencode
self.upscale_interpolation = inter_str_to_cv2(upscale_interpolation)
self.downscale_interpolation = inter_str_to_cv2(downscale_interpolation)
def __call__(self, img):
cv2.setNumThreads(1)
if img is None:
raise ValueError("Image decoding error")
if len(img.shape) == 3 and img.shape[-1] == 4:
# alpha matting with white background
alpha = img[:, :, 3, np.newaxis]
img = alpha / 255 * img[..., :3] + 255 - alpha
img = np.rint(img.clip(min=0, max=255)).astype(np.uint8)
original_height, original_width = img.shape[:2]
# check if image is too small
if min(original_height, original_width) < self.min_image_size:
return None, None, None, None, None, "image too small"
if original_height * original_width > self.max_image_area:
return None, None, None, None, None, "image area too large"
# check if wrong aspect ratio
if max(original_height, original_width) / min(original_height, original_width) > self.max_aspect_ratio:
return None, None, None, None, None, "aspect ratio too large"
# resizing in following conditions
if self.resize_mode in (ResizeMode.keep_ratio, ResizeMode.center_crop):
downscale = min(original_width, original_height) > self.image_size
if not self.resize_only_if_bigger or downscale:
interpolation = self.downscale_interpolation if downscale else self.upscale_interpolation
img = A.smallest_max_size(img, self.image_size, interpolation=interpolation)
if self.resize_mode == ResizeMode.center_crop:
img = A.center_crop(img, self.image_size, self.image_size)
elif self.resize_mode in (ResizeMode.border, ResizeMode.keep_ratio_largest):
downscale = max(original_width, original_height) > self.image_size
if not self.resize_only_if_bigger or downscale:
interpolation = self.downscale_interpolation if downscale else self.upscale_interpolation
img = A.longest_max_size(img, self.image_size, interpolation=interpolation)
if self.resize_mode == ResizeMode.border:
img = A.pad(
img,
self.image_size,
self.image_size,
border_mode=cv2.BORDER_CONSTANT,
value=[255, 255, 255],
)
height, width = img.shape[:2]
# img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = Image.fromarray(img)
return img, width, height, original_width, original_height, None
@torch.inference_mode()
def get_data_feats(device,
resizer: Resizer,
tsv_path: str,
model: HyCoCLIP | MERU | CLIPBaseline) -> tuple[list[str], torch.Tensor]:
tokenizer = Tokenizer()
image_transform = T.Compose(
[T.Resize(224, T.InterpolationMode.BICUBIC), T.CenterCrop(224), T.ToTensor()]
)
crop_dim_cutoff = 32*32
total_shards = 17 # Number of shards in the dataset. 17 for Flickr grounded dataset.
min_batch_size = 512
previous_img_file_name = ""
item_list = []
representation_list = []
images_to_encode = []
texts_to_encode = []
image_items = []
text_items = []
for shard_num in range(total_shards):
sample_train_shard = pd.read_csv(f'{tsv_path}/train-{shard_num:04d}.tsv', sep='\t', header=None, names=['data_id', 'data'])
samples_in_shard = len(sample_train_shard['data'])
for idx in tqdm(range(samples_in_shard), desc=f"Processing shard {shard_num+1}/{total_shards}"):
sample_data = sample_train_shard['data'][idx]
sample_data = json.loads(sample_data)
img_file_name = sample_data['file_name']
if img_file_name != previous_img_file_name:
previous_img_file_name = img_file_name
image_bytes = base64.b64decode(bytes(sample_data['image'], encoding='raw_unicode_escape'))
image = Image.open(BytesIO(image_bytes))
image_tensor, _, _, _, _, error = resizer(np.array(image))
image_tensor = image_transform(image_tensor)
images_to_encode.append(image_tensor)
image_items.append(img_file_name)
bboxes_of_image = []
texts_of_image = []
img_caption = sample_data['caption']
texts_to_encode.append(img_caption)
text_items.append(img_caption)
for n, annotation in enumerate(sample_data['annos']):
bbox_dims = annotation['bbox']
if str(bbox_dims) not in bboxes_of_image:
bboxes_of_image.append(str(bbox_dims))
left = bbox_dims[0]
top = bbox_dims[1]
right = left + bbox_dims[2]
bottom = top + bbox_dims[3]
if (right-left)*(bottom-top) >= crop_dim_cutoff:
entity_image = image.crop((left, top, right, bottom))
entity_image, _, _, _, _, error = resizer(np.array(entity_image))
entity_image = image_transform(entity_image)
images_to_encode.append(entity_image)
image_items.append(f'{img_file_name}_{str(bbox_dims)}')
tokens_positive = annotation['tokens_positive'][0]
entity_text = img_caption[tokens_positive[0]:tokens_positive[1]]
if entity_text not in texts_of_image:
texts_of_image.append(entity_text)
texts_to_encode.append(entity_text)
text_items.append(entity_text)
if len(images_to_encode) >= min_batch_size:
representation_list.append(model.encode_image(torch.stack(images_to_encode).to(device), project=True))
item_list.extend(image_items)
images_to_encode = []
image_items = []
if len(texts_to_encode) >= min_batch_size:
text_tokens = tokenizer(texts_to_encode)
representation_list.append(model.encode_text(text_tokens, project=True))
item_list.extend(text_items)
texts_to_encode = []
text_items = []
representation_list.append(model.encode_image(torch.stack(images_to_encode).to(device), project=True))
item_list.extend(image_items)
text_tokens = tokenizer(texts_to_encode)
representation_list.append(model.encode_text(text_tokens, project=True))
item_list.extend(text_items)
item_feats = torch.cat(representation_list, dim=0)
return item_list, item_feats
@torch.inference_mode()
def main(_A: argparse.Namespace):
resizer = Resizer(
image_size=224,
resize_mode="border",
resize_only_if_bigger=False,
upscale_interpolation="lanczos",
downscale_interpolation="area",
encode_quality=95,
skip_reencode=False,
min_image_size=0,
max_image_area=float("inf"),
max_aspect_ratio=float("inf"),
)
# Get the current device (this will be `cuda:0` here by default) or use CPU.
device = (
torch.cuda.current_device()
if torch.cuda.is_available()
else torch.device("cpu")
)
# Create the model using training config and load pre-trained weights.
_C_TRAIN = LazyConfig.load(_A.train_config)
model = LazyFactory.build_model(_C_TRAIN, device).eval()
CheckpointManager(model=model).load(_A.checkpoint_path)
if isinstance(model, (HyCoCLIP, MERU)):
root_feat = torch.zeros(_C_TRAIN.model.embed_dim, device=device)
else:
# CLIP model checkpoint should have the 'root' embedding.
root_feat = torch.load(_A.checkpoint_path)["root"].to(device)
if not os.path.exists(_A.feats_path):
if _A.download_data:
snapshot_download(repo_id="gligen/flickr_tsv", repo_type="dataset", local_dir=_A.data_path, local_dir_use_symlinks=False)
item_list, item_feats = get_data_feats(device, resizer, _A.data_path, model)
torch.save((item_list, item_feats), _A.feats_path)
else:
item_list, item_feats = torch.load(_A.feats_path)
# Add [ROOT] to the pool of text feats.
item_list.append("[ROOT]")
item_feats = torch.cat([item_feats, root_feat[None, ...]])
print(f"Total items in item_list: {len(item_list)}")
print(f"Size of item_feats: {item_feats.size()}")
# ------------------------------------------------------------------------
print(f"\nPerforming image to root traversals with source: {_A.image_path}...")
# ------------------------------------------------------------------------
image = Image.open(_A.image_path).convert("RGB")
image_transform = T.Compose(
[T.Resize(224, T.InterpolationMode.BICUBIC), T.CenterCrop(224), T.ToTensor()]
)
image, _, _, _, _, error = resizer(np.array(image))
image = image_transform(image).to(device)
image_feats = model.encode_image(image[None, ...], project=True)[0]
interp_feats = interpolate(model, image_feats, root_feat, _A.steps)
nn1_scores = calc_scores(model, interp_feats, item_feats, has_root=True)
nn1_scores, _nn1_idxs = nn1_scores.max(dim=-1)
nn1_texts = [item_list[_idx.item()] for _idx in _nn1_idxs]
# De-duplicate retrieved texts (multiple points may have same NN) and print.
print(f"Texts retrieved from [IMAGE] -> [ROOT] traversal:")
unique_nn1_texts = []
for _text in nn1_texts:
if _text not in unique_nn1_texts:
unique_nn1_texts.append(_text)
print(f" - {_text}")
if _A.image_to_image_traversal:
# ------------------------------------------------------------------------
print(f"\nPerforming image to image traversals with source: {_A.image_path} and target: {_A.target_image_path}...")
# ------------------------------------------------------------------------
image = Image.open(_A.image_path).convert("RGB")
target_image = Image.open(_A.target_image_path).convert("RGB")
image_transform = T.Compose(
[T.Resize(224, T.InterpolationMode.BICUBIC), T.CenterCrop(224), T.ToTensor()]
)
image, _, _, _, _, error = resizer(np.array(image))
image = image_transform(image).to(device)
image_feats = model.encode_image(image[None, ...], project=True)[0]
target_image, _, _, _, _, error = resizer(np.array(target_image))
target_image = image_transform(target_image).to(device)
target_image_feats = model.encode_image(target_image[None, ...], project=True)[0]
interp_feats = interpolate(model, image_feats, target_image_feats, _A.steps)
nn1_scores = calc_scores(model, interp_feats, item_feats, has_root=True)
nn1_scores, _nn1_idxs = nn1_scores.max(dim=-1)
nn1_texts = [item_list[_idx.item()] for _idx in _nn1_idxs]
# De-duplicate retrieved texts (multiple points may have same NN) and print.
print(f"Texts retrieved from [SOURCE IMAGE] -> [TARGET IMAGE] traversal:")
unique_nn1_texts = []
for _text in nn1_texts:
if _text not in unique_nn1_texts:
unique_nn1_texts.append(_text)
print(f" - {_text}")
if __name__ == "__main__":
_A = parser.parse_args()
main(_A)