Skip to content

Commit 9965aba

Browse files
abhishek002002Orbax Authors
authored andcommitted
Update EveryNSteps preservation policy.
PiperOrigin-RevId: 805206191
1 parent ba906d4 commit 9965aba

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

checkpoint/orbax/checkpoint/_src/checkpoint_managers/preservation_policy.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,17 @@ def should_preserve(
157157
) -> Sequence[bool]:
158158
if self.interval_steps == 0:
159159
raise ValueError("interval_steps must not be 0.")
160-
result = [ckpt.step % self.interval_steps == 0 for ckpt in checkpoints]
160+
result = []
161+
previous_step = None
162+
for i, ckpt in enumerate(checkpoints):
163+
if i == 0:
164+
result.append(True) # Always preserve the first checkpoint.
165+
previous_step = ckpt.step
166+
elif ckpt.step - previous_step >= self.interval_steps:
167+
result.append(True)
168+
previous_step = ckpt.step
169+
else:
170+
result.append(False)
161171
_log_preservation_decision(
162172
f"EveryNSteps (interval_steps={self.interval_steps})",
163173
checkpoints,

0 commit comments

Comments
 (0)