diff --git a/checkpoint/orbax/checkpoint/_src/path/async_path.py b/checkpoint/orbax/checkpoint/_src/path/async_path.py index 2a33003e8..be3e8296d 100644 --- a/checkpoint/orbax/checkpoint/_src/path/async_path.py +++ b/checkpoint/orbax/checkpoint/_src/path/async_path.py @@ -91,6 +91,10 @@ async def is_dir(path: epath.Path): return await asyncio.to_thread(path.is_dir) +async def is_file(path: epath.Path): + return await asyncio.to_thread(path.is_file) + + async def is_link(path: epath.Path): return await asyncio.to_thread(os.path.islink, path) @@ -101,3 +105,7 @@ async def iterdir(path: epath.Path): async def glob(path: epath.Path, pattern: str) -> Iterator[epath.Path]: return await asyncio.to_thread(path.glob, pattern) + + +async def unlink(path: epath.Path, missing_ok: bool = False): + return await asyncio.to_thread(path.unlink, missing_ok=missing_ok) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/__init__.py b/checkpoint/orbax/checkpoint/experimental/v1/__init__.py index 06850cbfb..08324154d 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/__init__.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/__init__.py @@ -37,7 +37,7 @@ from orbax.checkpoint.experimental.v1._src.context.context import ( Context, ) -from orbax.checkpoint.experimental.v1._src.layout.format_utils import ( +from orbax.checkpoint.experimental.v1._src.layout.orbax_layout import ( is_orbax_checkpoint, ) from orbax.checkpoint.experimental.v1._src.loading.loading import ( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py index a3afcee5d..79fc58eff 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio +import itertools from typing import Any, Awaitable from absl import logging @@ -29,7 +30,7 @@ from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import -from orbax.checkpoint.experimental.v1._src.path import format_utils +from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.path import types as path_types @@ -39,6 +40,15 @@ ORBAX_CHECKPOINT_INDICATOR_FILE = 'orbax.checkpoint' +def _subdirs(directory: path_types.Path, *, limit: int = 3) -> list[str]: + return list( + itertools.islice( + (subdir.name for subdir in directory.iterdir() if subdir.is_dir()), + limit, + ) + ) + + _V0_ERROR_MESSAGE = ( 'If your checkpoint was saved with the Orbax V0 API, please follow the' ' instructions at' @@ -57,7 +67,9 @@ async def _create_orbax_identifier_file( """Creates a file called `orbax.checkpoint` for easy identification.""" directory = await directory.await_creation() if multihost.is_primary_host(primary_host): - await async_path.touch(directory / 'orbax.checkpoint', exist_ok=True) + await async_path.touch( + directory / ORBAX_CHECKPOINT_INDICATOR_FILE, exist_ok=True + ) class CompositeHandler: @@ -165,7 +177,7 @@ async def load( abstract_checkpointables = { name: None for name in handlers_for_load.keys() - if name not in format_utils.RESERVED_CHECKPOINTABLE_KEYS + if name not in checkpoint_layout.RESERVED_CHECKPOINTABLE_KEYS and name in existing_checkpointable_names } if any( @@ -265,8 +277,8 @@ def _get_saved_handler_typestrs( return saved_metadata.item_handlers # found step level metadata. raise ValueError( f'Path at {directory} contains subdirectories:' - f' {format_utils.subdirs(directory)}, which are expected to match the' - ' keys given by the _CHECKPOINT_METADATA file:' + f' {_subdirs(directory)}, which are expected to' + ' match the keys given by the _CHECKPOINT_METADATA file:' f' {saved_metadata.item_handlers}. If you intended to load a pytree' ' checkpoint from the given path, then please consider using' ' `loading.load_pytree(..., checkpointable_name=None)` instead.' @@ -293,8 +305,8 @@ def _get_saved_handler_typestrs( if isinstance(saved_metadata.item_handlers, dict): raise ValueError( f'Path at {directory} contains subdirectories:' - f' {format_utils.subdirs(directory)}, which are expected to match' - ' the keys given by the _CHECKPOINT_METADATA file:' + f' {_subdirs(directory)}, which are expected to' + ' match the keys given by the _CHECKPOINT_METADATA file:' f' {saved_metadata.item_handlers}. If you intended to load a pytree' ' checkpoint from the given path, then please consider using' ' `loading.load_pytree(..., checkpointable_name=None)` instead.' diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/checkpoint_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/checkpoint_layout.py index 4d2ff05c6..cf1953b28 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/checkpoint_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/checkpoint_layout.py @@ -21,6 +21,17 @@ Path = types.Path +### Constants shared by all layouts. ### + +PYTREE_CHECKPOINTABLE_KEY = "pytree" + +METRICS_CHECKPOINTABLE_KEY = "metrics" + +RESERVED_CHECKPOINTABLE_KEYS = frozenset({ + METRICS_CHECKPOINTABLE_KEY, +}) + + class InvalidLayoutError(ValueError): """Raised when the checkpoint layout is invalid.""" @@ -51,7 +62,7 @@ async def metadata(self) -> metadata_types.CheckpointMetadata[dict[str, Any]]: """ ... - def validate(self) -> None: + async def validate(self) -> None: """Validates the path, determining if it conforms to this instance. Returns: @@ -62,7 +73,7 @@ def validate(self) -> None: """ ... - def validate_pytree(self, checkpointable_name: str | None) -> None: + async def validate_pytree(self, checkpointable_name: str | None) -> None: """Validates the path as a PyTree checkpoint. Args: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/format_utils.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/format_utils.py deleted file mode 100644 index b8cc8d348..000000000 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/format_utils.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2025 The Orbax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for validating checkpoint formats.""" - -from etils import epath -from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout -from orbax.checkpoint.experimental.v1._src.layout import orbax_layout -from orbax.checkpoint.experimental.v1._src.path import types as path_types - -InvalidLayoutError = checkpoint_layout.InvalidLayoutError - - -def is_orbax_checkpoint(path: path_types.PathLike) -> bool: - """Determines if the given path is an Orbax checkpoint. - - Args: - path: The path to the checkpoint directory. - - Returns: - True if the path is an Orbax checkpoint, False otherwise. - """ - path = epath.Path(path) - try: - orbax_layout.OrbaxLayout(path).validate() - return True - except InvalidLayoutError: - return False diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py index 68cba9e8b..5a68e0855 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py @@ -16,6 +16,7 @@ from typing import Any, Awaitable +from etils import epath from orbax.checkpoint.experimental.v1._src.context import context as context_lib from orbax.checkpoint.experimental.v1._src.handlers import composite_handler from orbax.checkpoint.experimental.v1._src.handlers import registration @@ -23,16 +24,15 @@ from orbax.checkpoint.experimental.v1._src.loading import v0_compatibility from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types from orbax.checkpoint.experimental.v1._src.path import format_utils -from orbax.checkpoint.experimental.v1._src.path import types +from orbax.checkpoint.experimental.v1._src.path import types as path_types InvalidLayoutError = checkpoint_layout.InvalidLayoutError CompositeHandler = composite_handler.CompositeHandler -Path = types.Path +Path = path_types.Path CheckpointLayout = checkpoint_layout.CheckpointLayout -ORBAX_CHECKPOINT_INDICATOR_FILE = ( - composite_handler.ORBAX_CHECKPOINT_INDICATOR_FILE -) +PYTREE_METADATA_FILE = "_METADATA" +ORBAX_CHECKPOINT_INDICATOR_FILE = "orbax.checkpoint" _V0_ERROR_MESSAGE = ( @@ -48,6 +48,23 @@ ) +def is_orbax_checkpoint(path: path_types.PathLike) -> bool: + """Determines if the given path is an Orbax checkpoint. + + Args: + path: The path to the checkpoint directory. + + Returns: + True if the path is an Orbax checkpoint, False otherwise. + """ + path = epath.Path(path) + try: + OrbaxLayout(path).validate() + return True + except InvalidLayoutError: + return False + + class OrbaxLayout(CheckpointLayout): """OrbaxLayout. @@ -96,7 +113,7 @@ async def metadata(self) -> metadata_types.CheckpointMetadata[dict[str, Any]]: custom_metadata=step_metadata.custom_metadata, ) - def validate(self): + async def validate(self): try: format_utils.validate_checkpoint_directory(self._path) if self.has_indicator_file: @@ -107,7 +124,7 @@ def validate(self): f" {_GENERAL_ERROR_MESSAGE}" ) from e - def validate_pytree(self, checkpointable_name: str | None) -> None: + async def validate_pytree(self, checkpointable_name: str | None) -> None: """Validates the given path as a PyTree checkpoint.""" try: format_utils.validate_pytree_checkpoint( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py index 960197c28..52413cb7e 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py @@ -55,22 +55,22 @@ def setUp(self): custom_metadata=self.custom_metadata, ) - def test_valid_orbax_checkpoint(self): + async def test_valid_orbax_checkpoint(self): layout = OrbaxLayout(self.orbax_path / '0') - layout.validate() + await layout.validate() - def test_invalid_orbax_checkpoint(self): + async def test_invalid_orbax_checkpoint(self): layout = OrbaxLayout(self.safetensors_path) with self.assertRaises(InvalidLayoutError): - layout.validate() + await layout.validate() - def test_validate_fails_not_directory(self): + async def test_validate_fails_not_directory(self): layout = OrbaxLayout(self.orbax_path / '1') # This tests `format_utils.validate_checkpoint_directory` with self.assertRaises(InvalidLayoutError): - layout.validate() + await layout.validate() - def test_skips_metadata_validation_if_no_indicator_file(self): + async def test_skips_metadata_validation_if_no_indicator_file(self): layout = OrbaxLayout(self.orbax_path / '0') indicator_path = ( self.orbax_path @@ -81,25 +81,25 @@ def test_skips_metadata_validation_if_no_indicator_file(self): metadata_path = self.orbax_path / '0' / '_CHECKPOINT_METADATA' self.assertTrue(metadata_path.exists()) metadata_path.rmtree() # Remove the metadata file - layout.validate() + await layout.validate() - def test_validate_fails_no_metadata_file(self): + async def test_validate_fails_no_metadata_file(self): layout = OrbaxLayout(self.orbax_path / '0') # This tests `format_utils.validate_checkpoint_metadata` metadata_path = self.orbax_path / '0' / '_CHECKPOINT_METADATA' self.assertTrue(metadata_path.exists()) metadata_path.rmtree() # Remove the metadata file with self.assertRaises(InvalidLayoutError): - layout.validate() + await layout.validate() - def test_validate_fails_tmp_directory(self): + async def test_validate_fails_tmp_directory(self): # This simulates a temporary directory created by Orbax (should fail) test_utils.save_fake_tmp_dir(self.orbax_path, 0, 'test_checkpoint.tmp') layout = OrbaxLayout( epath.Path(self.test_dir.full_path) / 'test_checkpoint.tmp' ) with self.assertRaises(InvalidLayoutError): - layout.validate() + await layout.validate() async def test_load_orbax_checkpoint(self): layout = OrbaxLayout(self.orbax_path / '0') diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py index a227835f0..bd63ee500 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py @@ -24,7 +24,7 @@ InvalidLayoutError = checkpoint_layout.InvalidLayoutError -def get_checkpoint_layout( +async def get_checkpoint_layout( path: path_types.PathLike, layout_enum: options_lib.CheckpointLayout ) -> checkpoint_layout.CheckpointLayout: """Returns the checkpoint layout class for the given path. @@ -52,7 +52,7 @@ def get_checkpoint_layout( try: layout = layout_class(path) - layout.validate() + await layout.validate() return layout except InvalidLayoutError as e: raise InvalidLayoutError( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py index 8dca1215f..9f72b0252 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py @@ -22,6 +22,7 @@ import jax import numpy as np from orbax.checkpoint._src.arrays import numpy_utils +from orbax.checkpoint._src.path import async_path from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types from orbax.checkpoint.experimental.v1._src.path import format_utils @@ -280,8 +281,11 @@ async def metadata(self) -> metadata_types.CheckpointMetadata[dict[str, Any]]: custom_metadata=custom_metadata, ) - def validate(self): - if self._path.is_file() and self._path.suffix == ".safetensors": + async def validate(self): + if ( + await async_path.is_file(self._path) + and self._path.suffix == ".safetensors" + ): return else: raise InvalidLayoutError( @@ -290,7 +294,7 @@ def validate(self): " suffix." ) - def validate_pytree(self, checkpointable_name: str | None) -> None: + async def validate_pytree(self, checkpointable_name: str | None) -> None: return async def load( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_test.py index 87a278d38..d8ecbdab3 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_test.py @@ -56,27 +56,27 @@ def setUp(self): ) saving.save_pytree(self.orbax_path, self.object_to_save) - def test_valid_safetensors_checkpoint(self): + async def test_valid_safetensors_checkpoint(self): layout = SafetensorsLayout(self.safetensors_path) - layout.validate() + await layout.validate() - def test_invalid_safetensors_checkpoint_orbax(self): + async def test_invalid_safetensors_checkpoint_orbax(self): layout = SafetensorsLayout(self.orbax_path / '0') with self.assertRaises(InvalidLayoutError): - layout.validate() + await layout.validate() - def test_validate_fails_not_file(self): + async def test_validate_fails_not_file(self): layout = SafetensorsLayout(epath.Path(self.test_dir.full_path)) with self.assertRaises(InvalidLayoutError): - layout.validate() + await layout.validate() - def test_validate_fails_wrong_suffix(self): + async def test_validate_fails_wrong_suffix(self): wrong_suffix_path = ( epath.Path(self.test_dir.full_path) / 'test_checkpoint.txt' ) layout = SafetensorsLayout(wrong_suffix_path) with self.assertRaises(InvalidLayoutError): - layout.validate() + await layout.validate() @parameterized.product( dtype=[ diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py index 5276a6c27..f93bfe8bf 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py @@ -30,6 +30,7 @@ from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types from orbax.checkpoint.experimental.v1._src.path import format_utils from orbax.checkpoint.experimental.v1._src.path import types as path_types +from orbax.checkpoint.experimental.v1._src.synchronization import asyncio_utils from orbax.checkpoint.experimental.v1._src.synchronization import multihost from orbax.checkpoint.experimental.v1._src.synchronization import types as async_types from orbax.checkpoint.experimental.v1._src.tree import types as tree_types @@ -102,6 +103,7 @@ def load_pytree( The restored PyTree. """ start_time = time.time() + asyncio_utils.maybe_apply_nest_asyncio() logging.info('Loading checkpoint from %s.', path) path = epath.Path(path) @@ -114,11 +116,15 @@ def load_pytree( path = epath.Path(path) checkpointable_name = path.name path = path.parent - layout = layout_registry.get_checkpoint_layout( - path, context_lib.get_context().checkpoint_layout - ) - layout.validate_pytree(checkpointable_name) + async def _get_layout(): + layout = await layout_registry.get_checkpoint_layout( + path, context_lib.get_context().checkpoint_layout + ) + await layout.validate_pytree(checkpointable_name) + return layout + + layout = asyncio.run(_get_layout()) return _load_checkpointables_impl( layout, @@ -180,12 +186,13 @@ def load_checkpointables( Raises: FileNotFoundError: If the checkpoint path does not exist. """ - start_time = time.time() + asyncio_utils.maybe_apply_nest_asyncio() logging.info('Loading checkpoint from %s.', path) path = epath.Path(path) - layout = layout_registry.get_checkpoint_layout( - path, context_lib.get_context().checkpoint_layout + context = context_lib.get_context() + layout = asyncio.run( + layout_registry.get_checkpoint_layout(path, context.checkpoint_layout) ) return _load_checkpointables_impl( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py index cd1557988..23bb03eb0 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py @@ -26,6 +26,7 @@ from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types from orbax.checkpoint.experimental.v1._src.path import format_utils from orbax.checkpoint.experimental.v1._src.path import types as path_types +from orbax.checkpoint.experimental.v1._src.synchronization import asyncio_utils CheckpointMetadata = metadata_types.CheckpointMetadata @@ -87,6 +88,7 @@ def _get_abstract_array(arr): Returns: A `CheckpointMetadata[PyTreeMetadata]` object. """ + asyncio_utils.maybe_apply_nest_asyncio() context = context_lib.get_context() path = epath.Path(path) @@ -94,13 +96,15 @@ def _get_abstract_array(arr): checkpointable_name = path.name path = path.parent - layout = layout_registry.get_checkpoint_layout( - path, context.checkpoint_layout - ) - layout.validate_pytree(checkpointable_name) + async def _get_layout(): + layout = await layout_registry.get_checkpoint_layout( + path, context.checkpoint_layout + ) + await layout.validate_pytree(checkpointable_name) + return layout + layout = asyncio.run(_get_layout()) metadata = _checkpointables_metadata_impl(layout) - return CheckpointMetadata[PyTreeMetadata]( metadata=metadata.metadata[checkpointable_name], init_timestamp_nsecs=metadata.init_timestamp_nsecs, @@ -136,13 +140,12 @@ def checkpointables_metadata( Returns: A `CheckpointMetadata[dict[str, Any]]` object. """ + asyncio_utils.maybe_apply_nest_asyncio() path = epath.Path(path) context = context_lib.get_context() - layout = layout_registry.get_checkpoint_layout( - path, context.checkpoint_layout + layout = asyncio.run( + layout_registry.get_checkpoint_layout(path, context.checkpoint_layout) ) - layout.validate() - return _checkpointables_metadata_impl(layout) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py index b3411fc9a..b820fd879 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py @@ -22,7 +22,6 @@ from absl import logging from etils import epath import jax -import nest_asyncio from orbax.checkpoint._src.futures import future from orbax.checkpoint._src.logging import event_tracking from orbax.checkpoint._src.metadata import step_metadata_serialization @@ -36,6 +35,7 @@ from orbax.checkpoint.experimental.v1._src.path import format_utils from orbax.checkpoint.experimental.v1._src.path import types as path_types from orbax.checkpoint.experimental.v1._src.saving import path_utils as saving_path_utils +from orbax.checkpoint.experimental.v1._src.synchronization import asyncio_utils from orbax.checkpoint.experimental.v1._src.synchronization import multihost from orbax.checkpoint.experimental.v1._src.synchronization import thread_utils from orbax.checkpoint.experimental.v1._src.synchronization import types as async_types @@ -264,13 +264,6 @@ def create_save_response( ) -def _maybe_apply_nest_asyncio(): - try: - nest_asyncio.apply() - except RuntimeError: - pass - - def save_checkpointables_impl( path: path_types.PathLike, checkpointables: dict[str, Any], @@ -281,7 +274,7 @@ def save_checkpointables_impl( partial_save: bool = False, ) -> async_types.AsyncResponse[None]: """See caller docstrings.""" - _maybe_apply_nest_asyncio() + asyncio_utils.maybe_apply_nest_asyncio() context = context_lib.get_context() path = epath.Path(path) path_exists = path.exists() if partial_save else False diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/asyncio_utils.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/asyncio_utils.py new file mode 100644 index 000000000..f6bff8ddf --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/asyncio_utils.py @@ -0,0 +1,24 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for asyncio usage.""" + +import nest_asyncio + + +def maybe_apply_nest_asyncio(): + try: + nest_asyncio.apply() + except RuntimeError: + pass diff --git a/checkpoint/orbax/checkpoint/test_utils.py b/checkpoint/orbax/checkpoint/test_utils.py index a862ddc72..5942115bb 100644 --- a/checkpoint/orbax/checkpoint/test_utils.py +++ b/checkpoint/orbax/checkpoint/test_utils.py @@ -252,6 +252,13 @@ def _eq(x, y): jax.tree.map(_eq, expected, actual, is_leaf=lambda x: x is None) +def assert_tree_same_structure(testclass, expected, actual): + """Asserts that two PyTrees have the same structure.""" + expected_structure = jax.tree.structure(expected) + actual_structure = jax.tree.structure(actual) + testclass.assertEqual(expected_structure, actual_structure) + + def setup_pytree(add: int = 0): """Creates a numpy PyTree for testing.""" pytree = {