Skip to content

Commit fd83954

Browse files
authored
Merge pull request #218 from RaulPPelaez/use_update
Use update instead of |= so Trainer is compatible with Python 3.8
2 parents e964a72 + 6e8b7ae commit fd83954

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

torchmdnet/module.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,9 @@ def on_validation_epoch_end(self):
212212
"epoch": float(self.current_epoch),
213213
"lr": self.trainer.optimizers[0].param_groups[0]["lr"],
214214
}
215-
result_dict |= self._get_mean_loss_dict_for_type("total")
216-
result_dict |= self._get_mean_loss_dict_for_type("y")
217-
result_dict |= self._get_mean_loss_dict_for_type("neg_dy")
215+
result_dict.update(self._get_mean_loss_dict_for_type("total"))
216+
result_dict.update(self._get_mean_loss_dict_for_type("y"))
217+
result_dict.update(self._get_mean_loss_dict_for_type("neg_dy"))
218218
# For retro compatibility with previous versions of TorchMD-Net we report some losses twice
219219
result_dict["val_loss"] = result_dict["val_total_mse_loss"]
220220
result_dict["train_loss"] = result_dict["train_total_mse_loss"]
@@ -228,9 +228,9 @@ def on_test_epoch_end(self):
228228
# Log all test losses
229229
if not self.trainer.sanity_checking:
230230
result_dict = {}
231-
result_dict |= self._get_mean_loss_dict_for_type("total")
232-
result_dict |= self._get_mean_loss_dict_for_type("y")
233-
result_dict |= self._get_mean_loss_dict_for_type("neg_dy")
231+
result_dict.update(self._get_mean_loss_dict_for_type("total"))
232+
result_dict.update(self._get_mean_loss_dict_for_type("y"))
233+
result_dict.update(self._get_mean_loss_dict_for_type("neg_dy"))
234234
# Get only test entries
235235
result_dict = {k: v for k, v in result_dict.items() if k.startswith("test")}
236236
self.log_dict(result_dict, sync_dist=True)

0 commit comments

Comments
 (0)