@@ -28,15 +28,15 @@ def setUp(self):
2828 super ().setUp ()
2929 self .preservation_context = preservation_policy_lib .PreservationContext ()
3030
31- def get_checkpoints (self , n : int ):
31+ def get_checkpoints (self , steps : Sequence [ int ] = ( 0 , 1 , 2 , 3 , 4 ) ):
3232 checkpoints = []
3333 time_increment = datetime .timedelta (seconds = 1 )
3434 start_time = datetime .datetime .now ()
35- for i in range ( n ):
35+ for i , step in enumerate ( steps ):
3636 current_time = start_time + i * time_increment
3737 checkpoints .append (
3838 checkpoint_info .CheckpointInfo (
39- step = i ,
39+ step = step ,
4040 time = current_time ,
4141 metrics = None
4242 )
@@ -79,7 +79,8 @@ def test_latest_n_policy(self, n, expected_preserved_steps):
7979
8080 self .assertSequenceEqual (
8181 expected_preserved_steps ,
82- self .get_preserved_checkpoints (self .get_checkpoints (5 ), policy ),
82+ self .get_preserved_checkpoints (
83+ self .get_checkpoints (), policy ),
8384 )
8485
8586 @parameterized .parameters (
@@ -103,35 +104,45 @@ def test_every_n_seconds_policy(
103104
104105 self .assertSequenceEqual (
105106 expected_preserved_steps ,
106- self .get_preserved_checkpoints (self .get_checkpoints (5 ), policy ),
107+ self .get_preserved_checkpoints (self .get_checkpoints (), policy ),
107108 )
108109
109110 @parameterized .parameters (
110111 dict (
111112 interval_steps = 1 ,
113+ steps = [0 , 1 , 2 , 3 , 4 ],
112114 expected_preserved_steps = [0 , 1 , 2 , 3 , 4 ],
113115 ),
114116 dict (
115117 interval_steps = 3 ,
118+ steps = [0 , 1 , 2 , 3 , 4 ],
116119 expected_preserved_steps = [0 , 3 ],
117120 ),
118121 dict (
119122 interval_steps = 6 ,
123+ steps = [0 , 1 , 2 , 3 , 4 ],
120124 expected_preserved_steps = [0 ],
121125 ),
126+ dict (
127+ interval_steps = 3 ,
128+ steps = [0 , 1 , 2 , 4 , 5 , 8 , 9 , 13 , 14 , 25 ],
129+ expected_preserved_steps = [0 , 4 , 8 , 13 , 25 ],
130+ ),
122131 )
123- def test_every_n_steps_policy (self , interval_steps , expected_preserved_steps ):
132+ def test_every_n_steps_policy (
133+ self , interval_steps , steps , expected_preserved_steps
134+ ):
124135 policy = preservation_policy_lib .EveryNSteps (interval_steps = interval_steps )
125136
126137 self .assertEqual (
127138 expected_preserved_steps ,
128- self .get_preserved_checkpoints (self .get_checkpoints (5 ), policy ),
139+ self .get_preserved_checkpoints (self .get_checkpoints (steps ), policy ),
129140 )
130141
131142 def test_every_zero_steps_policy_raises_error (self ):
132143 policy = preservation_policy_lib .EveryNSteps (interval_steps = 0 )
133144 with self .assertRaises (ValueError ):
134- self .get_preserved_checkpoints (self .get_checkpoints (5 ), policy )
145+ self .get_preserved_checkpoints (self .get_checkpoints (), policy )
135146
136147 @parameterized .parameters (
137148 dict (
@@ -152,7 +163,7 @@ def test_custom_steps_policy(self, steps, expected_preserved_steps):
152163
153164 self .assertEqual (
154165 expected_preserved_steps ,
155- self .get_preserved_checkpoints (self .get_checkpoints (5 ), policy ),
166+ self .get_preserved_checkpoints (self .get_checkpoints (), policy ),
156167 )
157168
158169 @parameterized .parameters (
@@ -190,7 +201,7 @@ def test_best_n_policy(
190201 n = n ,
191202 keep_checkpoints_without_metrics = keep_checkpoints_without_metrics ,
192203 )
193- checkpoints = self .get_checkpoints (5 )
204+ checkpoints = self .get_checkpoints ()
194205 for i , checkpoint in enumerate (checkpoints ):
195206 if loss [i ]:
196207 checkpoint .metrics = {'loss' : loss [i ]}
@@ -219,7 +230,8 @@ def test_joint_preservation_policy(self):
219230 ]
220231 )
221232 loss = [5 , None , 4 , None , 3 , None , 11 , None , 8 , None , 12 , None ]
222- checkpoints = self .get_checkpoints (12 )
233+ steps = [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 9 , 10 , 11 ]
234+ checkpoints = self .get_checkpoints (steps )
223235 for i , checkpoint in enumerate (checkpoints ):
224236 if loss [i ]:
225237 checkpoint .metrics = {'loss' : loss [i ]}
0 commit comments