Skip to content

Commit

Permalink
additional logging
Browse files Browse the repository at this point in the history
  • Loading branch information
LuisA92 committed Feb 25, 2025
1 parent a6b108d commit be3be61
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 15 deletions.
9 changes: 8 additions & 1 deletion src/integrator/callbacks/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def on_train_epoch_end(self, trainer, pl_module):
"thresholded_mean",
"dials_I_prf_value",
),
"corrcoef": torch.corrcoef(
"corrcoef qI": torch.corrcoef(
torch.vstack(
[
self.train_predictions["qI"].mean.flatten(),
Expand All @@ -186,6 +186,13 @@ def on_train_epoch_end(self, trainer, pl_module):
)
)[0, 1],
"max_qI": torch.max(self.train_predictions["qI"].mean.flatten()),
"max_bg": torch.max(
self.train_predictions["dials_I_prf_value"].flatten()
),
"mean_qI": torch.mean(self.train_predictions["qI"].mean.flatten()),
"mean_bg": torch.mean(
self.train_predictions["dials_I_prf_value"].flatten()
),
}

# Only create and log comparison grid on specified epochs
Expand Down
15 changes: 11 additions & 4 deletions src/integrator/configs/dev_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,18 @@ integrator:

components:
encoder:
name: dev_encoder
name: 3d_cnn
params:
depth: 10
dmodel: 64
feature_dim: 7
Z: 3
H: 21
W: 21
conv_channels: 64
use_norm: true
#name: dev_encoder
#params:
#depth: 10
#dmodel: 64
#feature_dim: 7
profile:
name: dirichlet
params:
Expand Down
10 changes: 8 additions & 2 deletions src/integrator/model/integrators/default_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def training_step(self, batch, batch_idx):
outputs = self(shoebox, dials, masks, metadata, counts)

# neg_ll, kl = self.loss_fn(
loss, neg_ll, kl, recon_loss = self.loss_fn(
loss, neg_ll, kl, recon_loss, kl_bg, kl_I, kl_p = self.loss_fn(
outputs["rates"],
outputs["counts"],
outputs["qp"],
Expand All @@ -246,6 +246,9 @@ def training_step(self, batch, batch_idx):
self.log("train_nll", neg_ll.mean())
self.log("train_kl", kl.mean())
self.log("train_recon", recon_loss)
self.log("kl_bg", kl_bg)
self.log("kl_I", kl_I)
self.log("kl_p", kl_p)

return loss.mean()

Expand Down Expand Up @@ -275,7 +278,7 @@ def validation_step(self, batch, batch_idx):
outputs = self(shoebox, dials, masks, metadata, counts)

# Calculate validation metrics
loss, neg_ll, kl, recon_loss = self.loss_fn(
loss, neg_ll, kl, recon_loss, kl_bg, kl_I, kl_p = self.loss_fn(
outputs["rates"],
outputs["counts"],
outputs["qp"],
Expand All @@ -289,6 +292,9 @@ def validation_step(self, batch, batch_idx):
self.log("val_nll", neg_ll.mean())
self.log("val_kl", kl.mean())
self.log("val_recon", recon_loss)
self.log("val_kl_bg", kl_bg)
self.log("val_kl_I", kl_I)
self.log("val_kl_p", kl_p)

# Return the complete outputs dictionary
return outputs
Expand Down
3 changes: 3 additions & 0 deletions src/integrator/model/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,4 +299,7 @@ def forward(self, rate, counts, q_p, q_I, q_bg, dead_pixel_mask):
neg_ll_batch.mean(),
kl_terms.mean(),
recon_loss_batch.mean(),
kl_bg.mean(),
kl_I.mean(),
kl_p.mean() if self.p_pairing == "dirichlet_dirichlet" else 0.0,
)
9 changes: 4 additions & 5 deletions src/integrator/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@

REGISTRY = {
"encoder": {
"encoder1": CNNResNet,
"fc_encoder": FcEncoder,
"encoder2": CNNResNet2,
"encoder1": CNNResNet, # done
"fc_encoder": FcEncoder, # for metadata
"dev_encoder": DevEncoder,
"fc_resnet": FcResNet,
"3d_cnn": CNN_3d,
"fc_resnet": FcResNet, # done
"3d_cnn": CNN_3d, # done
},
"decoder": {
"decoder1": Decoder,
Expand Down
16 changes: 15 additions & 1 deletion test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,23 @@
refl_tbl_subset = refl_tbl.select(flex.bool(sel))

# %%

encoder_name = config["components"]["encoder"]["name"]
I_pairing_name = config["components"]["loss"]["params"]["I_pairing"]
bg_pairing_name = config["components"]["loss"]["params"]["bg_pairing"]
p_pairing_name = config["components"]["loss"]["params"]["p_pairing"]


logger = WandbLogger(
project="integrator",
name="test-simpson-reg-local-3",
name="Encoder_"
+ encoder_name
+ "_I_"
+ I_pairing_name
+ "_Bg_"
+ bg_pairing_name
+ "_P_"
+ p_pairing_name,
save_dir="lightning_logs",
)

Expand Down
16 changes: 14 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,25 @@ def analysis(prediction_path, dials_env, phenix_env, pdb, expt_file):
],
)

encoder_name = config["components"]["encoder"]["name"]
I_pairing_name = config["components"]["loss"]["params"]["I_pairing"]
bg_pairing_name = config["components"]["loss"]["params"]["bg_pairing"]
p_pairing_name = config["components"]["loss"]["params"]["p_pairing"]

logger = WandbLogger(
project="integrator",
name="test-run",
name="Encoder_"
+ encoder_name
+ "_I_"
+ I_pairing_name
+ "_Bg_"
+ bg_pairing_name
+ "_P_"
+ p_pairing_name,
save_dir="lightning_logs",
)

plotter = IntensityPlotter()
plotter = IntensityPlotter(num_profiles=10)

## create checkpoint callback
checkpoint_callback = ModelCheckpoint(
Expand Down

0 comments on commit be3be61

Please sign in to comment.