Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to evaluate the pretrained model on the jumpcp dataset. #14

Open
Algolzw opened this issue Jan 6, 2025 · 1 comment
Open

How to evaluate the pretrained model on the jumpcp dataset. #14

Algolzw opened this issue Jan 6, 2025 · 1 comment

Comments

@Algolzw
Copy link

Algolzw commented Jan 6, 2025

Hi, thanks for your great work! Now I am trying to test the trained model (with weight cpjump_cellpaint_bf_channelvit_small_p8_with_hcs_supervised) on jumpcp with my custom dataloader (since I don't want so many config files in my project). The code for data generation is:

def __getitem__(self, index):
        if self.well_loc[index] not in self.well2lbl[self.perturbation_type]:
            # this well is not labeled
            return None

        image = self.read_im(self.data_path[index]) #/ 255.
        image = self.transform(image)
        label = self.well2lbl[self.perturbation_type][self.well_loc[index]]

        return {'image': image, 'label': label, 'channels': np.arange(8)}

But the results are not good (~0.2% accuracy). Can you give me some suggestion how to apply the trained model with a custom dataloader (or how to design the dataloader)

@Algolzw
Copy link
Author

Algolzw commented Jan 6, 2025

And here is the whole code for the dataset in which I load the image locally for IO efficiency:

class JUMPCPDataset(Dataset):
    def __init__(self, root, cyto_mask_path, split='train', perturbation_type='compound'):
        super().__init__()
        self.root = root
        df = pd.read_parquet(cyto_mask_path)
        df = self.get_split(df, split)

        self.data_path = list(df["path"])
        self.data_id = list(df["ID"])
        self.well_loc = list(df["well_loc"])
        self.perturbation_type = perturbation_type
        # self.transform = train_transform() if split == 'train' else test_transform()
        self.transform = CellAugmentation(
                            is_train=(split == 'train'), 
                            global_resize=224,
                            normalization_mean=jumpcp_mean_,
                            normalization_std=jumpcp_stds_)

        self.plate2id, self.field2id, self.well2id, self.well2lbl = load_meta_data()

    def get_split(self, df, split_name, seed=0):
        ###### split copy from ChannelViT #####
        np.random.seed(seed)
        perm = np.random.permutation(df.index)
        m = len(df.index)
        train_end = int(0.6 * m)
        validate_end = int(0.2 * m) + train_end

        if split_name == "train":
            return df.iloc[perm[:train_end]]
        elif split_name == "valid":
            return df.iloc[perm[train_end:validate_end]]
        elif split_name == "test":
            return df.iloc[perm[validate_end:]]
        else:
            raise ValueError("Unknown split")

    def read_im(self, file_path):
        file_path = file_path.replace('s3://insitro-research-2023-context-vit/jumpcp', self.root)
        image = np.load(file_path, allow_pickle=True) ### jumpcp uses .npy image file
        channel_dim = np.argmin(image.shape)
        image = np.moveaxis(image, channel_dim, -1) if channel_dim != -1 else image
        # return torch.tensor(image, dtype=torch.float) # to tensor
        return image # numpy

    def __len__(self):
        return len(self.data_path)

    def __getitem__(self, index):
        if self.well_loc[index] not in self.well2lbl[self.perturbation_type]:
            # this well is not labeled
            return None

        image = self.read_im(self.data_path[index]) #/ 255.
        image = self.transform(image)
        label = self.well2lbl[self.perturbation_type][self.well_loc[index]]

        return {'image': image, 'label': label, 'channels': np.arange(8)}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant