diff --git a/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py b/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py index 07b44154..720ff17b 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py +++ b/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py @@ -37,6 +37,7 @@ from orbax.checkpoint._src.path import atomicity_types + BarrierSyncFn = multihost.BarrierSyncFn _DIRECTORY_CREATION_SIGNALS = [ synchronization.HandlerAwaitableSignal.STEP_DIRECTORY_CREATION @@ -49,7 +50,8 @@ def _on_commit_callback( ): """Finalize atomic save and record checkpoint save metrics.""" atomicity.on_commit_callback( - tmpdir, checkpoint_start_time=checkpoint_start_time + tmpdir, + checkpoint_start_time=checkpoint_start_time, ) total_duration_secs = time.time() - checkpoint_start_time jax.monitoring.record_event_duration_secs( @@ -441,7 +443,10 @@ def _callback() -> None: threading.current_thread().name, tmpdir.get(), ) - _on_commit_callback(tmpdir, checkpoint_start_time) + _on_commit_callback( + tmpdir, + checkpoint_start_time, + ) self._async_manager.start_async_commit( directory, diff --git a/checkpoint/orbax/checkpoint/_src/path/atomicity.py b/checkpoint/orbax/checkpoint/_src/path/atomicity.py index 30cbac48..df8d0573 100644 --- a/checkpoint/orbax/checkpoint/_src/path/atomicity.py +++ b/checkpoint/orbax/checkpoint/_src/path/atomicity.py @@ -552,7 +552,8 @@ def on_commit_callback( Args: tmp_dir: A temporary checkpoint directory, where the checkpoint data is currently saved. - checkpoint_start_time: The time at which checkpoint saving began. + checkpoint_start_time: The time at which checkpoint saving began. # BEGIN + tree_verity_options: Options to configure checkpoint signing and integrity """ tmp_dir.finalize() step_lib.record_saved_duration(checkpoint_start_time) diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index 28075ae6..977235ec 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -56,6 +56,7 @@ from typing_extensions import Self # for Python version < 3.11 + PyTree = Any CheckpointDirs = Tuple[str, str] SaveParams = Mapping[str, Any] diff --git a/checkpoint/orbax/checkpoint/options.py b/checkpoint/orbax/checkpoint/options.py index fd3ef1ec..db34826e 100644 --- a/checkpoint/orbax/checkpoint/options.py +++ b/checkpoint/orbax/checkpoint/options.py @@ -70,5 +70,3 @@ class FileOptions: path_permission_mode: Optional[int] = None - -