diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 658f3d2e9..ac6684837 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- #v1 Add `LeafHandler` as a `CheckpointableHandler`, so that ordinary PyTree +leaves can also be saved as individual checkpointables. +- #v1 Modify LeafHandler definitions so that `AbstractLeaf` or +`Type[AbstractLeaf]` are always accepted as valid abstract values. + ## [0.11.25] - 2025-09-10 ### Changed diff --git a/checkpoint/orbax/checkpoint/_src/testing/multiprocess_test.py b/checkpoint/orbax/checkpoint/_src/testing/multiprocess_test.py index 667228b02..b9d3cfa6e 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/multiprocess_test.py +++ b/checkpoint/orbax/checkpoint/_src/testing/multiprocess_test.py @@ -299,18 +299,23 @@ class MultiProcessTest(absltest.TestCase): def setUp(self): """Start distributed service.""" super().setUp() - assert jax.process_count() == NUM_PROCESSES.value, ( - jax.process_count(), - NUM_PROCESSES.value, - ) - # Make sure all processes are at the same test case. - client = multihost.get_jax_distributed_client() - # Note that the name of this barrier is long and complicated, to prevent - # any collisions with barriers in user test code. - client.wait_at_barrier( - f"multiprocess_test_ensure_all_processes_arrive_at_test_case_{self._testMethodName}", - 10000, - ) + if multihost.is_pathways_backend(): + assert ( + jax.process_count() == 1 + ), "Expected 1 process for Pathways backend." + else: + assert jax.process_count() == NUM_PROCESSES.value, ( + jax.process_count(), + NUM_PROCESSES.value, + ) + # Make sure all processes are at the same test case. + client = multihost.get_jax_distributed_client() + # Note that the name of this barrier is long and complicated, to prevent + # any collisions with barriers in user test code. + client.wait_at_barrier( + f"multiprocess_test_ensure_all_processes_arrive_at_test_case_{self._testMethodName}", + 10000, + ) def multiprocess_create_tempdir( self, name: str | None = None diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py index a3afcee5d..611dc097e 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py @@ -213,12 +213,15 @@ def get_handlers_for_save( self, checkpointables: dict[str, Any] ) -> dict[str, handler_types.CheckpointableHandler]: """Returns a mapping from checkpointable name to handler.""" - return { + result = { checkpointable_name: registration.resolve_handler_for_save( self._handler_registry, checkpointable, name=checkpointable_name ) for checkpointable_name, checkpointable in checkpointables.items() } + for name, handler in result.items(): + logging.info('Resolved handler type %s for %s', type(handler), name) + return result def get_handlers_for_load( self, directory: path_types.Path, abstract_checkpointables: dict[str, Any] @@ -247,6 +250,8 @@ def get_handlers_for_load( handler_typestr=handler_typestr, ) loadable_checkpointable_names_to_handlers[name] = handler + for name, handler in loadable_checkpointable_names_to_handlers.items(): + logging.info('Resolved handler type %s for %s', type(handler), name) return loadable_checkpointable_names_to_handlers def _get_saved_handler_typestrs( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py index e731e876a..37ec36447 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py @@ -23,6 +23,7 @@ from typing import Type from orbax.checkpoint.experimental.v1._src.handlers import json_handler +from orbax.checkpoint.experimental.v1._src.handlers import leaf_handler from orbax.checkpoint.experimental.v1._src.handlers import proto_handler from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler from orbax.checkpoint.experimental.v1._src.handlers import registration @@ -54,3 +55,9 @@ def _try_register_handler( _try_register_handler( pytree_handler.PyTreeHandler, format_utils.PYTREE_CHECKPOINTABLE_KEY ) + +# Registration for leaf types that can be treated as distinct checkpointables. +_try_register_handler(leaf_handler.ShardedArrayHandler) +_try_register_handler(leaf_handler.ArrayHandler) +_try_register_handler(leaf_handler.ScalarHandler) +_try_register_handler(leaf_handler.StringHandler) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/leaf_handler.py new file mode 100644 index 000000000..0c47a24ea --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/leaf_handler.py @@ -0,0 +1,114 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper for :py:class:`serialization.LeafHandler`. + +This :py:class:`CheckpointableHandler` is a wrapper for checkpointables where +support is already implemented at the PyTree leaf level. +""" + +from typing import Any, Awaitable, TypeVar + +import jax +import numpy as np +from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler +from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types +from orbax.checkpoint.experimental.v1._src.path import types as path_types +from orbax.checkpoint.experimental.v1._src.serialization import registry +from orbax.checkpoint.experimental.v1._src.serialization import types as serialization_types + + +Leaf = TypeVar('Leaf') +AbstractLeaf = TypeVar('AbstractLeaf') + + +class _LeafHandler(handler_types.CheckpointableHandler[Leaf, AbstractLeaf]): + """Base class for handlers that operate on individual PyTree leaves. + + This handler wraps `PyTreeHandler` to provide support for checkpointables + that are single leaves in a PyTree. + """ + + def __init__(self): + self._context = context_lib.get_context() + + async def save( + self, directory: path_types.PathAwaitingCreation, checkpointable: Leaf + ) -> Awaitable[None]: + return await pytree_handler.PyTreeHandler().save( + directory, [checkpointable] + ) + + async def load( + self, + directory: path_types.Path, + abstract_checkpointable: AbstractLeaf | None = None, + ) -> Awaitable[Leaf]: + if abstract_checkpointable is None: + abstract_pytree = None + else: + abstract_pytree = [abstract_checkpointable] + + background_load = await pytree_handler.PyTreeHandler().load( + directory, abstract_pytree + ) + + async def background_load_wrapper() -> Leaf: + loaded_pytree = await background_load + return loaded_pytree[0] + + return background_load_wrapper() + + async def metadata(self, directory: path_types.Path) -> AbstractLeaf: + pytree_metadata = await pytree_handler.PyTreeHandler().metadata(directory) + return pytree_metadata[0] + + def is_handleable(self, checkpointable: Any) -> bool: + try: + pytree_handler.PyTreeHandler().validate_leaves_handleable( + [checkpointable] + ) + return True + except registry.UnregisteredTypeError: + return False + + def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool | None: + try: + pytree_handler.PyTreeHandler().validate_abstract_leaves_handleable( + [abstract_checkpointable] + ) + return True + except registry.UnregisteredTypeError: + return False + + +class ShardedArrayHandler( + _LeafHandler[jax.Array, serialization_types.AbstractShardedArray] +): + pass + + +class ArrayHandler(_LeafHandler[np.ndarray, serialization_types.AbstractArray]): + pass + + +class StringHandler(_LeafHandler[str, serialization_types.AbstractString]): + pass + + +class ScalarHandler( + _LeafHandler[serialization_types.Scalar, serialization_types.AbstractScalar] +): + pass diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/leaf_handler_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/leaf_handler_test.py new file mode 100644 index 000000000..4ae0db354 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/leaf_handler_test.py @@ -0,0 +1,139 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Awaitable + +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized +from etils import epath +import jax +from jax import numpy as jnp +import numpy as np +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.arrays import abstract_arrays +from orbax.checkpoint.experimental.v1._src.handlers import leaf_handler +from orbax.checkpoint.experimental.v1._src.path import types as path_types +from orbax.checkpoint.experimental.v1._src.serialization import registry +from orbax.checkpoint.experimental.v1._src.testing import handler_utils as handler_test_utils +from orbax.checkpoint.experimental.v1._src.tree import types as tree_types + + +PathAwaitingCreation = path_types.PathAwaitingCreation +PathLike = path_types.PathLike +Path = path_types.Path +Json = tree_types.JsonType +create_test_handler = handler_test_utils.create_test_handler + +Leaf = leaf_handler.Leaf +AbstractLeaf = leaf_handler.AbstractLeaf + + +async def _run_awaitable(awaitable: Awaitable[Any]) -> Any: + return await awaitable + + +class FooHandler( + leaf_handler._LeafHandler[ + handler_test_utils.Foo, handler_test_utils.AbstractFoo + ] +): + + def is_handleable(self, checkpointable: Any) -> bool: + return isinstance(checkpointable, handler_test_utils.Foo) + + def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool | None: + return isinstance(abstract_checkpointable, handler_test_utils.AbstractFoo) + + +class LeafHandlerTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.directory = epath.Path( + self.create_tempdir(name='checkpointing_test').full_path + ) + self.values_and_abstract_values = { + leaf_handler.ShardedArrayHandler: [ + ( + jnp.arange(8), + abstract_arrays.to_shape_dtype_struct(jnp.arange(8)), + ), + (jnp.arange(8), jax.ShapeDtypeStruct), + ], + leaf_handler.ArrayHandler: [ + (np.arange(8), np.empty_like(np.arange(8))) + ], + leaf_handler.ScalarHandler: [ + (123, int), + (123, 0), + (123.456, float), + (123.456, 0.0), + ], + leaf_handler.StringHandler: [('test', str), ('test', '_')], + } + + def validate_load( + self, + handler: handler_test_utils.TestHandler[Leaf, AbstractLeaf], + value: Leaf, + abstract_value: AbstractLeaf, + directory: Path | None = None, + ): + directory = directory or self.directory + with self.subTest('load_with_abstract'): + restored = handler.load(directory, abstract_value) + test_utils.assert_array_equal(self, value, restored) + with self.subTest('load_without_abstract'): + restored = handler.load(directory) + test_utils.assert_array_equal(self, value, restored) + + @parameterized.parameters( + leaf_handler.ShardedArrayHandler, + leaf_handler.ArrayHandler, + leaf_handler.ScalarHandler, + leaf_handler.StringHandler, + ) + def test_save_load(self, handler_cls): + handler = create_test_handler(handler_cls) + test_cases = self.values_and_abstract_values[handler_cls] + + self.assertFalse(handler.is_handleable(handler_test_utils.Foo(1, 'hi'))) + self.assertFalse( + handler.is_abstract_handleable(handler_test_utils.AbstractFoo()) + ) + + for i, (value, abstract_value) in enumerate(test_cases): + name = str(i) + with self.subTest(f'value={value}, abstract_value={abstract_value}'): + logging.info( + 'Subtest: value=%s, abstract_value=%s', value, abstract_value + ) + self.assertTrue(handler.is_handleable(value)) + self.assertTrue(handler.is_abstract_handleable(abstract_value)) + handler.save(self.directory / name, value) + self.validate_load( + handler, value, abstract_value, directory=self.directory / name + ) + + def test_unregistered_type(self): + handler = create_test_handler(FooHandler) + with self.assertRaises(registry.UnregisteredTypeError): + handler.save(self.directory, handler_test_utils.Foo(1, 'hi')) + + with self.assertRaises(registry.UnregisteredTypeError): + handler.load(self.directory, handler_test_utils.AbstractFoo()) + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py index fd1de0073..aa1281c81 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py @@ -269,7 +269,7 @@ async def save( self, directory: path_types.PathAwaitingCreation, checkpointable: PyTree ) -> Awaitable[None]: - self._validate_leaves_handleable(checkpointable) + self.validate_leaves_handleable(checkpointable) commit_futures = await self._handler_impl.async_save( directory.path, @@ -303,7 +303,25 @@ async def load( directory: path_types.Path, abstract_checkpointable: PyTree | None = None, ) -> Awaitable[PyTree]: - self._validate_abstract_leaves_handleable(abstract_checkpointable) + """Loads a PyTree from a checkpoint directory. + + Args: + directory: The directory to load from. + abstract_checkpointable: The abstract checkpointable to load into. If + None, the handler will attempt to load the entire checkpoint using the + recorded metadata. Otherwise, the `abstract_checkpointable` is expected + to be a PyTree of abstract leaves. See :py:class:`.LeafHandler` for more + details. The abstract leaf may be a value of type `AbstractLeaf`, + `Type[AbstractLeaf]`, or `None`. E.g. if the `AbstractLeaf` is + `AbstractFoo`, it is always valid to pass `AbstractFoo()` or + `AbstractFoo` or `None`. Passing the latter two indicates that metadata + should be used to restore the leaf. + + Returns: + A awaitable which can be awaited to complete the load operation and + obtain a PyTree. + """ + self.validate_abstract_leaves_handleable(abstract_checkpointable) return self._background_load(directory, abstract_checkpointable) async def metadata( @@ -318,7 +336,7 @@ def _unwrap(metadata): return jax.tree.map(_unwrap, v0_metadata) - def _validate_leaves_handleable(self, checkpointable: PyTree): + def validate_leaves_handleable(self, checkpointable: PyTree): missing_leaf_types = set() def _validate_handleable_leaf(leaf: Any): @@ -335,14 +353,14 @@ def _validate_handleable_leaf(leaf: Any): ) if missing_leaf_types: - raise ValueError( + raise registry.UnregisteredTypeError( 'The following leaf types are not registered in the' f' `LeafHandlerRegistry`: [{missing_leaf_types}]. Please register a' ' `LeafHandler` for each type in the `LeafHandlerRegistry` and' ' assign it into the `PyTreeOptions` in the `Context`.' ) - def _validate_abstract_leaves_handleable( + def validate_abstract_leaves_handleable( self, abstract_checkpointable: PyTree ): missing_abstract_leaf_types = set() @@ -361,7 +379,7 @@ def _validate_handleable_leaf(leaf: Any): ) if missing_abstract_leaf_types: - raise ValueError( + raise registry.UnregisteredTypeError( 'The following abstract leaf types are not registered in the' f' `LeafHandlerRegistry`: [{missing_abstract_leaf_types}]. Please' ' register a `LeafHandler` for each type in the' @@ -370,7 +388,10 @@ def _validate_handleable_leaf(leaf: Any): ) def is_handleable(self, checkpointable: Any) -> bool: + logging.info('is_handleable: %s', checkpointable) try: + if not checkpointable: + return False # If it's a leaf or an empty pytree container, it's not handleable. return not jax.tree_util.treedef_is_leaf( jax.tree.structure(checkpointable) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test_base.py deleted file mode 100644 index ebc7d6a3c..000000000 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test_base.py +++ /dev/null @@ -1,2163 +0,0 @@ -# Copyright 2025 The Orbax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Common test cases for PyTreeHandler.""" - -# pylint: disable=protected-access, missing-function-docstring - -from __future__ import annotations - -import asyncio -import contextlib -import dataclasses -import datetime -import functools -import json -import threading -from typing import Any, Awaitable, Iterator, List, Sequence, Type -from unittest import mock - -from absl.testing import parameterized -import aiofiles -from etils import epath -import flax -import flax.training.train_state -import jax -from jax import numpy as jnp -from jax.experimental import mesh_utils -import numpy as np -import optax -from orbax.checkpoint import test_utils -from orbax.checkpoint import utils -from orbax.checkpoint._src.arrays import abstract_arrays -from orbax.checkpoint._src.handlers import pytree_checkpoint_handler -from orbax.checkpoint._src.metadata import array_metadata -from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib -from orbax.checkpoint._src.metadata import empty_values -from orbax.checkpoint._src.metadata import sharding as sharding_metadata -from orbax.checkpoint._src.metadata import tree as tree_metadata -from orbax.checkpoint._src.metadata import value as value_metadata -from orbax.checkpoint._src.serialization import replica_slices -from orbax.checkpoint._src.serialization import serialization -from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils -from orbax.checkpoint._src.serialization import type_handlers -from orbax.checkpoint._src.tree import utils as tree_utils -from orbax.checkpoint.experimental.v1._src.context import context as context_lib -from orbax.checkpoint.experimental.v1._src.context import options as options_lib -from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler -from orbax.checkpoint.experimental.v1._src.path import types as path_types -from orbax.checkpoint.experimental.v1._src.serialization import array_leaf_handler -from orbax.checkpoint.experimental.v1._src.serialization import numpy_leaf_handler -from orbax.checkpoint.experimental.v1._src.serialization import registry -from orbax.checkpoint.experimental.v1._src.serialization import scalar_leaf_handler -from orbax.checkpoint.experimental.v1._src.serialization import types as serialization_types -from orbax.checkpoint.experimental.v1._src.synchronization import multihost -from orbax.checkpoint.experimental.v1._src.testing import array_utils as array_test_utils -from orbax.checkpoint.experimental.v1._src.testing import handler_utils as handler_test_utils -from orbax.checkpoint.experimental.v1._src.tree import types as tree_types - - -PyTree = tree_types.PyTree -ParamInfo = pytree_checkpoint_handler.ParamInfo - -_SHARDING = '_sharding' -PYTREE_METADATA_FILE = pytree_checkpoint_handler.PYTREE_METADATA_FILE -ARRAY_METADATA_STORE = array_metadata_store_lib.Store() -PLACEHOLDER = type_handlers.PLACEHOLDER - -create_sharded_array = array_test_utils.create_sharded_array -create_numpy_pytree = array_test_utils.create_numpy_pytree -create_sharded_pytree = array_test_utils.create_sharded_pytree -as_abstract_type = array_test_utils.as_abstract_type - - -PathAwaitingCreation = path_types.PathAwaitingCreation -PathLike = path_types.PathLike -Path = path_types.Path - - -async def _run_awaitable(awaitable: Awaitable[Any]) -> Any: - return await awaitable - - -# Custom dataclasses for testing custom leaf handlers. PyType check requires -# these defines outside of the test. -@dataclasses.dataclass -class Point: - x: int - y: float - - -@dataclasses.dataclass -class AbstractPoint: - x: Type[int] = int - y: Type[float] = float - - -class PointLeafHandler(serialization_types.LeafHandler[Point, AbstractPoint]): - """A custom leaf handler for testing.""" - - def __init__(self, context: context_lib.Context | None = None): - del context - - async def serialize( - self, - params: Sequence[serialization_types.SerializationParam[Point]], - serialization_context: serialization_types.SerializationContext, - ) -> Awaitable[None]: - - async def _background_serialize(): - if multihost.is_primary_host(0): - # make sure the parent directory is created - await serialization_context.parent_dir.await_creation() - - for param in params: - async with aiofiles.open( - serialization_context.parent_dir.path / f'{param.name}.txt', - 'w', - ) as f: - await f.write(json.dumps(dataclasses.asdict(param.value))) - - return _background_serialize() - - async def deserialize( - self, - params: Sequence[serialization_types.DeserializationParam[AbstractPoint]], - deserialization_context: serialization_types.DeserializationContext, - ) -> Awaitable[Sequence[Point]]: - - async def _background_deserialize(): - ret = [] - for param in params: - async with aiofiles.open( - deserialization_context.parent_dir / f'{param.name}.txt', - 'r', - ) as f: - ret.append(Point(**json.loads(await f.read()))) - - return ret - - return _background_deserialize() - - async def metadata( - self, - params: Sequence[serialization_types.DeserializationParam[None]], - deserialization_context: serialization_types.DeserializationContext, - ) -> Sequence[AbstractPoint]: - return [AbstractPoint()] * len(params) - - -def create_mixed_format_pytree( - *, - add: int = 0, - strings: bool = False, - parent_key: str | None = None, - include_scalars: bool = True, -) -> PyTree: - """Creates a PyTree with different leaf types for testing. - - Args: - add: Adds the specified value to numeric leafs. - strings: If true, adds string leaves to the tree. - parent_key: If provided, keys will be contained within a dictionary under - this key. - include_scalars: If true, adds scalar leaves to the tree. - - Returns: - PyTree - """ - numpy_pytree, abstract_numpy_pytree = create_numpy_pytree( - add=add, include_scalars=include_scalars - ) - sharded_pytree, abstract_sharded_pytree = create_sharded_pytree( - add=add, include_scalars=include_scalars - ) - if parent_key: - numpy_pytree = {parent_key: numpy_pytree} - sharded_pytree = {parent_key: sharded_pytree} - abstract_numpy_pytree = {parent_key: abstract_numpy_pytree} - abstract_sharded_pytree = {parent_key: abstract_sharded_pytree} - mixed_pytree = { - 'numpy': numpy_pytree, - 'sharded': sharded_pytree, - } - abstract_mixed_pytree = { - 'numpy': abstract_numpy_pytree, - 'sharded': abstract_sharded_pytree, - } - if strings: - mixed_pytree['foo'] = 'foo_val' - mixed_pytree['bar'] = 'bar_val' - abstract_mixed_pytree['foo'] = '' - abstract_mixed_pytree['bar'] = '' - return mixed_pytree, abstract_mixed_pytree - - -def _raise_file_not_found_error(*args, **kwargs): - del args, kwargs - raise FileNotFoundError() - - -# Not in common util because we need to eliminate OSS dependency on flax. -def init_flax_model(model): - params = model.init(jax.random.PRNGKey(0), jnp.ones([8, 8])) - tx = optax.adamw(learning_rate=0.001) - state = flax.training.train_state.TrainState.create( - apply_fn=model.apply, params=params, tx=tx - ) - return jax.tree.map(np.asarray, state) - - -def get_d_files(path: Path) -> list[Path]: - files = [] - for idx in range(multihost.process_count()): - d_path = path / f'ocdbt.process_{idx}' / 'd' - if not d_path.exists(): - continue - files.extend(list(d_path.iterdir())) - return files - - -@contextlib.contextmanager -def handler_with_options( - *, - create_array_storage_options_fn: ( - options_lib.PyTreeOptions.Saving.CreateArrayStorageOptionsFn | None - ) = None, - save_concurrent_bytes: int | None = None, - restore_concurrent_bytes: int | None = None, - use_ocdbt: bool = True, - use_zarr3: bool = False, - enable_padding_and_truncation: bool = True, - ocdbt_target_data_file_size: int | None = None, - enable_pinned_host_transfer: bool | None = None, - pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = ( - tree_metadata.PYTREE_METADATA_OPTIONS - ), - array_metadata_store: array_metadata_store_lib.Store | None = ( - ARRAY_METADATA_STORE - ), - enable_write_sharding_file: bool = True, - partial_load: bool = False, - leaf_handler_registry: ( - serialization_types.LeafHandlerRegistry | None - ) = None, -): - """Registers handlers with OCDBT support and resets when done.""" - context = context_lib.Context( - array_options=options_lib.ArrayOptions( - saving=options_lib.ArrayOptions.Saving( - concurrent_bytes=save_concurrent_bytes, - use_ocdbt=use_ocdbt, - use_zarr3=use_zarr3, - ocdbt_target_data_file_size=ocdbt_target_data_file_size, - enable_pinned_host_transfer=enable_pinned_host_transfer, - array_metadata_store=array_metadata_store, - enable_write_sharding_file=enable_write_sharding_file, - use_replica_parallel=not utils.is_pathways_backend(), - ), - loading=options_lib.ArrayOptions.Loading( - concurrent_bytes=restore_concurrent_bytes, - enable_padding_and_truncation=enable_padding_and_truncation, - ), - ), - pytree_options=options_lib.PyTreeOptions( - saving=options_lib.PyTreeOptions.Saving( - create_array_storage_options_fn=create_array_storage_options_fn, - pytree_metadata_options=pytree_metadata_options, - ), - loading=options_lib.PyTreeOptions.Loading( - partial_load=partial_load, - ), - leaf_handler_registry=leaf_handler_registry, - ), - ) - - handler = handler_test_utils.create_test_handler( - pytree_handler.PyTreeHandler, context=context - ) - - try: - yield handler - finally: - pass - - -class PyTreeHandlerTestBase: - """Base test cases for PyTreeCheckpointHandler.""" - - class Test(parameterized.TestCase): - """Test class.""" - - def setUp(self): - super().setUp() - - self.pytree, self.abstract_pytree = create_sharded_pytree() - self.numpy_pytree, self.abstract_numpy_pytree = create_numpy_pytree() - - self.directory = epath.Path( - self.create_tempdir(name='checkpointing_test').full_path - ) - # TODO: b/365169723 - Add tests for support_rich_types=True. - self.pytree_metadata_options = tree_metadata.PyTreeMetadataOptions( - support_rich_types=False - ) - - # default to use_ocdbt=False, so we can test non-ocdbt handler first - self.handler = self.enter_context( - handler_with_options( - use_ocdbt=False, array_metadata_store=ARRAY_METADATA_STORE - ) - ) - test_utils.set_tensorstore_driver_for_test() - - test_utils.sync_global_processes( - 'PyTreeCheckpointHandlerTest:setup_complete' - ) - - def tearDown(self): - test_utils.sync_global_processes( - 'PyTreeCheckpointHandlerTest:tests_complete' - ) - super().tearDown() - - def validate_save( - self, - path: epath.Path, - abstract_pytree: PyTree | None, - expected: PyTree, - checkpoint_handler, - ): - """Validate save was performed correctly.""" - actual = checkpoint_handler.load(path, abstract_pytree) - test_utils.assert_tree_equal(self, expected, actual) - - def validate_metadata( - self, - *, - expected_reference_metadata_tree: PyTree, - actual_metadata: PyTree, - pytree_metadata_options: tree_metadata.PyTreeMetadataOptions, - array_metadata_store: array_metadata_store_lib.Store | None, - ): - """Validate metadata, provided the original tree that was saved.""" - expected_reference_metadata_tree = tree_metadata.serialize_tree( - expected_reference_metadata_tree, pytree_metadata_options - ) - - def _metadata(value): - if empty_values.is_supported_empty_value( - value, pytree_metadata_options - ): - return value - if isinstance(value, np.ndarray): - return numpy_leaf_handler.NumpyMetadata( - shape=value.shape, - dtype=value.dtype, - storage_metadata=value_metadata.StorageMetadata( - chunk_shape=value.shape, - ), - ) - if isinstance(value, jax.Array): - expected_sharding = sharding_metadata.from_jax_sharding( - value.sharding - ) - expected_chunk_shape = test_utils.get_expected_chunk_shape(value) - return array_leaf_handler.ArrayMetadata( - shape=value.shape, - sharding_metadata=expected_sharding, - dtype=value.dtype, - storage_metadata=value_metadata.StorageMetadata( - chunk_shape=expected_chunk_shape, - write_shape=( - expected_chunk_shape - if array_metadata_store is not None - else None - ), - ), - ) - if isinstance(value, (float, int)): - return np.float64 if isinstance(value, float) else np.int64 - if isinstance(value, str): - return str - if isinstance(value, optax.EmptyState): - return None - if isinstance(value, Point): - return AbstractPoint() - raise ValueError(f'Unrecognized type: {type(value)}.') - - expected_metadata = jax.tree.map( - _metadata, - expected_reference_metadata_tree, - is_leaf=tree_utils.is_empty_or_leaf, - ) - test_utils.assert_tree_equal(self, expected_metadata, actual_metadata) - - def test_get_param_names(self): - param_names = pytree_checkpoint_handler.get_param_names(self.pytree) - expected = { - 'a': 'a', - 'b': 'b', - 'c': { - 'a': 'c.a', - 'e': 'c.e', - }, - 'x': 'x', - 'y': 'y', - } - test_utils.assert_tree_equal(self, expected, param_names) - - def test_save_format(self): - pytree = {'a': 0, 'c': {'d': np.arange(3), 'e': {'f': 5}}, 'g': 10} - self.handler.save(self.directory, pytree) - fnames = ['a', 'c.d', 'c.e.f', 'g'] - paths = [self.directory / name for name in fnames] - for p in paths: - self.assertTrue(p.exists()) - self.assertTrue((p / '.zarray').exists()) - - @parameterized.product(use_ocdbt=(True, False)) - def test_save_sharding(self, use_ocdbt: bool): - if multihost.is_pathways_backend(): - self.skipTest('Sharding metadata not present on Pathways.') - with handler_with_options(use_ocdbt=use_ocdbt) as checkpoint_handler: - pytree = { - 'mlp/~/linear_0': { - 'a': self.pytree['a'], - 'b': self.pytree['b'], - 'c': {'a': self.pytree['c']['a'], 'e': self.pytree['c']['e']}, - } - } - abstract_pytree = jax.tree.map( - array_test_utils.as_abstract_type, pytree - ) - checkpoint_handler.save(self.directory, pytree) - - self.validate_save( - self.directory, - abstract_pytree, - pytree, - checkpoint_handler, - ) - - self.assertTrue((self.directory / _SHARDING).exists()) - with open(self.directory / _SHARDING, 'r') as file: - data = json.load(file) - self.assertCountEqual( - data.keys(), - { - 'bWxwL34vbGluZWFyXzAuYQ==', # mlp/~/linear_0.a - 'bWxwL34vbGluZWFyXzAuYg==', # mlp/~/linear_0.b - 'bWxwL34vbGluZWFyXzAuYy5h', # mlp/~/linear_0.c.a - 'bWxwL34vbGluZWFyXzAuYy5l', # mlp/~/linear_0.c.e - }, - ) - # mlp/~/linear_0.a - self.assertEqual( - sharding_metadata.NamedShardingMetadata.from_deserialized_dict( - json.loads(data['bWxwL34vbGluZWFyXzAuYQ==']) - ), - sharding_metadata.NamedShardingMetadata.from_jax_sharding( - pytree['mlp/~/linear_0']['a'].sharding - ), - ) - # mlp/~/linear_0.b - self.assertEqual( - sharding_metadata.NamedShardingMetadata.from_deserialized_dict( - json.loads(data['bWxwL34vbGluZWFyXzAuYg==']) - ), - sharding_metadata.NamedShardingMetadata.from_jax_sharding( - pytree['mlp/~/linear_0']['b'].sharding - ), - ) - # mlp/~/linear_0.c.a - self.assertEqual( - sharding_metadata.NamedShardingMetadata.from_deserialized_dict( - json.loads(data['bWxwL34vbGluZWFyXzAuYy5h']) - ), - sharding_metadata.NamedShardingMetadata.from_jax_sharding( - pytree['mlp/~/linear_0']['c']['a'].sharding - ), - ) - # mlp/~/linear_0.c.e - self.assertEqual( - sharding_metadata.NamedShardingMetadata.from_deserialized_dict( - json.loads(data['bWxwL34vbGluZWFyXzAuYy5l']) - ), - sharding_metadata.NamedShardingMetadata.from_jax_sharding( - pytree['mlp/~/linear_0']['c']['e'].sharding - ), - ) - - @parameterized.product( - use_ocdbt=(True, False), - array_metadata_store=(None, ARRAY_METADATA_STORE), - ) - def test_disable_write_sharding_file( - self, - use_ocdbt: bool, - array_metadata_store: array_metadata_store_lib.Store | None, - ): - pytree, abstract_pytree = create_mixed_format_pytree() - with handler_with_options( - use_ocdbt=use_ocdbt, - array_metadata_store=array_metadata_store, - enable_write_sharding_file=False, - ) as checkpoint_handler: - checkpoint_handler.save(self.directory, pytree) - self.validate_save( - self.directory, - abstract_pytree, - pytree, - checkpoint_handler, - ) - self.assertFalse((self.directory / _SHARDING).exists()) - - def test_sharding_variable_devices(self): - if multihost.is_pathways_backend(): - self.skipTest('Sharding metadata not present on Pathways.') - mesh_axes = jax.sharding.PartitionSpec( - 'x', - ) - devices_subset = [] - for idx in range(jax.process_count()): - for d in jax.devices(): - if d.process_index == idx: - devices_subset.append(d) - break - pytree = { - 'a': test_utils.create_sharded_array( - np.arange(16), - jax.sharding.Mesh(devices_subset, ('x',)), - mesh_axes, - ), - 'b': test_utils.create_sharded_array( - np.arange(16), jax.sharding.Mesh(jax.devices(), ('x',)), mesh_axes - ), - } - - self.handler.save(self.directory, pytree) - self.assertTrue((self.directory / _SHARDING).exists()) - a_sharding_metadata = sharding_metadata.NamedShardingMetadata( - shape=np.array([2]), - axis_names=['x'], - partition_spec=('x',), - device_mesh=sharding_metadata.DeviceMetadataMesh.from_jax_mesh( - jax.sharding.Mesh(devices_subset, ('x',)) - ), - ) - b_sharding_metadata = sharding_metadata.NamedShardingMetadata( - shape=np.array([8]), - axis_names=['x'], - partition_spec=('x',), - device_mesh=sharding_metadata.DeviceMetadataMesh.from_jax_mesh( - jax.sharding.Mesh(jax.devices(), ('x',)) - ), - ) - - restored_metadata = self.handler.metadata(self.directory) - self.assertEqual( - a_sharding_metadata, - restored_metadata['a'].sharding_metadata, - ) - self.assertEqual( - b_sharding_metadata, - restored_metadata['b'].sharding_metadata, - ) - self.assertEqual( - pytree['a'].sharding, - restored_metadata['a'].sharding, - ) - self.assertEqual( - pytree['b'].sharding, - restored_metadata['b'].sharding, - ) - - @parameterized.product(use_ocdbt=(True, False)) - def test_save_main(self, use_ocdbt: bool): - with handler_with_options(use_ocdbt=use_ocdbt) as checkpoint_handler: - checkpoint_handler.save(self.directory, self.pytree) - self.validate_save( - self.directory, - self.abstract_pytree, - self.pytree, - checkpoint_handler, - ) - self.assertEqual( - type_handlers.is_ocdbt_checkpoint(self.directory), use_ocdbt - ) - - @parameterized.product(use_ocdbt=(True, False)) - def test_save_keys_with_slashes(self, use_ocdbt: bool): - with handler_with_options(use_ocdbt=use_ocdbt) as checkpoint_handler: - pytree = { - 'a': np.arange(2), - 'b/c': np.arange(4), - } - checkpoint_handler.save(self.directory, pytree) - self.validate_save( - self.directory, - None, - pytree, - checkpoint_handler, - ) - - def test_save_non_sharded(self): - self.handler.save(self.directory, self.numpy_pytree) - self.validate_save( - self.directory, - None, - self.numpy_pytree, - self.handler, - ) - - @parameterized.product( - use_ocdbt=(True, False), - array_metadata_store=(None, ARRAY_METADATA_STORE), - ) - def test_save_mixed( - self, - use_ocdbt: bool, - array_metadata_store: array_metadata_store_lib.Store | None, - ): - with handler_with_options( - use_ocdbt=use_ocdbt, array_metadata_store=array_metadata_store - ) as checkpoint_handler: - pytree, abstract_pytree = create_mixed_format_pytree(strings=True) - checkpoint_handler.save(self.directory, pytree) - self.validate_save( - self.directory, - abstract_pytree, - pytree, - checkpoint_handler, - ) - if use_ocdbt: - expected_files_and_directories = [ - '_strings.json', - 'manifest.ocdbt', - 'ocdbt.process_0', - ] - else: - expected_files_and_directories = [ - '_strings.json', - 'numpy.a', - 'numpy.b', - 'numpy.c.a', - 'numpy.c.e', - ] - self.assertContainsSubset( - expected_files_and_directories, - [f.name for f in self.directory.iterdir()], - ) - self.validate_metadata( - expected_reference_metadata_tree=pytree, - actual_metadata=checkpoint_handler.metadata(self.directory), - pytree_metadata_options=self.pytree_metadata_options, - array_metadata_store=array_metadata_store, - ) - - @parameterized.product( - use_ocdbt=(True, False), - array_metadata_store=(None, ARRAY_METADATA_STORE), - ) - def test_save_strings( - self, - use_ocdbt: bool, - array_metadata_store: array_metadata_store_lib.Store | None, - ): - if use_ocdbt and multihost.is_pathways_backend(): - self.skipTest('Pathways + OCDBT not supported.') - - with handler_with_options( - use_ocdbt=use_ocdbt, array_metadata_store=array_metadata_store - ) as checkpoint_handler: - pytree, abstract_pytree = create_mixed_format_pytree(strings=True) - - checkpoint_handler.save(self.directory, pytree) - self.validate_save( - self.directory, - abstract_pytree, - pytree, - checkpoint_handler, - ) - self.validate_metadata( - expected_reference_metadata_tree=pytree, - actual_metadata=checkpoint_handler.metadata(self.directory), - pytree_metadata_options=self.pytree_metadata_options, - array_metadata_store=array_metadata_store, - ) - self.assertTrue((self.directory / '_strings.json').exists()) - with open(self.directory / '_strings.json') as file: - data = json.load(file) - self.assertCountEqual( - data.keys(), - {'foo', 'bar'}, - None, - ) - self.assertEqual(data['foo'], 'foo_val') - self.assertEqual(data['bar'], 'bar_val') - - def test_cast(self): - pytree, abstract_pytree = create_mixed_format_pytree( - include_scalars=False - ) - origin_dtype = np.int64 - save_dtype = np.uint32 - restore_dtype = np.float64 - - def check_dtype(x, dtype): - if not utils.is_scalar(x): - self.assertEqual(x.dtype, dtype) - - def set_dtype(v, dtype): - if hasattr(v, 'dtype'): - if isinstance(v, jax.ShapeDtypeStruct): - v = v.update(dtype=dtype) - else: - setattr(v, 'dtype', dtype) - return v - - with self.subTest('check_origin_dtype'): - jax.tree.map(functools.partial(check_dtype, dtype=origin_dtype), pytree) - jax.tree.map( - functools.partial(check_dtype, dtype=origin_dtype), abstract_pytree - ) - - with handler_with_options( - use_ocdbt=False, - create_array_storage_options_fn=lambda k, v: options_lib.ArrayOptions.Saving.StorageOptions( - dtype=save_dtype - ), - ) as checkpoint_handler: - checkpoint_handler.save(self.directory, pytree) - - with self.subTest('check_restore_dtype'): - abstract_pytree = jax.tree.map( - functools.partial(set_dtype, dtype=restore_dtype), abstract_pytree - ) - restored = self.handler.load(self.directory, abstract_pytree) - jax.tree.map( - functools.partial(check_dtype, dtype=restore_dtype), restored - ) - - with self.subTest('check_save_dtype'): - restored = self.handler.load(self.directory) - jax.tree.map(functools.partial(check_dtype, dtype=save_dtype), restored) - - @parameterized.product(cast_to=(int, float, 0, 0.0)) - def test_cast_scalar_types(self, cast_to): - pytree = {'a': 5, 'b': 6.1} - abstract_pytree = { - 'a': cast_to, - 'b': cast_to, - } - - self.handler.save(self.directory, pytree) - restored = self.handler.load(self.directory, abstract_pytree) - expected_type = cast_to if isinstance(cast_to, type) else type(cast_to) - self.assertIsInstance(restored['a'], expected_type) - self.assertIsInstance(restored['b'], expected_type) - - @parameterized.product( - use_ocdbt=(True, False), - use_zarr3=(True, False), - array_metadata_store=(None, ARRAY_METADATA_STORE), - ) - def test_save_restore( - self, - use_ocdbt: bool, - use_zarr3: bool, - array_metadata_store: array_metadata_store_lib.Store | None, - ): - with handler_with_options( - use_ocdbt=use_ocdbt, - use_zarr3=use_zarr3, - array_metadata_store=array_metadata_store, - ) as checkpoint_handler: - checkpoint_handler.save(self.directory, self.pytree) - restored = checkpoint_handler.load( - self.directory, - self.abstract_pytree, - ) - test_utils.assert_tree_equal(self, self.pytree, restored) - self.validate_metadata( - expected_reference_metadata_tree=self.pytree, - actual_metadata=checkpoint_handler.metadata(self.directory), - pytree_metadata_options=self.pytree_metadata_options, - array_metadata_store=array_metadata_store, - ) - - def test_save_async(self): - # The pytree must be larger so that saving doesn't complete too quickly. - mesh = jax.sharding.Mesh(jax.devices(), 'x') - np.random.seed(42) - pytree = { - 'a': array_test_utils.create_sharded_array( - np.arange(2**20), - sharding=jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec('x') - ), - ), - 'b': array_test_utils.create_sharded_array( - np.random.uniform(size=2**15), - sharding=jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec(None) - ), - ), - } - abstract_pytree = jax.tree.map(array_test_utils.as_abstract_type, pytree) - - start_serialize = threading.Event() - original_serialize = serialization.async_serialize_from_host - - def mock_serialize(*args, **kwargs): - start_serialize.wait() # Wait for explicit signal before proceeding. - return original_serialize(*args, **kwargs) - - def is_save_complete(directory): - return (directory / 'manifest.ocdbt').exists() - - # Serialization to disk does not start until receiving an explicit signal. - self.enter_context( - mock.patch.object( - serialization, 'async_serialize_from_host', new=mock_serialize - ) - ) - - with handler_with_options() as checkpoint_handler: - awaitable = checkpoint_handler.save_async(self.directory, pytree) - initial_d_files = get_d_files(self.directory) - self.assertFalse(is_save_complete(self.directory)) - start_serialize.set() - - asyncio.run(_run_awaitable(awaitable)) - final_d_files = get_d_files(self.directory) - self.assertNotEmpty(final_d_files) - self.assertNotEqual(len(initial_d_files), len(final_d_files)) - self.assertTrue(is_save_complete(self.directory)) - - restored = checkpoint_handler.load( - self.directory, - abstract_pytree, - ) - test_utils.assert_tree_equal(self, pytree, restored) - - def test_load_async(self): - with handler_with_options() as checkpoint_handler: - checkpoint_handler.save(self.directory, self.pytree) - load_awaitable = checkpoint_handler.load_async( - self.directory, - self.abstract_pytree, - ) - restored = asyncio.run(_run_awaitable(load_awaitable)) - test_utils.assert_tree_equal(self, self.pytree, restored) - - @parameterized.product(use_ocdbt=(True, False)) - def test_load_reverse_mesh(self, use_ocdbt: bool): - if use_ocdbt and multihost.is_pathways_backend(): - self.skipTest('Pathways + OCDBT not supported.') - with handler_with_options(use_ocdbt=use_ocdbt) as checkpoint_handler: - pytree, abstract_pytree = array_test_utils.create_sharded_pytree( - reverse_devices=True - ) - checkpoint_handler.save(self.directory, pytree) - restored = checkpoint_handler.load(self.directory, abstract_pytree) - test_utils.assert_tree_equal(self, pytree, restored) - - def test_load_multiple_steps(self): - for step in [0, 1]: - directory = self.directory / str(step) - if multihost.process_index() == 0: - directory.mkdir() - test_utils.sync_global_processes( - 'PyTreeCheckpointHandlerTest:test_load_different_mkdir' - ) - - pytree, abstract_pytree = create_mixed_format_pytree(add=step) - self.handler.save(directory, pytree) - - restored = self.handler.load(directory, abstract_pytree) - test_utils.assert_tree_equal(self, pytree, restored) - - def test_load_missing_checkpoint(self): - directory = self.directory / 'nothing' - with self.assertRaises(FileNotFoundError): - self.handler.load(directory) - - @parameterized.product( - use_ocdbt=(True, False), - array_metadata_store=(None, ARRAY_METADATA_STORE), - ) - def test_flax_model( - self, - use_ocdbt: bool, - array_metadata_store: array_metadata_store_lib.Store | None, - ): - - @flax.struct.dataclass - class Params(flax.struct.PyTreeNode): - params: PyTree - opt_state: PyTree - - def make_state_with_optax(): - return Params( - params=self.numpy_pytree, - opt_state=(optax.EmptyState(), optax.EmptyState()), - ) - - def make_state_with_nones(): - return Params( - params=self.numpy_pytree, - opt_state=(None, None), - ) - - state = make_state_with_optax() - - with handler_with_options( - use_ocdbt=use_ocdbt, array_metadata_store=array_metadata_store - ) as checkpoint_handler: - checkpoint_handler.save(self.directory, state) - - with self.subTest('with_abstract_state'): - abstract_state = jax.tree.map( - array_test_utils.as_abstract_type, state - ) - restored = checkpoint_handler.load(self.directory, abstract_state) - expected_state = state - test_utils.assert_tree_equal(self, expected_state, restored) - self.validate_metadata( - expected_reference_metadata_tree=expected_state, - actual_metadata=checkpoint_handler.metadata(self.directory), - pytree_metadata_options=self.pytree_metadata_options, - array_metadata_store=array_metadata_store, - ) - - with self.subTest('without_abstract_state'): - if multihost.is_pathways_backend(): - self.skipTest('Must provide abstract_pytree for Pathways.') - restored = checkpoint_handler.load(self.directory) - expected_state = tree_utils.serialize_tree( - make_state_with_nones(), - keep_empty_nodes=True, - ) - test_utils.assert_tree_equal(self, expected_state, restored) - self.validate_metadata( - expected_reference_metadata_tree=expected_state, - actual_metadata=checkpoint_handler.metadata(self.directory), - pytree_metadata_options=self.pytree_metadata_options, - array_metadata_store=array_metadata_store, - ) - - @parameterized.product( - use_ocdbt=( - True, - False, - ), - data=( - {}, - {'a': {}, 'b': 3}, - [1, {}, 2], - None, - {'a': None, 'b': 3}, - [1, None, 2], - [], - [1, [], 2], - {'a': [], 'b': 3}, - ), - array_metadata_store=(None, ARRAY_METADATA_STORE), - ) - def test_empty_data( - self, - use_ocdbt: bool, - data: Any, - array_metadata_store: array_metadata_store_lib.Store | None, - ): - with handler_with_options( - use_ocdbt=use_ocdbt, array_metadata_store=array_metadata_store - ) as checkpoint_handler: - if not data: - with self.assertRaisesRegex(ValueError, 'Found empty item'): - checkpoint_handler.save( - self.directory, - data, - ) - return - - checkpoint_handler.save(self.directory, data) - restored = checkpoint_handler.load(self.directory) - self.assertEqual(restored, data) - - self.validate_metadata( - expected_reference_metadata_tree=data, - actual_metadata=checkpoint_handler.metadata(self.directory), - pytree_metadata_options=self.pytree_metadata_options, - array_metadata_store=array_metadata_store, - ) - - @parameterized.product( - use_ocdbt=(True, False), - array_metadata_store=(None, ARRAY_METADATA_STORE), - ) - def test_list( - self, - use_ocdbt: bool, - array_metadata_store: array_metadata_store_lib.Store | None, - ): - item = [1, 2, 5, 6] - with handler_with_options( - use_ocdbt=use_ocdbt, array_metadata_store=array_metadata_store - ) as checkpoint_handler: - checkpoint_handler.save(self.directory, item) - abstract_item = [0, 0, 0, 0] - restored = checkpoint_handler.load(self.directory, abstract_item) - self.assertListEqual(restored, item) - self.validate_metadata( - expected_reference_metadata_tree=[0, 0, 0, 0], - actual_metadata=checkpoint_handler.metadata(self.directory), - pytree_metadata_options=self.pytree_metadata_options, - array_metadata_store=array_metadata_store, - ) - - restored = checkpoint_handler.load(self.directory) - self.assertListEqual( - restored, - [ - np.asarray([1]), - np.asarray([2]), - np.asarray([5]), - np.asarray([6]), - ], - ) - - def test_no_metadata_file(self): - self.handler.save(self.directory, self.pytree) - metadata_file = self.directory / PYTREE_METADATA_FILE - if multihost.process_index() == 0: - self.assertTrue(metadata_file.exists()) - metadata_file.unlink() - test_utils.sync_global_processes('delete_metadata_file') - self.assertFalse(metadata_file.exists()) - with self.assertRaises(FileNotFoundError): - self.handler.metadata(self.directory) - - @parameterized.parameters((True,), (False,)) - def test_reshape_padding(self, enable_padding_and_truncation: bool): - mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('x',)) - axes = jax.sharding.PartitionSpec( - 'x', - ) - dtype = np.float32 - pytree = { - 'x': test_utils.create_sharded_array( - np.arange(8, dtype=dtype), mesh, axes - ) - } - abstract_pytree = { - 'x': jax.ShapeDtypeStruct( - shape=(16,), dtype=dtype, sharding=pytree['x'].sharding - ) - } - with handler_with_options( - enable_padding_and_truncation=enable_padding_and_truncation - ) as checkpoint_handler: - checkpoint_handler.save(self.directory, pytree) - if enable_padding_and_truncation: - restored = checkpoint_handler.load(self.directory, abstract_pytree) - expected = { - 'x': test_utils.create_sharded_array( - np.concatenate( - (np.arange(8, dtype=dtype), np.zeros(8, dtype=dtype)) - ), - mesh, - axes, - ) - } - test_utils.assert_tree_equal(self, expected, restored) - else: - with self.assertRaises(BaseException): - checkpoint_handler.load(self.directory, abstract_pytree) - - @parameterized.parameters((True,), (False,)) - def test_reshape_truncate(self, enable_padding_and_truncation: bool): - mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('x',)) - axes = jax.sharding.PartitionSpec( - 'x', - ) - dtype = np.float32 - pytree = { - 'x': test_utils.create_sharded_array( - np.arange(16, dtype=dtype), mesh, axes - ) - } - abstract_pytree = { - 'x': jax.ShapeDtypeStruct( - shape=(8,), dtype=dtype, sharding=pytree['x'].sharding - ) - } - - with handler_with_options( - enable_padding_and_truncation=enable_padding_and_truncation - ) as checkpoint_handler: - checkpoint_handler.save(self.directory, pytree) - if enable_padding_and_truncation: - restored = checkpoint_handler.load(self.directory, abstract_pytree) - expected = { - 'x': test_utils.create_sharded_array( - np.arange(8, dtype=dtype), mesh, axes - ) - } - test_utils.assert_tree_equal(self, expected, restored) - else: - with self.assertRaises(BaseException): - checkpoint_handler.load(self.directory, abstract_pytree) - - @parameterized.parameters( - (jax.sharding.PartitionSpec(), jax.sharding.PartitionSpec(('x', 'y'))), - (jax.sharding.PartitionSpec(), jax.sharding.PartitionSpec(('y', 'x'))), - (jax.sharding.PartitionSpec(), jax.sharding.PartitionSpec(('x',))), - (jax.sharding.PartitionSpec(), jax.sharding.PartitionSpec(('y',))), - (jax.sharding.PartitionSpec(('x', 'y')), jax.sharding.PartitionSpec()), - ( - jax.sharding.PartitionSpec(('x', 'y')), - jax.sharding.PartitionSpec(('x',)), - ), - ( - jax.sharding.PartitionSpec(('x', 'y')), - jax.sharding.PartitionSpec(('y',)), - ), - ( - jax.sharding.PartitionSpec(('x', 'y')), - jax.sharding.PartitionSpec(('y', 'x')), - ), - ( - jax.sharding.PartitionSpec(('x',)), - jax.sharding.PartitionSpec(('y',)), - ), - ) - def test_reshard(self, save_spec, restore_spec): - devices = jax.devices() - len_devices = len(devices) - self.assertGreaterEqual(len_devices, 4) - - mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh((4, len_devices // 4)), ('x', 'y') - ) - dtype = np.int32 - pytree = { - 'x': test_utils.create_sharded_array( - np.arange(len_devices, dtype=dtype), mesh, save_spec - ) - } - abstract_pytree = { - 'x': jax.ShapeDtypeStruct( - shape=(len_devices,), - dtype=dtype, - sharding=jax.sharding.NamedSharding(mesh, restore_spec), - ) - } - - self.handler.save(self.directory, pytree) - restored = self.handler.load(self.directory, abstract_pytree) - expected = { - 'x': test_utils.create_sharded_array( - np.arange(len_devices, dtype=dtype), mesh, restore_spec - ) - } - test_utils.assert_tree_equal(self, expected, restored) - - def test_load_non_ocdbt(self): - with handler_with_options(use_ocdbt=False) as checkpoint_handler: - checkpoint_handler.save(self.directory, self.pytree) - self.assertFalse(type_handlers.is_ocdbt_checkpoint(self.directory)) - with handler_with_options(use_ocdbt=True) as checkpoint_handler: - restored = checkpoint_handler.load( - self.directory, - self.abstract_pytree, - ) - test_utils.assert_tree_equal(self, self.pytree, restored) - - def test_load_non_ocdbt_mixed(self): - pytree, abstract_pytree = create_mixed_format_pytree(strings=True) - with handler_with_options(use_ocdbt=False) as checkpoint_handler: - checkpoint_handler.save(self.directory, pytree) - self.assertFalse(type_handlers.is_ocdbt_checkpoint(self.directory)) - with handler_with_options(use_ocdbt=True) as checkpoint_handler: - restored = checkpoint_handler.load(self.directory, abstract_pytree) - test_utils.assert_tree_equal(self, pytree, restored) - - def test_check_zarray(self): - self.handler.save(self.directory, self.pytree) - zarr_path = self.directory / 'a' / '.zarray' - zarr_path.unlink(missing_ok=True) - test_utils.sync_global_processes( - 'PyTreeCheckpointHandlerTest:delete_zarray' - ) - self.assertFalse(zarr_path.exists()) - with self.assertRaises(FileNotFoundError): - self.handler.load( - self.directory, - self.abstract_pytree, - ) - - def test_without_abstract_pytree(self): - arr = test_utils.create_sharded_array( - np.arange(8), - jax.sharding.Mesh(jax.devices(), ('x',)), - jax.sharding.PartitionSpec('x'), - ) - pytree = [arr] - self.handler.save(self.directory, pytree) - restored = self.handler.load(self.directory) - test_utils.assert_tree_equal(self, pytree, restored) - - @parameterized.product(use_ocdbt=(True, False)) - def test_masked_shape_dtype_struct(self, use_ocdbt: bool): - - def _should_mask(keypath): - return keypath[0].key == 'a' or ( - keypath[0].key == 'c' and keypath[1].key == 'e' - ) - - def _mask(keypath, x): - return optax.MaskedNode() if _should_mask(keypath) else x - - def _none(keypath, x): - return None if _should_mask(keypath) else x - - masked_tree = jax.tree_util.tree_map_with_path(_mask, self.pytree) - expected = jax.tree_util.tree_map_with_path(_none, self.pytree) - - with handler_with_options(use_ocdbt=use_ocdbt) as handler: - handler.save(self.directory, masked_tree) - if use_ocdbt: - self.assertTrue(type_handlers.is_ocdbt_checkpoint(self.directory)) - - # Restore it with state which was given before applying masking. - restored = handler.load( - self.directory, - jax.tree.map(abstract_arrays.to_shape_dtype_struct, self.pytree), - ) - test_utils.assert_tree_equal(self, expected, restored) - - # Restore it with state after applying masking to it. - restored = handler.load( - self.directory, - jax.tree.map(abstract_arrays.to_shape_dtype_struct, masked_tree), - ) - test_utils.assert_tree_equal(self, expected, restored) - - # Restore it without any state. - restored = handler.load( - self.directory, - self.abstract_pytree, - ) - test_utils.assert_tree_equal(self, expected, restored) - - def test_finalize(self): - with handler_with_options(use_ocdbt=True) as checkpoint_handler: - checkpoint_handler.save(self.directory, self.pytree) - process_index = multihost.process_index() - process_dir = ( - self.directory / f'{ts_utils.PROCESS_SUBDIR_PREFIX}{process_index}' - ) - self.assertTrue(process_dir.exists()) - self.assertTrue(process_dir.is_dir()) - self.assertTrue(type_handlers.is_ocdbt_checkpoint(self.directory)) - - @parameterized.product(use_ocdbt=(True, False)) - def test_unregistered_types(self, use_ocdbt: bool): - data = {'uncheckpointable_field': datetime.timedelta(seconds=5)} - with handler_with_options(use_ocdbt=use_ocdbt) as checkpoint_handler: - with self.assertRaisesRegex( - ValueError, 'The following leaf types are not registered' - ): - checkpoint_handler.save( - self.directory, - data, - ) - - @parameterized.product( - target_data_file_size=[ - 50 * 1024, # 50KB - 10 * 1024, # 10KB - 0, - None, - ], - chunk_byte_size=[ - None, # unspecified - 5 * 1024, # 5KB - 100 * 1024, # greater than target_data_file_size - ], - use_zarr3=[True, False], - patch_default_ocdbt_data_file_size=[True, False], - ) - def test_ocdbt_target_data_file_size( - self, - target_data_file_size, - chunk_byte_size, - use_zarr3, - patch_default_ocdbt_data_file_size, - ): - """Test ocdbt_target_data_file_size.""" - array_len = 16 * 1024 # ~ 64KB of float data - custom_pytree = { - 'a': np.arange(array_len, dtype=np.int32), - 'b': np.arange(array_len * 2, dtype=np.float32), - 'c': { - 'a': ( - np.arange(array_len, dtype=np.int32).reshape( - 2, array_len // 2 - ) - ), - 'e': ( - np.arange(array_len * 2, dtype=np.float32).reshape( - 2, array_len - ) - ), - }, - } - shardings = { - 'a': self.abstract_pytree['a'].sharding, - 'b': self.abstract_pytree['b'].sharding, - 'c': { - 'a': self.abstract_pytree['c']['a'].sharding, - 'e': self.abstract_pytree['c']['e'].sharding, - }, - } - pytree = jax.tree.map(create_sharded_array, custom_pytree, shardings) - abstract_pytree = jax.tree.map(as_abstract_type, pytree) - - create_array_storage_options_fn = ( - lambda key, value: options_lib.ArrayOptions.Saving.StorageOptions( - chunk_byte_size=chunk_byte_size - ) - ) - - new_ocdbt_target_data_file_size = ( - 1024 - if patch_default_ocdbt_data_file_size - else ts_utils._DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE - ) - with mock.patch.object( - ts_utils, - '_DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE', - new_ocdbt_target_data_file_size, - ): - if patch_default_ocdbt_data_file_size: - assert ts_utils._DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE == 1024 - with handler_with_options( - use_ocdbt=True, - use_zarr3=use_zarr3, - ocdbt_target_data_file_size=target_data_file_size, - create_array_storage_options_fn=create_array_storage_options_fn, - ) as checkpoint_handler: - checkpoint_handler.save(self.directory, pytree) - - data_dir = self.directory / 'd' - self.assertTrue(data_dir.exists()) - self.assertTrue(data_dir.is_dir()) - - for f in data_dir.iterdir(): - if f.is_file(): - if target_data_file_size not in (0, None): - # it's expected the resulting file sizes can be larger than - # the target_data_file_size, so we give some buffer here - self.assertLessEqual( - f.stat().length, - target_data_file_size * 2.0, - ) # some room - if patch_default_ocdbt_data_file_size: - self.assertLessEqual( - f.stat().length, - ( - new_ocdbt_target_data_file_size * 4.0 - ), # TODO(niketkb): revisit culprit cl/786790774. - ) - - restored = checkpoint_handler.load(self.directory, abstract_pytree) - - test_utils.assert_tree_equal(self, pytree, restored) - - def test_local_registry(self): - - if multihost.is_pathways_backend(): - # This does not test anything on the pathways backend - # TODO(b/333114195): add proper pathways testing. - return - - class PlusOneHandler(scalar_leaf_handler.ScalarLeafHandler): - """A custom handler that adds one to all scalar values.""" - - def __init__(self, context: context_lib.Context | None = None): - super().__init__(context=context) - - async def serialize( - self, - params: Sequence[scalar_leaf_handler.ScalarSerializationParam], - serialization_context: serialization_types.SerializationContext, - ) -> Awaitable[None]: - updated_params = [ - scalar_leaf_handler.ScalarSerializationParam( - keypath=param.keypath, value=param.value + 1 - ) - for param in params - ] - - return await super().serialize(updated_params, serialization_context) - - leaf_registry = registry.BaseLeafHandlerRegistry() - leaf_registry.add(int, int, PlusOneHandler) - - with handler_with_options( - leaf_handler_registry=leaf_registry, - array_metadata_store=None, - use_zarr3=True, - ) as handler: - with self.assertRaisesRegex( - ValueError, 'The following leaf types are not registered' - ): - handler.save(self.directory, {'a': 3, 'b': 1.0}) - - handler.save(self.directory, {'a': 3}) - - with self.assertRaisesRegex( - ValueError, 'The following abstract leaf types are not registered' - ): - handler.load(self.directory, {'a': 3.0}) - - restored = handler.load(self.directory) - expected = {'a': 4} - self.assertEqual(restored, expected) - - def test_empty_custom_node(self): - - class PyTreeDict(dict): - pass - - jax.tree_util.register_pytree_node( - PyTreeDict, - lambda d: (tuple(d.values()), tuple(d.keys())), - lambda keys, values: PyTreeDict(dict(zip(keys, values))), - ) - - with self.assertRaisesRegex(ValueError, 'Found empty item'): - self.handler.save(self.directory, PyTreeDict()) - - self.handler.save(self.directory, {'a': PyTreeDict()}) - restored = self.handler.load(self.directory) - self.assertDictEqual({'a': {}}, restored) - - restored = self.handler.load(self.directory, {'a': PyTreeDict()}) - test_utils.assert_tree_equal(self, {'a': PyTreeDict()}, restored) - - @parameterized.parameters((5,), (9,)) - def test_concurrent_gb_save(self, limit_bytes): - # TODO(b/346811105): Enable for Pathways. - if multihost.is_pathways_backend(): - self.skipTest( - 'Disabled on Pathways because completion_times cannot updated by' - ' reference outside remote Python.' - ) - sleep_time = 1.0 - sharding = jax.sharding.NamedSharding( - jax.sharding.Mesh( - jax.devices(), - ('x',), - ), - jax.sharding.PartitionSpec( - None, - ), - ) - # 4 arrays, each has a single chunk, with 4 bytes each. - tree = jax.tree.map( - functools.partial( - array_test_utils.create_sharded_array, sharding=sharding - ), - { - 'a': np.arange(1, dtype=np.int32), - 'b': np.arange(1, dtype=np.int32), - 'c': np.arange(1, dtype=np.int32), - 'd': np.arange(1, dtype=np.int32), - }, - ) - byte_limiter = test_utils.get_byte_limiter(limit_bytes, sleep_time) - with mock.patch.object( - serialization, - 'get_byte_limiter', - new=lambda _: byte_limiter, - ), handler_with_options( - save_concurrent_bytes=limit_bytes, - ) as handler: - handler.save(self.directory, tree) - # Replicated shards are handled within the _write_array_shard function. - # Since shards are only saved once per replica, we only have to check - # the primary process. - completion_times = byte_limiter.completion_times - if multihost.process_index() == 0: - self.assertLen(completion_times, len(jax.tree.leaves(tree))) - test_utils.assert_every_n_is_x_apart( - self, - completion_times, - limit_bytes // np.int32().itemsize, - sleep_time, - ) - - @parameterized.parameters((5,), (9,)) - def test_concurrent_gb_restore(self, limit_bytes): - # TODO(b/346811105): Enable for Pathways. - if multihost.is_pathways_backend(): - self.skipTest( - 'Disabled on Pathways because completion_times cannot updated by' - ' reference outside remote Python.' - ) - sleep_time = 1.0 - sharding = jax.sharding.NamedSharding( - jax.sharding.Mesh( - jax.devices(), - ('x',), - ), - jax.sharding.PartitionSpec( - None, - ), - ) - # 4 arrays, each has a single chunk, with 4 bytes each. - tree = jax.tree.map( - functools.partial( - array_test_utils.create_sharded_array, sharding=sharding - ), - { - 'a': np.arange(1, dtype=np.int32), - 'b': np.arange(1, dtype=np.int32), - 'c': np.arange(1, dtype=np.int32), - 'd': np.arange(1, dtype=np.int32), - }, - ) - self.handler.save(self.directory, tree) - - byte_limiter = test_utils.get_byte_limiter(limit_bytes, sleep_time) - with mock.patch.object( - serialization, - 'get_byte_limiter', - new=lambda _,: byte_limiter, - ), handler_with_options(restore_concurrent_bytes=limit_bytes) as handler: - restored = handler.load(self.directory) - test_utils.assert_tree_equal(self, tree, restored) - completion_times = byte_limiter.completion_times - self.assertLen( - completion_times, - len(jax.tree.leaves(tree)), - ) - test_utils.assert_every_n_is_x_apart( - self, - completion_times, - limit_bytes // np.int32().itemsize, - sleep_time, - ) - - @parameterized.product(enable_pinned_host_transfer=(True, False)) - def test_enable_pinned_host_transfer(self, enable_pinned_host_transfer): - if multihost.is_pathways_backend(): - self.skipTest( - 'Disabled on Pathways because local variables cannot updated by' - ' reference outside remote Python.' - ) - true_count = 0 - false_count = 0 - - original_transfer_arrays_to_host = replica_slices.transfer_arrays_to_host - - def _transfer_arrays_to_host( - arrays, - replica_id, - use_replica_parallel, - min_slice_bytes_for_replica_parallel, - max_replicas_for_replica_parallel, - enable_pinned_host_transfer, - ): - nonlocal true_count, false_count - if enable_pinned_host_transfer: - true_count += 1 - else: - false_count += 1 - return original_transfer_arrays_to_host( - arrays, - replica_id, - use_replica_parallel=use_replica_parallel, - min_slice_bytes_for_replica_parallel=min_slice_bytes_for_replica_parallel, - max_replicas_for_replica_parallel=max_replicas_for_replica_parallel, - enable_pinned_host_transfer=enable_pinned_host_transfer, - ) - - with mock.patch.object( - replica_slices, - 'transfer_arrays_to_host', - new=_transfer_arrays_to_host, - ), handler_with_options( - enable_pinned_host_transfer=enable_pinned_host_transfer, - ) as handler: - handler.save(self.directory, self.pytree) - - if enable_pinned_host_transfer: - self.assertGreater(true_count, 0) - self.assertEqual(false_count, 0) - else: - self.assertEqual(true_count, 0) - self.assertGreater(false_count, 0) - - @parameterized.product( - use_ocdbt=(True, False), - pytree_metadata_options=( - tree_metadata.PyTreeMetadataOptions(support_rich_types=False), - tree_metadata.PyTreeMetadataOptions(support_rich_types=True), - ), - ) - def test_write_shape_metadata_missing_for_all_types_other_than_jax_array( - self, - use_ocdbt: bool, - pytree_metadata_options: tree_metadata.PyTreeMetadataOptions, - ): - checkpoint = { - 'a': 1, - 'b': np.array([2]), - 'c': 'hello', - } - expected_metadata = { - 'a': np.int64, - 'b': numpy_leaf_handler.NumpyMetadata( - shape=(1,), - dtype=checkpoint['b'].dtype, - storage_metadata=value_metadata.StorageMetadata( - chunk_shape=(1,), write_shape=None - ), - ), - 'c': str, - } - with handler_with_options( - use_ocdbt=use_ocdbt, - pytree_metadata_options=pytree_metadata_options, - array_metadata_store=ARRAY_METADATA_STORE, - ) as checkpoint_handler: - checkpoint_handler.save(self.directory, checkpoint) - - self.assertFalse((self.directory / 'array_metadatas').exists()) - restored_metadata = checkpoint_handler.metadata(self.directory) - self.assertEqual( - expected_metadata, - restored_metadata, - ) - - @parameterized.product( - use_ocdbt=(True, False), - pytree_metadata_options=( - tree_metadata.PyTreeMetadataOptions(support_rich_types=False), - tree_metadata.PyTreeMetadataOptions(support_rich_types=True), - ), - ) - def test_write_shape_in_metadata_disabled( - self, - use_ocdbt: bool, - pytree_metadata_options: tree_metadata.PyTreeMetadataOptions, - ): - with handler_with_options( - use_ocdbt=use_ocdbt, - pytree_metadata_options=pytree_metadata_options, - array_metadata_store=None, - ) as checkpoint_handler: - checkpoint_handler.save(self.directory, self.pytree) - expected_tree_with_write_shapes = { - 'a': {'write_shape': None}, - 'b': {'write_shape': None}, - 'c': { - 'a': {'write_shape': None}, - 'e': {'write_shape': None}, - }, - 'x': {'write_shape': None}, - 'y': {'write_shape': None}, - } - metadata = checkpoint_handler.metadata(self.directory) - tree_with_write_shapes = jax.tree.map( - lambda m: {'write_shape': m.storage_metadata.write_shape}, metadata - ) - self.assertDictEqual( - expected_tree_with_write_shapes, tree_with_write_shapes - ) - - # TODO(b/382230550): Add test for chunk_shape != write_shape. - @parameterized.product( - use_ocdbt=(True, False), - pytree_metadata_options=( - tree_metadata.PyTreeMetadataOptions(support_rich_types=False), - tree_metadata.PyTreeMetadataOptions(support_rich_types=True), - ), - ) - def test_write_shape_in_metadata( - self, - use_ocdbt: bool, - pytree_metadata_options: tree_metadata.PyTreeMetadataOptions, - ): - with handler_with_options( - use_ocdbt=use_ocdbt, pytree_metadata_options=pytree_metadata_options - ) as checkpoint_handler: - checkpoint_handler.save(self.directory, self.pytree) - - expected_tree_with_write_shapes = { - 'a': { - 'write_shape': test_utils.get_expected_chunk_shape( - self.pytree['a'] - ) - }, - 'b': {'write_shape': (2,)}, - 'c': { - 'a': {'write_shape': (1, 1)}, - 'e': {'write_shape': (2, 1)}, - }, - 'x': {'write_shape': ()}, - 'y': {'write_shape': ()}, - } - metadata = checkpoint_handler.metadata(self.directory) - tree_with_write_shapes = jax.tree.map( - lambda m: {'write_shape': m.storage_metadata.write_shape}, metadata - ) - self.assertDictEqual( - expected_tree_with_write_shapes, tree_with_write_shapes - ) - - @parameterized.product(use_ocdbt=(True, False)) - def test_array_metadata_disabled(self, use_ocdbt: bool): - with handler_with_options( - use_ocdbt=use_ocdbt, array_metadata_store=None - ) as checkpoint_handler: - pytree, abstract_pytree = create_mixed_format_pytree() - - checkpoint_handler.save(self.directory, pytree) - - self.validate_save( - self.directory, - abstract_pytree, - pytree, - checkpoint_handler, - ) - - self.assertFalse((self.directory / 'array_metadatas').exists()) - - @parameterized.product(use_ocdbt=(True, False)) - def test_array_metadata(self, use_ocdbt: bool): - with handler_with_options(use_ocdbt=use_ocdbt) as checkpoint_handler: - - checkpoint_handler.save(self.directory, self.pytree) - - self.validate_save( - self.directory, - self.abstract_pytree, - self.pytree, - checkpoint_handler, - ) - - self.assertTrue((self.directory / 'array_metadatas').exists()) - if multihost.is_primary_host(0): - array_metadatas = asyncio.run(ARRAY_METADATA_STORE.read(self.directory)) - self.assertIsInstance(array_metadatas, dict) - per_process_metadatas = [ - array_metadata.SerializedArrayMetadata( - param_name='a', - write_shape=test_utils.get_expected_chunk_shape( - self.pytree['a'] - ), - chunk_shape=test_utils.get_expected_chunk_shape( - self.pytree['a'] - ), - ), - array_metadata.SerializedArrayMetadata( - param_name='b', - write_shape=(2,), - chunk_shape=(2,), - ), - array_metadata.SerializedArrayMetadata( - param_name='c.a', - write_shape=(1, 1), - chunk_shape=(1, 1), - ), - array_metadata.SerializedArrayMetadata( - param_name='c.e', - write_shape=(2, 1), - chunk_shape=(2, 1), - ), - array_metadata.SerializedArrayMetadata( - param_name='x', - write_shape=(), - chunk_shape=(), - ), - array_metadata.SerializedArrayMetadata( - param_name='y', - write_shape=(), - chunk_shape=(), - ), - ] - processes = range(multihost.process_count()) - expected_array_metadatas = { - idx: per_process_metadatas for idx in processes - } - self.assertSameElements( - expected_array_metadatas.keys(), array_metadatas.keys() - ) - for process_index in expected_array_metadatas: - self.assertEqual( # pylint: disable=g-generic-assert - sorted( - expected_array_metadatas[process_index], - key=lambda x: x.param_name, - ), - sorted( - array_metadatas[process_index], key=lambda x: x.param_name - ), - ) - - @parameterized.product(use_ocdbt=(True, False)) - def test_save_with_missing_array_metadata_file(self, use_ocdbt: bool): - if multihost.process_index() != 0: # only test on primary host - self.skipTest('Test only for primary host to avoid barrier timeout.') - - class PathResolverReturningNoMetadataFiles( - array_metadata_store_lib.PathResolver - ): - - async def get_read_file_paths( - self, checkpoint_dir: epath.Path, process_index: int | None = None - ) -> Iterator[epath.Path] | epath.Path | None: - return None - - with handler_with_options( - use_ocdbt=use_ocdbt, - array_metadata_store=array_metadata_store_lib.Store( - path_resolver=PathResolverReturningNoMetadataFiles() - ), - ) as checkpoint_handler: - with self.assertRaisesRegex( - ValueError, 'No ArrayMetadata found for process_index' - ): - checkpoint_handler.save(self.directory, self.pytree) - - @parameterized.product(use_ocdbt=(True, False)) - def test_save_with_missing_array_metadata_for_params(self, use_ocdbt: bool): - if multihost.process_index() != 0: # only test on primary host - self.skipTest('Test only for primary host to avoid barrier timeout.') - - class MissingArrayMetadataSerializer( - array_metadata_store_lib.Serializer - ): - - def deserialize( - self, serialized: str - ) -> List[array_metadata.SerializedArrayMetadata]: - true_data = super().deserialize(serialized) - return [true_data.pop(0)] # Delete the rest and return partial data. - - with handler_with_options( - use_ocdbt=use_ocdbt, - array_metadata_store=array_metadata_store_lib.Store( - serializer=MissingArrayMetadataSerializer() - ), - ) as checkpoint_handler: - with self.assertRaisesRegex( - ValueError, 'No ArrayMetadata found for param_info' - ): - checkpoint_handler.save(self.directory, self.pytree) - - @parameterized.parameters((True,), (False,)) - def test_zero_size_array(self, use_jax_array: bool): - arr = np.ones(shape=(0,)) - mesh = jax.sharding.Mesh(np.array(jax.devices()), axis_names=('x',)) - pspec = jax.sharding.PartitionSpec() - if use_jax_array: - arr = test_utils.create_sharded_array(arr, mesh, pspec) - pytree = [arr] - with self.assertRaisesRegex(ValueError, 'zero size'): - self.handler.save(self.directory, pytree) - - @parameterized.product(use_ocdbt=(True, False)) - def test_save_restore_random_keys(self, use_ocdbt: bool): - """Test saving and restoring random keys within a pytree.""" - - # TODO(b/393160483) investigate Pathways remote Python support for - # random.keys. - if multihost.is_pathways_backend(): - self.skipTest( - 'Disabled on Pathways because random keys are not supported by' - ' remote Python.' - ) - - mesh = jax.sharding.Mesh(jax.devices(), ('x',)) - sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) - - pytree = { - 'keys': { - 'kone': jax.random.key(jnp.array(0, device=sharding)), - 'impl_key': { - 'rbg': jax.random.key( - jnp.array(1, device=sharding), impl='rbg' - ), - 'unsafe_rbg': jax.random.key( - jnp.array(2, device=sharding), impl='unsafe_rbg' - ), - }, - 'split_keys': jax.random.split( - jax.random.key(jnp.array(123, device=sharding)), num=10 - ), - }, - 'arrays': self.pytree, - } - - with handler_with_options( - use_ocdbt=use_ocdbt, - ) as save_handler: - save_handler.save(self.directory, pytree) - - with handler_with_options( - use_ocdbt=use_ocdbt, - ) as load_handler: - restored = load_handler.load(self.directory) - test_utils.assert_tree_equal(self, pytree, restored) - - def test_pinned_host_loading(self): - if multihost.is_pathways_backend(): - # TODO(b/404915487): Reenable when possible. - self.skipTest('Disabled due to b/404915487.') - - mesh = jax.sharding.Mesh( - np.asarray(jax.devices()).reshape((1, len(jax.devices()))), ('x', 'y') - ) - sharding = jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec('x', 'y') - ).with_memory_kind('pinned_host') - - pytree = dict(arr=jnp.ones((1024, 512), device=sharding)) - self.handler.save(self.directory, pytree) - - abstract_pytree = dict( - arr=jax.ShapeDtypeStruct( - pytree['arr'].shape, pytree['arr'].dtype, sharding=sharding - ) - ) - restored = self.handler.load(self.directory, abstract_pytree) - expected = dict(arr=jax.device_put(np.ones((1024, 512)), sharding)) - test_utils.assert_tree_equal(self, expected, restored) - - @parameterized.product( - use_ocdbt=(True, False), - reference_item=( - { - 'a': 0, - 'b': 0, - 'c': { - 'e': 0, - }, - }, - { - 'a': 0, - 'c': { - 'a': 0, - 'e': 0, - }, - }, - { - 'a': 0, - 'b': 0, - }, - ), - ) - def test_restore_item_has_missing_leaves( - self, use_ocdbt: bool, reference_item: dict[str, Any] - ): - with handler_with_options( - use_ocdbt=use_ocdbt, - ) as handler: - handler.save(self.directory, self.pytree) - - with self.assertRaisesRegex( - ValueError, 'User-provided restore item and on-disk value' - ): - handler.load(self.directory, reference_item) - - def test_partial_restore_with_placeholder_simple(self): - original_item = { - 'a': np.arange(8), - 'b': np.arange(8), - 'c': { - 'a': np.arange(8), - 'e': np.arange(8), - }, - } - reference_item = jax.tree.map(as_abstract_type, original_item) - reference_item['b'] = PLACEHOLDER - reference_item['c']['e'] = PLACEHOLDER - expected = { - 'a': original_item['a'], - 'b': PLACEHOLDER, - 'c': { - 'a': original_item['c']['a'], - 'e': PLACEHOLDER, - }, - } - - simple_dir = epath.Path( - self.create_tempdir(name='simple_placeholder_dir').full_path - ) - - with handler_with_options() as handler: - handler.save(simple_dir, original_item) - restored = handler.load(simple_dir, reference_item) - test_utils.assert_tree_equal(self, expected, restored) - - @parameterized.product(use_ocdbt=(True, False)) - def test_partial_restore_with_placeholder(self, use_ocdbt: bool): - """Test saving and restoring placeholder.""" - with handler_with_options( - use_ocdbt=use_ocdbt, - ) as save_handler: - save_handler.save(self.directory, self.pytree) - - with self.subTest('success'): - reference_item = self.abstract_pytree.copy() - reference_item['b'] = PLACEHOLDER - reference_item['c']['e'] = PLACEHOLDER - - expected = self.pytree.copy() - expected['b'] = PLACEHOLDER - expected['c']['e'] = PLACEHOLDER - - with handler_with_options( - use_ocdbt=use_ocdbt, - ) as restore_handler: - restored = restore_handler.load(self.directory, reference_item) - test_utils.assert_tree_equal(self, expected, restored) - - with self.subTest('missing_leaf'): - reference_item = self.abstract_pytree.copy() - reference_item['b'] = PLACEHOLDER - reference_item['c']['e'] = PLACEHOLDER - del reference_item['c']['a'] - - with handler_with_options( - use_ocdbt=use_ocdbt, - ) as restore_handler: - with self.assertRaisesRegex( - ValueError, 'User-provided restore item and on-disk value' - ): - restore_handler.load(self.directory, reference_item) - - with self.subTest('non_leaf_placeholder'): - reference_item = self.abstract_pytree.copy() - reference_item['c'] = PLACEHOLDER - - with handler_with_options( - use_ocdbt=use_ocdbt, - ) as restore_handler: - with self.assertRaisesRegex( - ValueError, 'User-provided restore item and on-disk value' - ): - restore_handler.load(self.directory, reference_item) - - @parameterized.product(use_ocdbt=(True, False)) - def test_partial_restore_with_omission(self, use_ocdbt: bool): - """Basic save and restore test.""" - directory = self.directory / 'partial_restore' - - with handler_with_options( - use_ocdbt=use_ocdbt, - ) as save_handler: - save_handler.save(directory, self.pytree) - - with self.subTest('success'): - with handler_with_options( - use_ocdbt=use_ocdbt, - partial_load=True, - ) as restore_handler: - # Create a new pytree structure with the same leaves. - # Leaves (ShapeDtypeStruct) are immutable and can be shared. - reference_item = jax.tree.map(lambda x: x, self.abstract_pytree) - # Omit 'b', 'c.e', and 'x' from the reference item. - del reference_item['b'] - del reference_item['c']['e'] - del reference_item['x'] - expected = { - 'a': self.pytree['a'], - 'c': { - 'a': self.pytree['c']['a'], - }, - 'y': self.pytree['y'], - } - restored = restore_handler.load(directory, reference_item) - test_utils.assert_tree_equal(self, expected, restored) - - with self.subTest('extra_leaf'): - with handler_with_options( - use_ocdbt=use_ocdbt, - partial_load=True, - ) as restore_handler: - # Create a new pytree structure with the same leaves. - # Leaves (ShapeDtypeStruct) are immutable and can be shared. - reference_item = jax.tree.map(lambda x: x, self.abstract_pytree) - del reference_item['b'] - del reference_item['c']['e'] - del reference_item['x'] - # Add an extra leaf to the reference item. - reference_item['z'] = jax.ShapeDtypeStruct([0], np.int64) - with self.assertRaisesRegex( - ValueError, - r"Missing 1 keys in structure path \(\), including: \['z'\]", - ): - restore_handler.load(directory, reference_item) - - @parameterized.product(use_zarr3=(True, False), use_ocdbt=(True, False)) - def test_custom_leaf_handler(self, use_zarr3: bool, use_ocdbt: bool): - - pytree = { - 'point1': Point(1, 2), - 'point2': Point(3, 4), - 'nested': { - 'point3': Point(5, 6), - 'point4': Point(7, 8), - }, - 'string_leaf': 'string_leaf', - 'number': 123, - 'pytree': self.pytree, - } - - array_metadata_store = ARRAY_METADATA_STORE - - leaf_handler_registry = registry.StandardLeafHandlerRegistry() - leaf_handler_registry.add(Point, AbstractPoint, PointLeafHandler) - - def _as_abstract_type(x): - if isinstance(x, Point): - return AbstractPoint - return as_abstract_type(x) - - with handler_with_options( - use_ocdbt=use_ocdbt, - leaf_handler_registry=leaf_handler_registry, - array_metadata_store=array_metadata_store, - use_zarr3=use_zarr3, - ) as checkpoint_handler: - checkpoint_handler.save(self.directory, pytree) - abstract_pytree = jax.tree.map(_as_abstract_type, pytree) - restored = checkpoint_handler.load(self.directory, abstract_pytree) - - test_utils.assert_tree_equal(self, pytree, restored) - - self.validate_metadata( - expected_reference_metadata_tree=pytree, - actual_metadata=checkpoint_handler.metadata(self.directory), - pytree_metadata_options=self.pytree_metadata_options, - array_metadata_store=array_metadata_store, - ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py index 778272de2..2f1f74740 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py @@ -59,6 +59,7 @@ from absl import logging from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types +from orbax.checkpoint.experimental.v1._src.path import format_utils CheckpointableHandler = handler_types.CheckpointableHandler RegistryEntry = tuple[Type[CheckpointableHandler], str | None] @@ -116,6 +117,10 @@ class NoEntryError(KeyError): """Raised when no entry exists in the registry.""" +class NotHandleableError(ValueError): + """Raised when a checkpointable is not handleable by a handler.""" + + class _DefaultCheckpointableHandlerRegistry(CheckpointableHandlerRegistry): """Default implementation of `CheckpointableHandlerRegistry`.""" @@ -346,7 +351,7 @@ def _construct_handler_instance( def _get_possible_handlers( registry: CheckpointableHandlerRegistry, - is_handleable_fn: Callable[[CheckpointableHandler, Any], bool], + is_handleable: Callable[[CheckpointableHandler, Any], bool], checkpointable: Any | None, name: str, ) -> Sequence[CheckpointableHandler]: @@ -370,7 +375,7 @@ def _get_possible_handlers( handler for handler, checkpointable_name in registry_entries if checkpointable_name is None - and is_handleable_fn(handler, checkpointable) + and is_handleable(handler, checkpointable) ] if not possible_handlers: available_handlers = [ @@ -389,6 +394,29 @@ def _get_possible_handlers( return possible_handlers +def _maybe_raise_not_handleable_error( + handler: CheckpointableHandler, + is_handleable: Callable[[CheckpointableHandler, Any], bool], + checkpointable: Any, + name: str, +) -> None: + """Raises an error if the handler cannot handle the named checkpointable.""" + if not is_handleable(handler, checkpointable): + error_msg = ( + f'Handler {type(handler)}, explicitly registered for {name}, cannot' + ' handle the provided checkpointable.' + ) + if name == format_utils.PYTREE_CHECKPOINTABLE_KEY: + error_msg += ( + ' Usage of the name' + f' {format_utils.PYTREE_CHECKPOINTABLE_KEY} indicates that you' + ' attempted to save using `save/load_pytree`. Please ensure the' + ' provided checkpointable is a nested PyTree, and not a leaf node.' + ) + error_msg += f'Received checkpointable: {checkpointable}.' + raise NotHandleableError(error_msg) + + def resolve_handler_for_save( registry: CheckpointableHandlerRegistry, checkpointable: Any, @@ -411,22 +439,29 @@ def resolve_handler_for_save( Raises: NoEntryError: If no compatible `CheckpointableHandler` can be found. + NotHandleableError: If the resolved handler cannot handle the checkpointable + and the handler was explicitly registered for the checkpointable name. Returns: A CheckpointableHandler instance. """ + + def is_handleable(handler: CheckpointableHandler, ckpt: Any) -> bool: + return handler.is_handleable(ckpt) + # If explicitly registered, use that first. if registry.has(name): - return _construct_handler_instance(name, registry.get(name)) + handler = _construct_handler_instance(name, registry.get(name)) + _maybe_raise_not_handleable_error( + handler, is_handleable, checkpointable, name + ) + return handler if checkpointable is None: raise ValueError('checkpointable must not be None for saving.') - def is_handleable_fn(handler: CheckpointableHandler, ckpt: Any) -> bool: - return handler.is_handleable(ckpt) - possible_handlers = _get_possible_handlers( - registry, is_handleable_fn, checkpointable, name + registry, is_handleable, checkpointable, name ) # Prefer the first handler in the absence of any other information. @@ -462,18 +497,28 @@ def resolve_handler_for_load( name: The name of the checkpointable. handler_typestr: A CheckpointableHandler typestr to guide resolution. + Raises: + NotHandleableError: If the resolved handler cannot handle the abstract + checkpointable and the handler was explicitly registered for the + checkpointable name. + Returns: A CheckpointableHandler instance. """ - # If explicitly registered, use that first. - if registry.has(name): - return _construct_handler_instance(name, registry.get(name)) - def is_handleable_fn( handler: CheckpointableHandler, ckpt: Any ) -> bool | None: return handler.is_abstract_handleable(ckpt) + # If explicitly registered, use that first. + if registry.has(name): + handler = _construct_handler_instance(name, registry.get(name)) + if abstract_checkpointable is not None: + _maybe_raise_not_handleable_error( + handler, is_handleable_fn, abstract_checkpointable, name + ) + return _construct_handler_instance(name, registry.get(name)) + possible_handlers = _get_possible_handlers( registry, is_handleable_fn, abstract_checkpointable, name ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py index 040a4662a..61b403c6e 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py @@ -57,11 +57,10 @@ def _create_value_metadata(self, value): dtype=value.dtype, storage_metadata=storage_metadata, ) - elif isinstance(value, (int, float)): - dtype = np.float64 - if isinstance(value, int): - dtype = np.int64 - return dtype + elif isinstance(value, (int, np.integer)): + return 0 + elif isinstance(value, (float, np.floating)): + return 0.0 else: raise TypeError(f'Unsupported type: {type(value)}') diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py index 1fc554648..89650adb8 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py @@ -272,6 +272,12 @@ def save_checkpointables_impl( partial_save: bool = False, ) -> async_types.AsyncResponse[None]: """See caller docstrings.""" + if not isinstance(checkpointables, dict): + raise ValueError( + f'`checkpointables` must be a dict, but got {type(checkpointables)}' + ) + if not checkpointables: + raise ValueError('`checkpointables` must be a non-empty dict.') _maybe_apply_nest_asyncio() context = context_lib.get_context() path = epath.Path(path) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py index 1c9e68749..942488e32 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py @@ -20,6 +20,7 @@ import asyncio import dataclasses +import typing from typing import Awaitable, Sequence, cast from absl import logging @@ -31,13 +32,16 @@ from orbax.checkpoint._src.metadata import value as value_metadata from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0 from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.serialization import protocol_utils from orbax.checkpoint.experimental.v1._src.serialization import types -ArraySerializationParam = types.SerializationParam[jax.Array] -ArrayDeserializationParam = types.DeserializationParam["AbstractShardedArray"] Shape = arrays_types_v0.Shape AbstractShardedArray = types.AbstractShardedArray +ArraySerializationParam = types.SerializationParam[jax.Array] +ArrayDeserializationParam = types.DeserializationParam[ + AbstractShardedArray +] @dataclasses.dataclass @@ -148,12 +152,15 @@ def _create_v0_restore_paraminfo( ) -> type_handlers_v0.ParamInfo: """Creates a V0 ParamInfo from V1 params and contexts for loading.""" - loading_options = context.array_options.Loading + loading_options = context.array_options.loading if isinstance(param.value, ArrayMetadata): # the write_shape is populated for metadata() calls. v = cast(ArrayMetadata, param.value) - write_shape = v.storage_metadata.write_shape + if v.storage_metadata is not None: + write_shape = v.storage_metadata.write_shape + else: + write_shape = None else: write_shape = None @@ -176,23 +183,20 @@ def _create_v0_restorearg( context: context_lib.Context, ) -> type_handlers_v0.ArrayRestoreArgs: """Creates a V0 ArrayRestoreArgs from V1 params.""" - - if param.value is None: + value = param.value + if value is None or isinstance(value, type): return type_handlers_v0.ArrayRestoreArgs(restore_type=jax.Array) - else: - v = param.value - if not isinstance(v, (jax.Array, jax.ShapeDtypeStruct, ArrayMetadata)): - raise ValueError( - "ArrayDeserializationParam.value is an unsupported type:" - f" {type(v)} for param.name: {param.name}" - ) + elif protocol_utils.is_subclass_protocol(value, AbstractShardedArray): + value = typing.cast(AbstractShardedArray, value) return type_handlers_v0.ArrayRestoreArgs( restore_type=jax.Array, - dtype=v.dtype, - sharding=v.sharding, - shape=v.shape, + dtype=value.dtype, + sharding=value.sharding, + shape=value.shape, strict=not context.array_options.loading.enable_padding_and_truncation, ) + else: + raise TypeError(f'Unrecognized abstract value type: {type(value)}') async def _async_futures(commit_futures: Sequence[future.Future]): @@ -212,7 +216,7 @@ def __init__( self._context, ) - logging.vlog(1, "ArrayLeafHandler created.") + logging.vlog(1, 'ArrayLeafHandler created.') async def serialize( self, @@ -302,7 +306,7 @@ async def _convert_to_array_metadata() -> Sequence[ArrayMetadata]: ) ret.append(array_metadata) - logging.vlog(1, "array_metadata: %r", array_metadata) + logging.vlog(1, 'array_metadata: %r', array_metadata) return ret diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py index 23df85928..5f60ef0df 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py @@ -115,8 +115,11 @@ def _construct_deserialization_param( ) -> types.DeserializationParam[ types.AbstractShardedArray | types.AbstractArray + | Type[types.AbstractArray] | types.AbstractScalar + | Type[types.AbstractScalar] | types.AbstractString + | Type[types.AbstractString] | None ]: """Constructs a DeserializationParam from a ParamInfo and RestoreArg.""" diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/format_utils.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/format_utils.py new file mode 100644 index 000000000..ee2cb3eb2 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/format_utils.py @@ -0,0 +1,28 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for PyTree serialization format.""" + +from etils import epath +from orbax.checkpoint._src.path import async_path +from orbax.checkpoint.experimental.v1._src.path import types as path_types + + +_OCDBT_MANIFEST_FILE = 'manifest.ocdbt' + + +async def is_ocdbt_checkpoint(path: path_types.PathLike) -> bool: + """Determines whether a checkpoint uses OCDBT format.""" + path = epath.Path(path) + return await async_path.exists(path / _OCDBT_MANIFEST_FILE) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py index 7be66a956..b9ed98e71 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py @@ -20,6 +20,7 @@ import asyncio import dataclasses +import typing from typing import Awaitable, Sequence from absl import logging @@ -33,7 +34,9 @@ NumpySerializationParam = types.SerializationParam[np.ndarray] -NumpyDeserializationParam = types.DeserializationParam[types.AbstractArray] +NumpyDeserializationParam = types.DeserializationParam[ + types.AbstractArray +] Shape = arrays_types.Shape AbstractArray = types.AbstractArray @@ -87,7 +90,7 @@ def _create_v0_saving_paraminfo( use_zarr3=saving_options.use_zarr3, ocdbt_target_data_file_size=saving_options.ocdbt_target_data_file_size, ts_context=serialization_context.ts_context, - value_typestr="np.ndarray", + value_typestr='np.ndarray', ) @@ -121,7 +124,7 @@ def _create_v0_restore_paraminfo( ) -> type_handlers_v0.ParamInfo: """Creates a V0 ParamInfo from V1 params and contexts for loading.""" - loading_options = context.array_options.Loading + loading_options = context.array_options.loading return type_handlers_v0.ParamInfo( name=param.name, @@ -141,26 +144,17 @@ def _create_v0_restorearg( ) -> type_handlers_v0.RestoreArgs: """Creates a V0 RestoreArgs from V1 params.""" - if param.value is None: - return type_handlers_v0.RestoreArgs(restore_type=np.ndarray) + value = param.value + if value is None or isinstance(value, type): + return type_handlers_v0.RestoreArgs( + restore_type=np.ndarray, + ) else: - v = param.value - if not isinstance( - v, - ( - np.ndarray, - NumpyShapeDtype, - NumpyMetadata, - ), - ): - raise ValueError( - f"NumpyDeserializationParam.value is an unsupported type: {type(v)}" - ) - - logging.vlog(1, "name: %s, v.dtype: %s", param.name, v.dtype) + value = typing.cast(types.AbstractArray, value) + logging.vlog(1, 'name: %s, v.dtype: %s', param.name, value.dtype) return type_handlers_v0.RestoreArgs( restore_type=np.ndarray, - dtype=v.dtype, + dtype=value.dtype, ) @@ -179,7 +173,7 @@ def __init__( self._context = context_lib.get_context(context) self._handler_impl = _create_v0_numpy_handler() - logging.vlog(1, "NumpyLeafHandler created.") + logging.vlog(1, 'NumpyLeafHandler created.') async def serialize( self, @@ -212,7 +206,7 @@ async def serialize( async def deserialize( self, - params: Sequence[types.DeserializationParam[AbstractArray]], + params: Sequence[NumpyDeserializationParam], deserialization_context: types.DeserializationContext, ) -> Awaitable[Sequence[np.ndarray]]: """Returns sequence of np.ndarrays from a stored checkpointable location. @@ -224,7 +218,6 @@ async def deserialize( Returns: The deserialized sequence of nd.ndarays as leaves. """ - # validate all parameters paraminfos = [ _create_v0_restore_paraminfo(p, self._context, deserialization_context) @@ -268,7 +261,7 @@ async def _convert_to_numpy_metadata() -> Sequence[NumpyMetadata]: ) ret.append(numpy_metadata) - logging.vlog(1, "numpy_metadata: %r", numpy_metadata) + logging.vlog(1, 'numpy_metadata: %r', numpy_metadata) return ret diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registry.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registry.py index 2b1648321..c3be2d9b2 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registry.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registry.py @@ -28,6 +28,14 @@ import typing_extensions +class AlreadyRegisteredTypeError(ValueError): + """Raised when a leaf is already registered.""" + + +class UnregisteredTypeError(ValueError): + """Raised when a leaf is not registered.""" + + # The standard type and abstract type to handler mapping. # The type to abstract type pairs are well defined standard and users should # rarely need to override the pair. @@ -90,7 +98,7 @@ def get( self, leaf_type: Type[types.Leaf] ) -> Type[types.LeafHandler[types.Leaf, Any]]: if (handler_type := self._try_get(leaf_type)) is None: - raise ValueError( + raise UnregisteredTypeError( f'Unknown Leaf type: "{leaf_type}". Must register it with' ' LeafHandlerRegistry.' ) @@ -122,7 +130,7 @@ def get_abstract( abstract_type: Type[types.AbstractLeaf], ) -> Type[types.LeafHandler[Any, types.AbstractLeaf]]: if (handler_type := self._try_get_abstract(abstract_type)) is None: - raise ValueError( + raise UnregisteredTypeError( f'Unknown AbstractLeaf type: "{abstract_type}". Must register it with' ' LeafHandlerRegistry.' ) @@ -156,7 +164,7 @@ def add( current_abstract_handle_type = self._try_get_abstract(abstract_type) if not override and (current_handler_type or current_abstract_handle_type): - raise ValueError( + raise AlreadyRegisteredTypeError( f'Leaf_type[{leaf_type}] or abstract_type[{abstract_type}] has' f' already registered, current_handler: {current_handler_type}, ' f'current_abstract_handle_type: {current_abstract_handle_type}' @@ -177,7 +185,7 @@ def add( current_abstract_handle_type and current_handler_type != current_abstract_handle_type ): - raise ValueError( + raise AlreadyRegisteredTypeError( f'Abstract_type[{abstract_type}] has already registered with a' ' different type.' ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py index 141a5632f..810506fb9 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py @@ -19,7 +19,7 @@ """ import asyncio -from typing import Awaitable, Sequence +from typing import Awaitable, Sequence, Type from absl import logging import jax.numpy as jnp @@ -32,7 +32,9 @@ Scalar = types.Scalar AbstractScalar = types.AbstractScalar ScalarSerializationParam = types.SerializationParam[Scalar] -ScalarDeserializationParam = types.DeserializationParam[AbstractScalar] +ScalarDeserializationParam = types.DeserializationParam[ + AbstractScalar +] def _create_v0_scalar_handler() -> type_handlers_v0.ScalarHandler: @@ -85,13 +87,15 @@ def _create_v0_savearg( def _create_v0_restore_paraminfo( - param: types.DeserializationParam[None | AbstractScalar], + param: types.DeserializationParam[ + AbstractScalar | Type[AbstractScalar] | None + ], context: context_lib.Context, deserialization_context: types.DeserializationContext, ) -> type_handlers_v0.ParamInfo: """Creates a V0 ParamInfo from V1 params and contexts for loading.""" - loading_options = context.array_options.Loading + loading_options = context.array_options.loading return type_handlers_v0.ParamInfo( name=param.name, @@ -122,6 +126,16 @@ def _create_v0_restorearg( ) +def _np_dtype_to_python_type(dtype): + """Converts dtype by checking its fundamental type.""" + if np.issubdtype(dtype, np.integer): + return int + elif np.issubdtype(dtype, np.floating): + return float + else: + raise TypeError(f"Unsupported dtype: {dtype}.") + + async def _async_futures(commit_futures: Sequence[future.Future]): await asyncio.gather(*[asyncio.to_thread(f.result) for f in commit_futures]) @@ -155,11 +169,14 @@ async def serialize( operation. """ values = [p.value for p in params] + logging.info("values: %s", values) paraminfos = [ _create_v0_saving_paraminfo(p, self._context, serialization_context) for p in params ] + logging.info("paraminfos: %s", paraminfos) saveargs = [_create_v0_savearg(p, self._context) for p in params] + logging.info("saveargs: %s", saveargs) commit_futures = await self._handler_impl.serialize( values, paraminfos, saveargs @@ -170,7 +187,7 @@ async def serialize( async def deserialize( self, - params: Sequence[types.DeserializationParam[AbstractScalar]], + params: Sequence[ScalarDeserializationParam], deserialization_context: types.DeserializationContext, ) -> Awaitable[Sequence[Scalar]]: """Returns sequence of Scalar values from a stored checkpointable location. @@ -225,9 +242,11 @@ def _get_type(meta: type_handlers_v0.ScalarMetadata): raise ValueError("dtype is None") if isinstance(meta.dtype, (np.dtype | jnp.dtype)): - return meta.dtype.type + t = _np_dtype_to_python_type(meta.dtype) else: - return meta.dtype + t = meta.dtype + + return t(0) ret = [_get_type(meta) for meta in v0_metadatas] diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler.py index 81c8c8aab..b20cc5469 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler.py @@ -19,7 +19,7 @@ """ import asyncio -from typing import Awaitable, Sequence +from typing import Awaitable, Sequence, Type from absl import logging from orbax.checkpoint._src.futures import future @@ -29,7 +29,9 @@ AbstractString = types.AbstractString StringSerializationParam = types.SerializationParam[str] -StringDeserializationParam = types.DeserializationParam[AbstractString] +StringDeserializationParam = types.DeserializationParam[ + AbstractString +] def _create_v0_saving_paraminfo( @@ -55,13 +57,15 @@ def _create_v0_saving_paraminfo( def _create_v0_restore_paraminfo( - param: types.DeserializationParam[None | AbstractString], + param: types.DeserializationParam[ + AbstractString | Type[AbstractString] | None + ], context: context_lib.Context, deserialization_context: types.DeserializationContext, ) -> type_handlers_v0.ParamInfo: """Creates a V0 ParamInfo from V1 params and contexts for loading.""" - loading_options = context.array_options.Loading + loading_options = context.array_options.loading return type_handlers_v0.ParamInfo( name=param.name, @@ -130,7 +134,7 @@ async def serialize( async def deserialize( self, - params: Sequence[types.DeserializationParam[AbstractString]], + params: Sequence[StringDeserializationParam], deserialization_context: types.DeserializationContext, ) -> Awaitable[Sequence[str]]: """Returns sequence of String values from a stored checkpointable location. @@ -180,6 +184,6 @@ async def metadata( async def _get_metadata() -> Sequence[AbstractString]: v0_metadatas = await self._handler_impl.metadata(paraminfos) - return [str] * len(v0_metadatas) + return ["string"] * len(v0_metadatas) return await _get_metadata() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py index 586c0ab3a..6831774f1 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py @@ -39,9 +39,8 @@ Scalar = int | float | np.number # Optional type hint for a scalar leaf handler. If provided, the restored scalar # will be cast to this type. Only casting to int or float is supported. -AbstractScalar = Type[Scalar] | Scalar - -AbstractString = Type[str] +AbstractScalar = Scalar +AbstractString = str class AbstractArray(Protocol): @@ -110,7 +109,7 @@ class SerializationContext: @dataclasses.dataclass class DeserializationParam(Generic[AbstractLeaf]): keypath: tree_types.PyTreeKeyPath - value: AbstractLeaf | None = None + value: AbstractLeaf | Type[AbstractLeaf] | None = None @property def name(self) -> str: @@ -171,7 +170,11 @@ async def deserialize( confirm the completion of this data transfer. Args: - params: sequence of DeserializationParam per leaf. + params: sequence of DeserializationParam per leaf. The Param contains a + value corresponding to the `AbstractLeaf` type. `Type[AbstractLeaf]` is + always valid. E.g. if the `AbstractLeaf` is `AbstractFoo`, it is always + valid to pass `AbstractFoo()` or `AbstractFoo`. Passing the latter two + indicates that metadata should be used to restore the leaf. deserialization_context: DeserializationContext for the leaf handler. Returns: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/handler_utils.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/handler_utils.py index 24ebc9a52..a333746b1 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/handler_utils.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/handler_utils.py @@ -19,7 +19,7 @@ import asyncio import dataclasses import json -from typing import Any, Awaitable, Generic, Type, TypeVar +from typing import Any, Awaitable, Generic, Protocol, Type, TypeVar import aiofiles from etils import epath @@ -46,33 +46,63 @@ async def _run_awaitable(awaitable: Awaitable[Any]) -> Any: return await awaitable -class _TestHandler(Generic[T, AbstractT]): +class TestHandler(Protocol[T, AbstractT]): """This class facilitates testing of CheckpointableHandlers independently. Use `create_test_handler`. """ + def save(self, directory: Path, checkpointable: T) -> None: + ... + + def save_async(self, directory: Path, checkpointable: T) -> Awaitable[None]: + ... + + def load( + self, path: Path, abstract_checkpointable: AbstractT | None = None + ) -> T: + ... + + def load_async( + self, path: Path, abstract_checkpointable: AbstractT | None = None + ) -> Awaitable[T]: + ... + + def metadata(self, path: Path) -> AbstractT: + ... + + def is_handleable(self, checkpointable: Any) -> bool: + ... + + def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool | None: + ... + + +class _TestHandler(Generic[T, AbstractT]): + def __init__( self, handler_class: type[CheckpointableHandler[T, AbstractT]], **kwargs ): self._handler: CheckpointableHandler[T, AbstractT] = handler_class(**kwargs) - def save(self, directory: Path, checkpointable: T): + def save(self, directory: Path, checkpointable: T) -> None: path = path_test_utils.PathAwaitingCreationWrapper(directory) awaitable = asyncio.run(self._handler.save(path, checkpointable)) - return asyncio.run(_run_awaitable(awaitable)) + asyncio.run(_run_awaitable(awaitable)) - def save_async(self, directory: Path, checkpointable: T): + def save_async(self, directory: Path, checkpointable: T) -> Awaitable[None]: path = path_test_utils.PathAwaitingCreationWrapper(directory) return asyncio.run(self._handler.save(path, checkpointable)) - def load(self, path: Path, abstract_checkpointable: AbstractT | None = None): + def load( + self, path: Path, abstract_checkpointable: AbstractT | None = None + ) -> T: awaitable = self.load_async(path, abstract_checkpointable) return asyncio.run(_run_awaitable(awaitable)) def load_async( self, path: Path, abstract_checkpointable: AbstractT | None = None - ): + ) -> Awaitable[T]: return asyncio.run(self._handler.load(path, abstract_checkpointable)) def metadata(self, path: Path) -> AbstractT: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py index 789547acf..6235cdff4 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py @@ -171,7 +171,7 @@ def mock_serialize(*args, **kwargs): (optax.EmptyState(),), ) def test_empty_tree(self, tree): - with self.assertRaisesRegex(ValueError, 'empty'): + with self.assertRaises(registration.NotHandleableError): ocp.save_pytree(self.directory, tree) # Note the ommission of jax.Array, since this is covered in @@ -191,6 +191,42 @@ def test_standard_leaf_types(self, value): else: self.assertEqual(loaded['k'], value) + # Note the ommission of jax.Array, since this is covered below. + @parameterized.parameters( + (np.arange(8),), + (2,), + (2.2,), + ('foo',), + (np.asarray(3.14),), + ) + def test_standard_leaf_types_as_checkpointable(self, value): + with self.subTest('save_pytree'): + with self.assertRaises(registration.NotHandleableError): + ocp.save_pytree(self.directory, value) + with self.subTest('save_checkpointables'): + ocp.save_checkpointables(self.directory, {'foo': value}) + loaded = ocp.load_checkpointables(self.directory)['foo'] + if isinstance(value, np.ndarray): + np.testing.assert_array_equal(loaded, value) + else: + self.assertEqual(loaded, value) + + def test_jax_array_as_checkpointable(self): + value = jnp.arange( + 16, + device=jax.sharding.NamedSharding( + jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)), + jax.sharding.PartitionSpec(), + ), + ) + with self.subTest('save_pytree'): + with self.assertRaises(registration.NotHandleableError): + ocp.save_pytree(self.directory, value) + with self.subTest('save_checkpointables'): + ocp.save_checkpointables(self.directory, {'foo': value}) + loaded = ocp.load_checkpointables(self.directory)['foo'] + test_utils.assert_tree_equal(self, value, loaded) + def test_jax_array_leaf_types(self): mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) # TODO(cpgaffney): Add support for missing arrays. @@ -220,10 +256,23 @@ def test_jax_array_leaf_types(self): loaded = ocp.load_pytree(self.directory / k, [as_abstract_type(v)]) test_utils.assert_tree_equal(self, [v], loaded) with self.subTest('without_abstract_pytree'): - if multihost.is_pathways_backend(): - self.skipTest('Must provide abstract_pytree for Pathways.') - loaded = ocp.load_pytree(self.directory / k) - test_utils.assert_tree_equal(self, [v], loaded) + if not multihost.is_pathways_backend(): + loaded = ocp.load_pytree(self.directory / k) + test_utils.assert_tree_equal(self, [v], loaded) + + def test_save_unregistered_type_as_pytree(self): + with self.assertRaises(registration.NotHandleableError): + ocp.save_pytree(self.directory, handler_utils.Foo(1, 'hi')) + + @parameterized.parameters( + ({},), + ([],), + ('hello',), + (None,), + ) + def test_save_checkpointables_invalid(self, checkpointables): + with self.assertRaises(ValueError): + ocp.save_checkpointables(self.directory, checkpointables) def test_leaf_change_type(self): mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) diff --git a/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb b/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb index 389bc637d..321a4fa5c 100644 --- a/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb +++ b/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb @@ -322,14 +322,14 @@ "id": "G0oCkcDZb8g_" }, "source": [ - "| Leaf Type | Abstract Leaf Type | Properties |\n", + "| `Leaf` Type | `AbstractLeaf` Type | Properties |\n", ":------- | :-------- | :-------- |\n", "|`jax.Array`|`AbstractShardedArray` (`jax.ShapeDtypeStruct`) |`shape`, `dtype`, `sharding`|\n", "|`np.ndarray`|`AbstractArray` (`np.ndarray`) |`shape`, `dtype`|\n", - "|`int`|`int`, `type[int]`| |\n", - "|`float`|`float`, `type[float]`| |\n", - "|`bytes`|`bytes`, `type[bytes]`| |\n", - "|`str`|`str`, `type[str]`| |" + "|`int`|`int`| |\n", + "|`float`|`float`| |\n", + "|`bytes`|`bytes`| |\n", + "|`str`|`str`| |" ] }, { @@ -338,7 +338,63 @@ "id": "01zrVpfcdq7m" }, "source": [ - "Note that `None` is always a valid abstract leaf; it serves as an indication that the leaf should be restored using metadata stored in the checkpoint." + "`None` is always a valid abstract leaf; it serves as an indication that the leaf should be restored using metadata stored in the checkpoint.\n", + "\n", + "`Type[AbstractLeaf]` is also always a valid abstract leaf; it again serves as an indication that the leaf should be restored using the metadata, but with the additional constraint to load as the indicated type. For example, instead of specifying `jax.ShapeDtypeStruct(shape=..., dtype=..., sharding=...)`, it is sufficient to pass `jax.ShapeDtypeStruct`. Similarly, instead of passing `0` to restore as an `int`, the type itself may be passed." + ] + }, + { + "metadata": { + "id": "nRn_2IgV09xT" + }, + "cell_type": "markdown", + "source": [ + "To summarize, here are the ways you can load a PyTree using abstract leaves, with the way we most recommend at the top, and the way we least recommend at the bottom.\n", + "\n", + "**1. Fully-specified abstract values**\n", + "\n", + "This provides the most loading validations and requires the least amount of\n", + "unnecessary metadata reads.\n", + "\n", + "```\n", + "abstract_pytree = {\n", + " 'a': jax.ShapeDtypeStruct(shape=..., dtype=..., sharding=jax.sharding.NamedSharding(...))\n", + "}\n", + "```\n", + "\n", + "**2. Only types specified**\n", + "\n", + "This guarantees that each leaf will be loaded with the indicated type, but metadata\n", + "will be used to restore specific properties for each leaf.\n", + "\n", + "```\n", + "abstract_pytree = {\n", + " 'a': jax.ShapeDtypeStruct,\n", + " 'b': int,\n", + " 'c': np.ndarray,\n", + "}\n", + "```\n", + "\n", + "**3. `None` specified (per-leaf)**\n", + "\n", + "This is essentially the same as (2), but metadata will also be used to decide\n", + "which type each leaf should be loaded as.\n", + "\n", + "```\n", + "abstract_pytree = {\n", + " 'a': None,\n", + " 'b': None,\n", + "}\n", + "```\n", + "\n", + "**4. `None` specified**\n", + "\n", + "This loads the PyTree structure without any checks, and can lead to errors later\n", + "in your code if the checkpoint does not have the structure you expect.\n", + "\n", + "```\n", + "abstract_pytree = None\n", + "```" ] }, {