From 8a10c2563b1d3732d69f708ef36375a6f86cf025 Mon Sep 17 00:00:00 2001 From: Abhishek Agrawal Date: Tue, 9 Sep 2025 06:17:35 -0700 Subject: [PATCH] Migrate Tunelab to Orbax `PreservationPolicy` for checkpoint management. PiperOrigin-RevId: 804878789 --- .../_src/checkpoint_managers/preservation_policy.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/checkpoint/orbax/checkpoint/_src/checkpoint_managers/preservation_policy.py b/checkpoint/orbax/checkpoint/_src/checkpoint_managers/preservation_policy.py index 32cbfd1d0..c1d418b5c 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpoint_managers/preservation_policy.py +++ b/checkpoint/orbax/checkpoint/_src/checkpoint_managers/preservation_policy.py @@ -157,7 +157,17 @@ def should_preserve( ) -> Sequence[bool]: if self.interval_steps == 0: raise ValueError("interval_steps must not be 0.") - result = [ckpt.step % self.interval_steps == 0 for ckpt in checkpoints] + result = [] + previous_step = None + for i, ckpt in enumerate(checkpoints): + if i == 0: + result.append(True) # Always preserve the first checkpoint. + previous_step = ckpt.step + elif ckpt.step - previous_step >= self.interval_steps: + result.append(True) + previous_step = ckpt.step + else: + result.append(False) _log_preservation_decision( f"EveryNSteps (interval_steps={self.interval_steps})", checkpoints,