diff --git a/matsciml/lightning/data_utils.py b/matsciml/lightning/data_utils.py index ec045c84..0e82c1ff 100644 --- a/matsciml/lightning/data_utils.py +++ b/matsciml/lightning/data_utils.py @@ -291,7 +291,7 @@ def predict_dataloader(self): target, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, - collate_fn=self.dataset.collate_fn, + collate_fn=target.collate_fn, persistent_workers=self.persistent_workers, )