diff --git a/examples/main_example/dataset.py b/examples/main_example/dataset.py
index 7fa7ba3f..afc31da1 100644
--- a/examples/main_example/dataset.py
+++ b/examples/main_example/dataset.py
@@ -33,8 +33,8 @@ def get_domain_data(
 class DomainDataModule(LightningDataModule):
     def __init__(
         self,
-        val_dataset: torch.Tensor,
         train_dataset: torch.Tensor,
+        val_dataset: torch.Tensor,
         batch_size: int,
     ) -> None:
         super().__init__()
@@ -109,8 +109,8 @@ def make_datasets(
 class GWDataModule(LightningDataModule):
     def __init__(
         self,
-        val_datasets: dict[frozenset[str], DomainDataset],
         train_datasets: dict[frozenset[str], DomainDataset],
+        val_datasets: dict[frozenset[str], DomainDataset],
         batch_size: int,
     ) -> None:
         super().__init__()