Skip to content

Commit d4ce84b

Browse files
committed
Cluster with CLIP features.
1 parent 4922373 commit d4ce84b

File tree

3 files changed

+41
-7
lines changed

3 files changed

+41
-7
lines changed

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ torchio
7676
torchmetrics
7777
tqdm
7878
traitlets
79+
transformers
7980
trio
8081
typing_extensions
8182
umap-learn

requirements_windows.txt

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ torchio
7676
torchmetrics
7777
tqdm
7878
traitlets
79+
transformers
7980
trio
8081
typing_extensions
8182
umap-learn

survos2/entity/cluster/patch_cluster.py

+39-7
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
from skimage import img_as_ubyte
77

88
from torchvision.ops import roi_align
9-
from collections import OrderedDict
109

10+
from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel
11+
from tqdm.auto import tqdm
12+
13+
from collections import OrderedDict
1114
from survos2.entity.cluster.cnn_features import CNNFeatures, prepare_3channel
1215
from survos2.entity.cluster.utils import get_surface
1316
from skimage.feature import hog
@@ -94,14 +97,14 @@ def extract_hog_features(selected_images, gpu_id=0):
9497

9598
for i, img in enumerate(selected_images):
9699
fv = []
97-
# img_3channel = np.stack((img, img, img)).T
100+
98101
fd = hog(
99-
img, # img_3channel,
102+
img,
100103
orientations=8,
101104
pixels_per_cell=(8, 8),
102105
cells_per_block=(1, 1),
103106
visualize=False,
104-
) # , channel_axis=-1)
107+
)
105108
fv.extend(fd)
106109
vec_mat[i, 0 : len(fv)] = fv
107110

@@ -132,24 +135,53 @@ def extract_cnn_features2(patch_dict, gpu_id=0):
132135
return vec_mat, selected_images
133136

134137

138+
139+
def extract_CLIP_features(selected_images, gpu_id=0, batch_size = 16):
140+
feature_mat = None
141+
selected_3channel = [np.stack((img, img, img)).T for img in selected_images]
142+
device = "cuda" if torch.cuda.is_available() else "cpu"
143+
model_id = "openai/clip-vit-base-patch32"
144+
processor = CLIPProcessor.from_pretrained(model_id)
145+
model = CLIPModel.from_pretrained(model_id).to(device)
146+
image = processor(
147+
text=None,
148+
images=selected_3channel[0],
149+
return_tensors='pt',
150+
do_rescale=False)['pixel_values'].to(device)
151+
152+
for i in tqdm(range(0, len(selected_3channel), batch_size)):
153+
batch = selected_3channel[i:i+batch_size]
154+
batch = processor(
155+
text=None,
156+
images=batch,
157+
return_tensors='pt',
158+
padding=True
159+
)['pixel_values'].to(device)
160+
batch_emb = model.get_image_features(pixel_values=batch)
161+
batch_emb = batch_emb.squeeze(0)
162+
batch_emb = batch_emb.cpu().detach().numpy()
163+
if feature_mat is None:
164+
feature_mat = batch_emb
165+
else:
166+
feature_mat = np.concatenate((feature_mat, batch_emb), axis=0)
167+
168+
return feature_mat
169+
135170
def patch_2d_features(img_volume, bvol_table, gpu_id=0, padvol=10):
136171
patch_dict, bvol_info = prepare_patch_dict(raligned, raligned_labels, bvol_info)
137172
print(f"Number of roi extracted {len(patch_dict.keys())}")
138173
vec_mat = extract_cnn_features(selected_images, gpu_id)
139-
140174
return vec_mat, selected_images
141175

142176

143177
def patch_2d_features2(img_volume, bvol_table, gpu_id=0, padvol=10):
144178
# patch_dict, bbs_info = sample_patch_roi(img_vol, bvol_table)
145-
146179
raligned, raligned_labels, bvol_info = roi_pool_vol(img_volume, [bvol_table], padvol=padvol)
147180
patch_dict, bvol_info = prepare_patch_dict(raligned, raligned_labels, bvol_info)
148181
print(f"Number of roi extracted {len(patch_dict.keys())}")
149182
vec_mat, selected_images = extract_cnn_features(patch_dict, gpu_id)
150183
return vec_mat, selected_images, patch_dict, bvol_info
151184

152-
153185
def prepare_patch_dict(raligned, raligned_labels, bvol_info):
154186
patch_dict = OrderedDict()
155187
for r, rlabel, bv_info in zip(raligned, raligned_labels, bvol_info):

0 commit comments

Comments
 (0)