Skip to content

Commit

Permalink
Merge pull request #21 from DiamondLightSource/ClusterChange
Browse files Browse the repository at this point in the history
Cluster with CLIP features.
  • Loading branch information
penningavery authored Oct 11, 2023
2 parents 4922373 + d4ce84b commit 3599b71
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 7 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ torchio
torchmetrics
tqdm
traitlets
transformers
trio
typing_extensions
umap-learn
Expand Down
1 change: 1 addition & 0 deletions requirements_windows.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ torchio
torchmetrics
tqdm
traitlets
transformers
trio
typing_extensions
umap-learn
Expand Down
46 changes: 39 additions & 7 deletions survos2/entity/cluster/patch_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from skimage import img_as_ubyte

from torchvision.ops import roi_align
from collections import OrderedDict

from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel
from tqdm.auto import tqdm

from collections import OrderedDict
from survos2.entity.cluster.cnn_features import CNNFeatures, prepare_3channel
from survos2.entity.cluster.utils import get_surface
from skimage.feature import hog
Expand Down Expand Up @@ -94,14 +97,14 @@ def extract_hog_features(selected_images, gpu_id=0):

for i, img in enumerate(selected_images):
fv = []
# img_3channel = np.stack((img, img, img)).T

fd = hog(
img, # img_3channel,
img,
orientations=8,
pixels_per_cell=(8, 8),
cells_per_block=(1, 1),
visualize=False,
) # , channel_axis=-1)
)
fv.extend(fd)
vec_mat[i, 0 : len(fv)] = fv

Expand Down Expand Up @@ -132,24 +135,53 @@ def extract_cnn_features2(patch_dict, gpu_id=0):
return vec_mat, selected_images



def extract_CLIP_features(selected_images, gpu_id=0, batch_size = 16):
feature_mat = None
selected_3channel = [np.stack((img, img, img)).T for img in selected_images]
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_id)
model = CLIPModel.from_pretrained(model_id).to(device)
image = processor(
text=None,
images=selected_3channel[0],
return_tensors='pt',
do_rescale=False)['pixel_values'].to(device)

for i in tqdm(range(0, len(selected_3channel), batch_size)):
batch = selected_3channel[i:i+batch_size]
batch = processor(
text=None,
images=batch,
return_tensors='pt',
padding=True
)['pixel_values'].to(device)
batch_emb = model.get_image_features(pixel_values=batch)
batch_emb = batch_emb.squeeze(0)
batch_emb = batch_emb.cpu().detach().numpy()
if feature_mat is None:
feature_mat = batch_emb
else:
feature_mat = np.concatenate((feature_mat, batch_emb), axis=0)

return feature_mat

def patch_2d_features(img_volume, bvol_table, gpu_id=0, padvol=10):
patch_dict, bvol_info = prepare_patch_dict(raligned, raligned_labels, bvol_info)
print(f"Number of roi extracted {len(patch_dict.keys())}")
vec_mat = extract_cnn_features(selected_images, gpu_id)

return vec_mat, selected_images


def patch_2d_features2(img_volume, bvol_table, gpu_id=0, padvol=10):
# patch_dict, bbs_info = sample_patch_roi(img_vol, bvol_table)

raligned, raligned_labels, bvol_info = roi_pool_vol(img_volume, [bvol_table], padvol=padvol)
patch_dict, bvol_info = prepare_patch_dict(raligned, raligned_labels, bvol_info)
print(f"Number of roi extracted {len(patch_dict.keys())}")
vec_mat, selected_images = extract_cnn_features(patch_dict, gpu_id)
return vec_mat, selected_images, patch_dict, bvol_info


def prepare_patch_dict(raligned, raligned_labels, bvol_info):
patch_dict = OrderedDict()
for r, rlabel, bv_info in zip(raligned, raligned_labels, bvol_info):
Expand Down

0 comments on commit 3599b71

Please sign in to comment.