|
6 | 6 | from skimage import img_as_ubyte
|
7 | 7 |
|
8 | 8 | from torchvision.ops import roi_align
|
9 |
| -from collections import OrderedDict |
10 | 9 |
|
| 10 | +from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel |
| 11 | +from tqdm.auto import tqdm |
| 12 | + |
| 13 | +from collections import OrderedDict |
11 | 14 | from survos2.entity.cluster.cnn_features import CNNFeatures, prepare_3channel
|
12 | 15 | from survos2.entity.cluster.utils import get_surface
|
13 | 16 | from skimage.feature import hog
|
@@ -94,14 +97,14 @@ def extract_hog_features(selected_images, gpu_id=0):
|
94 | 97 |
|
95 | 98 | for i, img in enumerate(selected_images):
|
96 | 99 | fv = []
|
97 |
| - # img_3channel = np.stack((img, img, img)).T |
| 100 | + |
98 | 101 | fd = hog(
|
99 |
| - img, # img_3channel, |
| 102 | + img, |
100 | 103 | orientations=8,
|
101 | 104 | pixels_per_cell=(8, 8),
|
102 | 105 | cells_per_block=(1, 1),
|
103 | 106 | visualize=False,
|
104 |
| - ) # , channel_axis=-1) |
| 107 | + ) |
105 | 108 | fv.extend(fd)
|
106 | 109 | vec_mat[i, 0 : len(fv)] = fv
|
107 | 110 |
|
@@ -132,24 +135,53 @@ def extract_cnn_features2(patch_dict, gpu_id=0):
|
132 | 135 | return vec_mat, selected_images
|
133 | 136 |
|
134 | 137 |
|
| 138 | + |
| 139 | +def extract_CLIP_features(selected_images, gpu_id=0, batch_size = 16): |
| 140 | + feature_mat = None |
| 141 | + selected_3channel = [np.stack((img, img, img)).T for img in selected_images] |
| 142 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 143 | + model_id = "openai/clip-vit-base-patch32" |
| 144 | + processor = CLIPProcessor.from_pretrained(model_id) |
| 145 | + model = CLIPModel.from_pretrained(model_id).to(device) |
| 146 | + image = processor( |
| 147 | + text=None, |
| 148 | + images=selected_3channel[0], |
| 149 | + return_tensors='pt', |
| 150 | + do_rescale=False)['pixel_values'].to(device) |
| 151 | + |
| 152 | + for i in tqdm(range(0, len(selected_3channel), batch_size)): |
| 153 | + batch = selected_3channel[i:i+batch_size] |
| 154 | + batch = processor( |
| 155 | + text=None, |
| 156 | + images=batch, |
| 157 | + return_tensors='pt', |
| 158 | + padding=True |
| 159 | + )['pixel_values'].to(device) |
| 160 | + batch_emb = model.get_image_features(pixel_values=batch) |
| 161 | + batch_emb = batch_emb.squeeze(0) |
| 162 | + batch_emb = batch_emb.cpu().detach().numpy() |
| 163 | + if feature_mat is None: |
| 164 | + feature_mat = batch_emb |
| 165 | + else: |
| 166 | + feature_mat = np.concatenate((feature_mat, batch_emb), axis=0) |
| 167 | + |
| 168 | + return feature_mat |
| 169 | + |
135 | 170 | def patch_2d_features(img_volume, bvol_table, gpu_id=0, padvol=10):
|
136 | 171 | patch_dict, bvol_info = prepare_patch_dict(raligned, raligned_labels, bvol_info)
|
137 | 172 | print(f"Number of roi extracted {len(patch_dict.keys())}")
|
138 | 173 | vec_mat = extract_cnn_features(selected_images, gpu_id)
|
139 |
| - |
140 | 174 | return vec_mat, selected_images
|
141 | 175 |
|
142 | 176 |
|
143 | 177 | def patch_2d_features2(img_volume, bvol_table, gpu_id=0, padvol=10):
|
144 | 178 | # patch_dict, bbs_info = sample_patch_roi(img_vol, bvol_table)
|
145 |
| - |
146 | 179 | raligned, raligned_labels, bvol_info = roi_pool_vol(img_volume, [bvol_table], padvol=padvol)
|
147 | 180 | patch_dict, bvol_info = prepare_patch_dict(raligned, raligned_labels, bvol_info)
|
148 | 181 | print(f"Number of roi extracted {len(patch_dict.keys())}")
|
149 | 182 | vec_mat, selected_images = extract_cnn_features(patch_dict, gpu_id)
|
150 | 183 | return vec_mat, selected_images, patch_dict, bvol_info
|
151 | 184 |
|
152 |
| - |
153 | 185 | def prepare_patch_dict(raligned, raligned_labels, bvol_info):
|
154 | 186 | patch_dict = OrderedDict()
|
155 | 187 | for r, rlabel, bv_info in zip(raligned, raligned_labels, bvol_info):
|
|
0 commit comments