@@ -47,7 +47,7 @@ def create_queue() -> Queue:
4747 return q
4848
4949
50- def train_model (queue : Queue , max_epochs : int , ckpt_path : Path ) -> Trainer :
50+ def train_model (queue : Queue , max_epochs : int , ckpt_path : Path ) -> None :
5151 dataloader = DataLoader (QueueDataset (queue ), num_workers = 1 , batch_size = None , persistent_workers = True )
5252 trainer = Trainer (
5353 max_epochs = max_epochs ,
@@ -61,21 +61,17 @@ def train_model(queue: Queue, max_epochs: int, ckpt_path: Path) -> Trainer:
6161 else :
6262 trainer .fit (BoringModel (), dataloader )
6363 trainer .save_checkpoint (str (ckpt_path ))
64- return trainer
6564
6665
6766def test_resume_training_with (tmp_path ):
6867 """Test resuming training from checkpoint file using a IterableDataset."""
6968 queue = create_queue ()
7069 max_epoch = 2
7170 ckpt_path = tmp_path / "model.ckpt"
72- trainer = train_model (queue , max_epoch , ckpt_path )
73- assert trainer is not None
71+ train_model (queue , max_epoch , ckpt_path )
7472
7573 assert os .path .exists (ckpt_path ), f"Checkpoint file '{ ckpt_path } ' wasn't created"
76-
7774 ckpt_size = os .path .getsize (ckpt_path )
7875 assert ckpt_size > 0 , f"Checkpoint file is empty (size: { ckpt_size } bytes)"
7976
80- trainer = train_model (queue , max_epoch + 2 , ckpt_path )
81- assert trainer is not None
77+ train_model (queue , max_epoch + 2 , ckpt_path )
0 commit comments