Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed

- #v1 Add `LeafHandler` as a `CheckpointableHandler`, so that ordinary PyTree
leaves can also be saved as individual checkpointables.
- #v1 Modify LeafHandler definitions so that `AbstractLeaf` or
`Type[AbstractLeaf]` are always accepted as valid abstract values.

## [0.11.25] - 2025-09-10

### Changed
Expand Down
29 changes: 17 additions & 12 deletions checkpoint/orbax/checkpoint/_src/testing/multiprocess_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,18 +299,23 @@ class MultiProcessTest(absltest.TestCase):
def setUp(self):
"""Start distributed service."""
super().setUp()
assert jax.process_count() == NUM_PROCESSES.value, (
jax.process_count(),
NUM_PROCESSES.value,
)
# Make sure all processes are at the same test case.
client = multihost.get_jax_distributed_client()
# Note that the name of this barrier is long and complicated, to prevent
# any collisions with barriers in user test code.
client.wait_at_barrier(
f"multiprocess_test_ensure_all_processes_arrive_at_test_case_{self._testMethodName}",
10000,
)
if multihost.is_pathways_backend():
assert (
jax.process_count() == 1
), "Expected 1 process for Pathways backend."
else:
assert jax.process_count() == NUM_PROCESSES.value, (
jax.process_count(),
NUM_PROCESSES.value,
)
# Make sure all processes are at the same test case.
client = multihost.get_jax_distributed_client()
# Note that the name of this barrier is long and complicated, to prevent
# any collisions with barriers in user test code.
client.wait_at_barrier(
f"multiprocess_test_ensure_all_processes_arrive_at_test_case_{self._testMethodName}",
10000,
)

def multiprocess_create_tempdir(
self, name: str | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,15 @@ def get_handlers_for_save(
self, checkpointables: dict[str, Any]
) -> dict[str, handler_types.CheckpointableHandler]:
"""Returns a mapping from checkpointable name to handler."""
return {
result = {
checkpointable_name: registration.resolve_handler_for_save(
self._handler_registry, checkpointable, name=checkpointable_name
)
for checkpointable_name, checkpointable in checkpointables.items()
}
for name, handler in result.items():
logging.info('Resolved handler type %s for %s', type(handler), name)
return result

def get_handlers_for_load(
self, directory: path_types.Path, abstract_checkpointables: dict[str, Any]
Expand Down Expand Up @@ -247,6 +250,8 @@ def get_handlers_for_load(
handler_typestr=handler_typestr,
)
loadable_checkpointable_names_to_handlers[name] = handler
for name, handler in loadable_checkpointable_names_to_handlers.items():
logging.info('Resolved handler type %s for %s', type(handler), name)
return loadable_checkpointable_names_to_handlers

def _get_saved_handler_typestrs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Type

from orbax.checkpoint.experimental.v1._src.handlers import json_handler
from orbax.checkpoint.experimental.v1._src.handlers import leaf_handler
from orbax.checkpoint.experimental.v1._src.handlers import proto_handler
from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler
from orbax.checkpoint.experimental.v1._src.handlers import registration
Expand Down Expand Up @@ -54,3 +55,9 @@ def _try_register_handler(
_try_register_handler(
pytree_handler.PyTreeHandler, format_utils.PYTREE_CHECKPOINTABLE_KEY
)

# Registration for leaf types that can be treated as distinct checkpointables.
_try_register_handler(leaf_handler.ShardedArrayHandler)
_try_register_handler(leaf_handler.ArrayHandler)
_try_register_handler(leaf_handler.ScalarHandler)
_try_register_handler(leaf_handler.StringHandler)
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Wrapper for :py:class:`serialization.LeafHandler`.

This :py:class:`CheckpointableHandler` is a wrapper for checkpointables where
support is already implemented at the PyTree leaf level.
"""

from typing import Any, Awaitable, TypeVar

import jax
import numpy as np
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.serialization import registry
from orbax.checkpoint.experimental.v1._src.serialization import types as serialization_types


Leaf = TypeVar('Leaf')
AbstractLeaf = TypeVar('AbstractLeaf')


class _LeafHandler(handler_types.CheckpointableHandler[Leaf, AbstractLeaf]):
"""Base class for handlers that operate on individual PyTree leaves.

This handler wraps `PyTreeHandler` to provide support for checkpointables
that are single leaves in a PyTree.
"""

def __init__(self):
self._context = context_lib.get_context()

async def save(
self, directory: path_types.PathAwaitingCreation, checkpointable: Leaf
) -> Awaitable[None]:
return await pytree_handler.PyTreeHandler().save(
directory, [checkpointable]
)

async def load(
self,
directory: path_types.Path,
abstract_checkpointable: AbstractLeaf | None = None,
) -> Awaitable[Leaf]:
if abstract_checkpointable is None:
abstract_pytree = None
else:
abstract_pytree = [abstract_checkpointable]

background_load = await pytree_handler.PyTreeHandler().load(
directory, abstract_pytree
)

async def background_load_wrapper() -> Leaf:
loaded_pytree = await background_load
return loaded_pytree[0]

return background_load_wrapper()

async def metadata(self, directory: path_types.Path) -> AbstractLeaf:
pytree_metadata = await pytree_handler.PyTreeHandler().metadata(directory)
return pytree_metadata[0]

def is_handleable(self, checkpointable: Any) -> bool:
try:
pytree_handler.PyTreeHandler().validate_leaves_handleable(
[checkpointable]
)
return True
except registry.UnregisteredTypeError:
return False

def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool | None:
try:
pytree_handler.PyTreeHandler().validate_abstract_leaves_handleable(
[abstract_checkpointable]
)
return True
except registry.UnregisteredTypeError:
return False


class ShardedArrayHandler(
_LeafHandler[jax.Array, serialization_types.AbstractShardedArray]
):
pass


class ArrayHandler(_LeafHandler[np.ndarray, serialization_types.AbstractArray]):
pass


class StringHandler(_LeafHandler[str, serialization_types.AbstractString]):
pass


class ScalarHandler(
_LeafHandler[serialization_types.Scalar, serialization_types.AbstractScalar]
):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Awaitable

from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
import jax
from jax import numpy as jnp
import numpy as np
from orbax.checkpoint import test_utils
from orbax.checkpoint._src.arrays import abstract_arrays
from orbax.checkpoint.experimental.v1._src.handlers import leaf_handler
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.serialization import registry
from orbax.checkpoint.experimental.v1._src.testing import handler_utils as handler_test_utils
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types


PathAwaitingCreation = path_types.PathAwaitingCreation
PathLike = path_types.PathLike
Path = path_types.Path
Json = tree_types.JsonType
create_test_handler = handler_test_utils.create_test_handler

Leaf = leaf_handler.Leaf
AbstractLeaf = leaf_handler.AbstractLeaf


async def _run_awaitable(awaitable: Awaitable[Any]) -> Any:
return await awaitable


class FooHandler(
leaf_handler._LeafHandler[
handler_test_utils.Foo, handler_test_utils.AbstractFoo
]
):

def is_handleable(self, checkpointable: Any) -> bool:
return isinstance(checkpointable, handler_test_utils.Foo)

def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool | None:
return isinstance(abstract_checkpointable, handler_test_utils.AbstractFoo)


class LeafHandlerTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.directory = epath.Path(
self.create_tempdir(name='checkpointing_test').full_path
)
self.values_and_abstract_values = {
leaf_handler.ShardedArrayHandler: [
(
jnp.arange(8),
abstract_arrays.to_shape_dtype_struct(jnp.arange(8)),
),
(jnp.arange(8), jax.ShapeDtypeStruct),
],
leaf_handler.ArrayHandler: [
(np.arange(8), np.empty_like(np.arange(8)))
],
leaf_handler.ScalarHandler: [
(123, int),
(123, 0),
(123.456, float),
(123.456, 0.0),
],
leaf_handler.StringHandler: [('test', str), ('test', '_')],
}

def validate_load(
self,
handler: handler_test_utils.TestHandler[Leaf, AbstractLeaf],
value: Leaf,
abstract_value: AbstractLeaf,
directory: Path | None = None,
):
directory = directory or self.directory
with self.subTest('load_with_abstract'):
restored = handler.load(directory, abstract_value)
test_utils.assert_array_equal(self, value, restored)
with self.subTest('load_without_abstract'):
restored = handler.load(directory)
test_utils.assert_array_equal(self, value, restored)

@parameterized.parameters(
leaf_handler.ShardedArrayHandler,
leaf_handler.ArrayHandler,
leaf_handler.ScalarHandler,
leaf_handler.StringHandler,
)
def test_save_load(self, handler_cls):
handler = create_test_handler(handler_cls)
test_cases = self.values_and_abstract_values[handler_cls]

self.assertFalse(handler.is_handleable(handler_test_utils.Foo(1, 'hi')))
self.assertFalse(
handler.is_abstract_handleable(handler_test_utils.AbstractFoo())
)

for i, (value, abstract_value) in enumerate(test_cases):
name = str(i)
with self.subTest(f'value={value}, abstract_value={abstract_value}'):
logging.info(
'Subtest: value=%s, abstract_value=%s', value, abstract_value
)
self.assertTrue(handler.is_handleable(value))
self.assertTrue(handler.is_abstract_handleable(abstract_value))
handler.save(self.directory / name, value)
self.validate_load(
handler, value, abstract_value, directory=self.directory / name
)

def test_unregistered_type(self):
handler = create_test_handler(FooHandler)
with self.assertRaises(registry.UnregisteredTypeError):
handler.save(self.directory, handler_test_utils.Foo(1, 'hi'))

with self.assertRaises(registry.UnregisteredTypeError):
handler.load(self.directory, handler_test_utils.AbstractFoo())

if __name__ == '__main__':
absltest.main()
Loading
Loading