@@ -212,9 +212,9 @@ def on_validation_epoch_end(self):
212
212
"epoch" : float (self .current_epoch ),
213
213
"lr" : self .trainer .optimizers [0 ].param_groups [0 ]["lr" ],
214
214
}
215
- result_dict |= self ._get_mean_loss_dict_for_type ("total" )
216
- result_dict |= self ._get_mean_loss_dict_for_type ("y" )
217
- result_dict |= self ._get_mean_loss_dict_for_type ("neg_dy" )
215
+ result_dict . update ( self ._get_mean_loss_dict_for_type ("total" ) )
216
+ result_dict . update ( self ._get_mean_loss_dict_for_type ("y" ) )
217
+ result_dict . update ( self ._get_mean_loss_dict_for_type ("neg_dy" ) )
218
218
# For retro compatibility with previous versions of TorchMD-Net we report some losses twice
219
219
result_dict ["val_loss" ] = result_dict ["val_total_mse_loss" ]
220
220
result_dict ["train_loss" ] = result_dict ["train_total_mse_loss" ]
@@ -228,9 +228,9 @@ def on_test_epoch_end(self):
228
228
# Log all test losses
229
229
if not self .trainer .sanity_checking :
230
230
result_dict = {}
231
- result_dict |= self ._get_mean_loss_dict_for_type ("total" )
232
- result_dict |= self ._get_mean_loss_dict_for_type ("y" )
233
- result_dict |= self ._get_mean_loss_dict_for_type ("neg_dy" )
231
+ result_dict . update ( self ._get_mean_loss_dict_for_type ("total" ) )
232
+ result_dict . update ( self ._get_mean_loss_dict_for_type ("y" ) )
233
+ result_dict . update ( self ._get_mean_loss_dict_for_type ("neg_dy" ) )
234
234
# Get only test entries
235
235
result_dict = {k : v for k , v in result_dict .items () if k .startswith ("test" )}
236
236
self .log_dict (result_dict , sync_dist = True )
0 commit comments