Skip to content

Commit

Permalink
Code refactored, new training techniques added
Browse files Browse the repository at this point in the history
- Mixup added
- API updated
- FocalLoss added for adversarial training
- More efficient pre-processing algorithm added
- unnecessary methods removed
  • Loading branch information
Naghipourfar committed Jun 9, 2023
1 parent f63f497 commit 6161fec
Show file tree
Hide file tree
Showing 6 changed files with 1,058 additions and 1,144 deletions.
14 changes: 7 additions & 7 deletions cpa/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def __init__(self, adata: AnnData, model: CPA,

self.unique_perts = list(model.drug_encoder.keys())
self.unique_covars = {}
for covar in model.cat_covars_encoders.keys():
self.unique_covars[covar] = list(model.cat_covars_encoders[covar].keys())
for covar in model.covars_encoder.keys():
self.unique_covars[covar] = list(model.covars_encoder[covar].keys())

self.num_drugs = len(model.drug_encoder)

self.perts_dict = model.drug_encoder.copy()
self.covars_dict = model.cat_covars_encoders.copy()
self.covars_dict = model.covars_encoder.copy()

self.emb_covars = None
self.emb_perts = None
Expand Down Expand Up @@ -247,10 +247,10 @@ def latent_dose_response(self, perturbations=None, dose=None,
d = self.perts_dict[drug]
this_drug = torch.Tensor(dose).to(self.model.device).view(-1, 1)
if self.model.module.doser_type == 'mlp':
response = (self.model.module.drug_network.dosers[d](this_drug).sigmoid() * this_drug.gt(
response = (self.model.module.pert_network.dosers[d](this_drug).sigmoid() * this_drug.gt(
0)).cpu().clone().detach().numpy().reshape(-1)
else:
response = self.model.module.drug_network.dosers.one_drug(this_drug.view(-1),
response = self.model.module.pert_network.dosers.one_drug(this_drug.view(-1),
d).cpu().clone().detach().numpy().reshape(-1)

df_drug = pd.DataFrame(list(zip([drug] * n_points, dose, list(response))),
Expand Down Expand Up @@ -295,10 +295,10 @@ def latent_dose_response2D(self, perturbations, dose=None,
d = self.perts_dict[drug]
this_drug = torch.Tensor(dose).to(self.model.device).view(-1, 1)
if self.model.module.doser_type == 'mlp':
response[drug] = (self.model.module.drug_network.dosers[d](this_drug).sigmoid() * this_drug.gt(
response[drug] = (self.model.module.pert_network.dosers[d](this_drug).sigmoid() * this_drug.gt(
0)).cpu().clone().detach().numpy().reshape(-1)
else:
response[drug] = self.model.module.drug_network.dosers.one_drug(this_drug.view(-1),
response[drug] = self.model.module.pert_network.dosers.one_drug(this_drug.view(-1),
d).cpu().clone().detach().numpy().reshape(
-1)

Expand Down
8 changes: 6 additions & 2 deletions cpa/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ def __init__(
self.test_idx = test_indices

def setup(self, stage: Optional[str] = None):
gpus, self.device = parse_use_gpu_arg(self.use_gpu, return_device=True)
accelerator, _, self.device = parse_use_gpu_arg(
self.use_gpu, return_device=True
)
self.pin_memory = (
True if (settings.dl_pin_memory_gpu_training and gpus != 0) else False
True
if (settings.dl_pin_memory_gpu_training and accelerator == "gpu")
else False
)

def train_dataloader(self):
Expand Down
Loading

0 comments on commit 6161fec

Please sign in to comment.