@@ -796,12 +796,13 @@ def validation_step(self, batch, batch_idx):
796796 raise RuntimeError ("Trouble!" )
797797
798798 model = TroubledModel ()
799- epoch_length = 64
799+ epoch_length = 2
800800 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
801801 trainer = Trainer (
802802 default_root_dir = tmp_path ,
803803 callbacks = [checkpoint_callback ],
804804 max_epochs = 5 ,
805+ limit_train_batches = epoch_length ,
805806 logger = False ,
806807 enable_progress_bar = False ,
807808 )
@@ -864,12 +865,13 @@ def on_train_epoch_start(self, trainer, pl_module):
864865 raise RuntimeError ("Trouble!" )
865866
866867 model = BoringModel ()
867- epoch_length = 64
868+ epoch_length = 2
868869 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
869870 trainer = Trainer (
870871 default_root_dir = tmp_path ,
871872 callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()],
872873 max_epochs = 5 ,
874+ limit_train_batches = epoch_length ,
873875 logger = False ,
874876 enable_progress_bar = False ,
875877 )
@@ -887,12 +889,13 @@ def on_train_epoch_end(self, trainer, pl_module):
887889 raise RuntimeError ("Trouble!" )
888890
889891 model = BoringModel ()
890- epoch_length = 64
892+ epoch_length = 2
891893 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
892894 trainer = Trainer (
893895 default_root_dir = tmp_path ,
894896 callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()],
895897 max_epochs = 5 ,
898+ limit_train_batches = epoch_length ,
896899 logger = False ,
897900 enable_progress_bar = False ,
898901 )
@@ -956,12 +959,13 @@ def on_validation_epoch_start(self, trainer, pl_module):
956959 raise RuntimeError ("Trouble!" )
957960
958961 model = BoringModel ()
959- epoch_length = 64
962+ epoch_length = 2
960963 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
961964 trainer = Trainer (
962965 default_root_dir = tmp_path ,
963966 callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()],
964967 max_epochs = 5 ,
968+ limit_train_batches = epoch_length ,
965969 logger = False ,
966970 enable_progress_bar = False ,
967971 )
@@ -979,12 +983,13 @@ def on_validation_epoch_end(self, trainer, pl_module):
979983 raise RuntimeError ("Trouble!" )
980984
981985 model = BoringModel ()
982- epoch_length = 64
986+ epoch_length = 2
983987 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
984988 trainer = Trainer (
985989 default_root_dir = tmp_path ,
986990 callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()],
987991 max_epochs = 5 ,
992+ limit_train_batches = epoch_length ,
988993 logger = False ,
989994 enable_progress_bar = False ,
990995 )
@@ -1002,12 +1007,13 @@ def on_validation_start(self, trainer, pl_module):
10021007 raise RuntimeError ("Trouble!" )
10031008
10041009 model = BoringModel ()
1005- epoch_length = 64
1010+ epoch_length = 2
10061011 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
10071012 trainer = Trainer (
10081013 default_root_dir = tmp_path ,
10091014 callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()],
10101015 max_epochs = 5 ,
1016+ limit_train_batches = epoch_length ,
10111017 logger = False ,
10121018 enable_progress_bar = False ,
10131019 )
@@ -1025,12 +1031,13 @@ def on_validation_end(self, trainer, pl_module):
10251031 raise RuntimeError ("Trouble!" )
10261032
10271033 model = BoringModel ()
1028- epoch_length = 64
1034+ epoch_length = 2
10291035 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
10301036 trainer = Trainer (
10311037 default_root_dir = tmp_path ,
10321038 callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()],
10331039 max_epochs = 5 ,
1040+ limit_train_batches = epoch_length ,
10341041 logger = False ,
10351042 enable_progress_bar = False ,
10361043 )
0 commit comments