Skip to content

Commit cecd094

Browse files
abhishek002002Orbax Authors
authored andcommitted
Migrate Tunelab to Orbax PreservationPolicy for checkpoint management.
PiperOrigin-RevId: 806163667
1 parent a863a99 commit cecd094

File tree

4 files changed

+38
-16
lines changed

4 files changed

+38
-16
lines changed

checkpoint/orbax/checkpoint/_src/checkpoint_managers/preservation_policy.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def should_preserve(
145145

146146
@dataclasses.dataclass
147147
class EveryNSteps(PreservationPolicy):
148-
"""Ensures checkpoints are preserved at least every N steps."""
148+
"""Preserves checkpoints after at least N steps."""
149149

150150
interval_steps: int
151151

@@ -157,7 +157,17 @@ def should_preserve(
157157
) -> Sequence[bool]:
158158
if self.interval_steps == 0:
159159
raise ValueError("interval_steps must not be 0.")
160-
result = [ckpt.step % self.interval_steps == 0 for ckpt in checkpoints]
160+
result = []
161+
previous_step = None
162+
for i, ckpt in enumerate(checkpoints):
163+
if i == 0:
164+
result.append(True) # Always preserve the first checkpoint.
165+
previous_step = ckpt.step
166+
elif ckpt.step - previous_step >= self.interval_steps:
167+
result.append(True)
168+
previous_step = ckpt.step
169+
else:
170+
result.append(False)
161171
_log_preservation_decision(
162172
f"EveryNSteps (interval_steps={self.interval_steps})",
163173
checkpoints,

checkpoint/orbax/checkpoint/_src/checkpoint_managers/preservation_policy_test.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ def setUp(self):
2828
super().setUp()
2929
self.preservation_context = preservation_policy_lib.PreservationContext()
3030

31-
def get_checkpoints(self, n: int):
31+
def get_checkpoints(self, steps: Sequence[int] = (0, 1, 2, 3, 4)):
3232
checkpoints = []
3333
time_increment = datetime.timedelta(seconds=1)
3434
start_time = datetime.datetime.now()
35-
for i in range(n):
35+
for i, step in enumerate(steps):
3636
current_time = start_time + i * time_increment
3737
checkpoints.append(
3838
checkpoint_info.CheckpointInfo(
39-
step=i,
39+
step=step,
4040
time=current_time,
4141
metrics=None
4242
)
@@ -79,7 +79,8 @@ def test_latest_n_policy(self, n, expected_preserved_steps):
7979

8080
self.assertSequenceEqual(
8181
expected_preserved_steps,
82-
self.get_preserved_checkpoints(self.get_checkpoints(5), policy),
82+
self.get_preserved_checkpoints(
83+
self.get_checkpoints(), policy),
8384
)
8485

8586
@parameterized.parameters(
@@ -103,35 +104,45 @@ def test_every_n_seconds_policy(
103104

104105
self.assertSequenceEqual(
105106
expected_preserved_steps,
106-
self.get_preserved_checkpoints(self.get_checkpoints(5), policy),
107+
self.get_preserved_checkpoints(self.get_checkpoints(), policy),
107108
)
108109

109110
@parameterized.parameters(
110111
dict(
111112
interval_steps=1,
113+
steps=[0, 1, 2, 3, 4],
112114
expected_preserved_steps=[0, 1, 2, 3, 4],
113115
),
114116
dict(
115117
interval_steps=3,
118+
steps=[0, 1, 2, 3, 4],
116119
expected_preserved_steps=[0, 3],
117120
),
118121
dict(
119122
interval_steps=6,
123+
steps=[0, 1, 2, 3, 4],
120124
expected_preserved_steps=[0],
121125
),
126+
dict(
127+
interval_steps=3,
128+
steps=[0, 1, 2, 4, 5, 8, 9, 13, 14, 25],
129+
expected_preserved_steps=[0, 4, 8, 13, 25],
130+
),
122131
)
123-
def test_every_n_steps_policy(self, interval_steps, expected_preserved_steps):
132+
def test_every_n_steps_policy(
133+
self, interval_steps, steps, expected_preserved_steps
134+
):
124135
policy = preservation_policy_lib.EveryNSteps(interval_steps=interval_steps)
125136

126137
self.assertEqual(
127138
expected_preserved_steps,
128-
self.get_preserved_checkpoints(self.get_checkpoints(5), policy),
139+
self.get_preserved_checkpoints(self.get_checkpoints(steps), policy),
129140
)
130141

131142
def test_every_zero_steps_policy_raises_error(self):
132143
policy = preservation_policy_lib.EveryNSteps(interval_steps=0)
133144
with self.assertRaises(ValueError):
134-
self.get_preserved_checkpoints(self.get_checkpoints(5), policy)
145+
self.get_preserved_checkpoints(self.get_checkpoints(), policy)
135146

136147
@parameterized.parameters(
137148
dict(
@@ -152,7 +163,7 @@ def test_custom_steps_policy(self, steps, expected_preserved_steps):
152163

153164
self.assertEqual(
154165
expected_preserved_steps,
155-
self.get_preserved_checkpoints(self.get_checkpoints(5), policy),
166+
self.get_preserved_checkpoints(self.get_checkpoints(), policy),
156167
)
157168

158169
@parameterized.parameters(
@@ -190,7 +201,7 @@ def test_best_n_policy(
190201
n=n,
191202
keep_checkpoints_without_metrics=keep_checkpoints_without_metrics,
192203
)
193-
checkpoints = self.get_checkpoints(5)
204+
checkpoints = self.get_checkpoints()
194205
for i, checkpoint in enumerate(checkpoints):
195206
if loss[i]:
196207
checkpoint.metrics = {'loss': loss[i]}
@@ -219,7 +230,8 @@ def test_joint_preservation_policy(self):
219230
]
220231
)
221232
loss = [5, None, 4, None, 3, None, 11, None, 8, None, 12, None]
222-
checkpoints = self.get_checkpoints(12)
233+
steps = [0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11]
234+
checkpoints = self.get_checkpoints(steps)
223235
for i, checkpoint in enumerate(checkpoints):
224236
if loss[i]:
225237
checkpoint.metrics = {'loss': loss[i]}

checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ class PersistentCheckpointOptions:
185185
may be considered for deletion when there are more than `max_to_keep`
186186
checkpoints present.
187187
keep_period:
188-
If set, any existing checkpoints matching checkpoint_step % keep_period == 0
189-
will not be deleted.
188+
If set, any existing checkpoints after every at least keep_period steps will
189+
be preserved.
190190
should_save_fn:
191191
Predicate callable to check if given step can be saved. This callable
192192
accepts step number and optional latest step number as param and returns

checkpoint/orbax/checkpoint/experimental/emergency/test_utils/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ def test_all_steps(
810810
@parameterized.parameters(
811811
(2, 2, 2, [0, 2, 4, 6, 8, 9]),
812812
(3, 2, 4, [0, 4, 7, 8, 9]),
813-
(2, 2, 5, [0, 8, 9]),
813+
(2, 2, 5, [0, 6, 8, 9]),
814814
(2, 6, 3, [0, 6, 8, 9]),
815815
)
816816
def test_all_steps_with_keep_interval(

0 commit comments

Comments
 (0)