diff --git a/checkpoint/orbax/checkpoint/_src/checkpoint_managers/preservation_policy.py b/checkpoint/orbax/checkpoint/_src/checkpoint_managers/preservation_policy.py index 32cbfd1d0..c1d418b5c 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpoint_managers/preservation_policy.py +++ b/checkpoint/orbax/checkpoint/_src/checkpoint_managers/preservation_policy.py @@ -157,7 +157,17 @@ def should_preserve( ) -> Sequence[bool]: if self.interval_steps == 0: raise ValueError("interval_steps must not be 0.") - result = [ckpt.step % self.interval_steps == 0 for ckpt in checkpoints] + result = [] + previous_step = None + for i, ckpt in enumerate(checkpoints): + if i == 0: + result.append(True) # Always preserve the first checkpoint. + previous_step = ckpt.step + elif ckpt.step - previous_step >= self.interval_steps: + result.append(True) + previous_step = ckpt.step + else: + result.append(False) _log_preservation_decision( f"EveryNSteps (interval_steps={self.interval_steps})", checkpoints,