diff --git a/examples/main_example/dataset.py b/examples/main_example/dataset.py index 7fa7ba3f..afc31da1 100644 --- a/examples/main_example/dataset.py +++ b/examples/main_example/dataset.py @@ -33,8 +33,8 @@ def get_domain_data( class DomainDataModule(LightningDataModule): def __init__( self, - val_dataset: torch.Tensor, train_dataset: torch.Tensor, + val_dataset: torch.Tensor, batch_size: int, ) -> None: super().__init__() @@ -109,8 +109,8 @@ def make_datasets( class GWDataModule(LightningDataModule): def __init__( self, - val_datasets: dict[frozenset[str], DomainDataset], train_datasets: dict[frozenset[str], DomainDataset], + val_datasets: dict[frozenset[str], DomainDataset], batch_size: int, ) -> None: super().__init__()