Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/async_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ async def is_dir(path: epath.Path):
return await asyncio.to_thread(path.is_dir)


async def is_file(path: epath.Path):
return await asyncio.to_thread(path.is_file)


async def is_link(path: epath.Path):
return await asyncio.to_thread(os.path.islink, path)

Expand All @@ -101,3 +105,7 @@ async def iterdir(path: epath.Path):

async def glob(path: epath.Path, pattern: str) -> Iterator[epath.Path]:
return await asyncio.to_thread(path.glob, pattern)


async def unlink(path: epath.Path, missing_ok: bool = False):
return await asyncio.to_thread(path.unlink, missing_ok=missing_ok)
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/experimental/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from orbax.checkpoint.experimental.v1._src.context.context import (
Context,
)
from orbax.checkpoint.experimental.v1._src.layout.format_utils import (
from orbax.checkpoint.experimental.v1._src.layout.orbax_layout import (
is_orbax_checkpoint,
)
from orbax.checkpoint.experimental.v1._src.loading.loading import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import asyncio
import itertools
from typing import Any, Awaitable

from absl import logging
Expand All @@ -29,7 +30,7 @@
from orbax.checkpoint.experimental.v1._src.handlers import registration
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import
from orbax.checkpoint.experimental.v1._src.path import format_utils
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.path import types as path_types


Expand All @@ -39,6 +40,15 @@
ORBAX_CHECKPOINT_INDICATOR_FILE = 'orbax.checkpoint'


def _subdirs(directory: path_types.Path, *, limit: int = 3) -> list[str]:
return list(
itertools.islice(
(subdir.name for subdir in directory.iterdir() if subdir.is_dir()),
limit,
)
)


_V0_ERROR_MESSAGE = (
'If your checkpoint was saved with the Orbax V0 API, please follow the'
' instructions at'
Expand All @@ -57,7 +67,9 @@ async def _create_orbax_identifier_file(
"""Creates a file called `orbax.checkpoint` for easy identification."""
directory = await directory.await_creation()
if multihost.is_primary_host(primary_host):
await async_path.touch(directory / 'orbax.checkpoint', exist_ok=True)
await async_path.touch(
directory / ORBAX_CHECKPOINT_INDICATOR_FILE, exist_ok=True
)


class CompositeHandler:
Expand Down Expand Up @@ -165,7 +177,7 @@ async def load(
abstract_checkpointables = {
name: None
for name in handlers_for_load.keys()
if name not in format_utils.RESERVED_CHECKPOINTABLE_KEYS
if name not in checkpoint_layout.RESERVED_CHECKPOINTABLE_KEYS
and name in existing_checkpointable_names
}
if any(
Expand Down Expand Up @@ -265,8 +277,8 @@ def _get_saved_handler_typestrs(
return saved_metadata.item_handlers # found step level metadata.
raise ValueError(
f'Path at {directory} contains subdirectories:'
f' {format_utils.subdirs(directory)}, which are expected to match the'
' keys given by the _CHECKPOINT_METADATA file:'
f' {_subdirs(directory)}, which are expected to'
' match the keys given by the _CHECKPOINT_METADATA file:'
f' {saved_metadata.item_handlers}. If you intended to load a pytree'
' checkpoint from the given path, then please consider using'
' `loading.load_pytree(..., checkpointable_name=None)` instead.'
Expand All @@ -293,8 +305,8 @@ def _get_saved_handler_typestrs(
if isinstance(saved_metadata.item_handlers, dict):
raise ValueError(
f'Path at {directory} contains subdirectories:'
f' {format_utils.subdirs(directory)}, which are expected to match'
' the keys given by the _CHECKPOINT_METADATA file:'
f' {_subdirs(directory)}, which are expected to'
' match the keys given by the _CHECKPOINT_METADATA file:'
f' {saved_metadata.item_handlers}. If you intended to load a pytree'
' checkpoint from the given path, then please consider using'
' `loading.load_pytree(..., checkpointable_name=None)` instead.'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
Path = types.Path


### Constants shared by all layouts. ###

PYTREE_CHECKPOINTABLE_KEY = "pytree"

METRICS_CHECKPOINTABLE_KEY = "metrics"

RESERVED_CHECKPOINTABLE_KEYS = frozenset({
METRICS_CHECKPOINTABLE_KEY,
})


class InvalidLayoutError(ValueError):
"""Raised when the checkpoint layout is invalid."""

Expand Down Expand Up @@ -51,7 +62,7 @@ async def metadata(self) -> metadata_types.CheckpointMetadata[dict[str, Any]]:
"""
...

def validate(self) -> None:
async def validate(self) -> None:
"""Validates the path, determining if it conforms to this instance.

Returns:
Expand All @@ -62,7 +73,7 @@ def validate(self) -> None:
"""
...

def validate_pytree(self, checkpointable_name: str | None) -> None:
async def validate_pytree(self, checkpointable_name: str | None) -> None:
"""Validates the path as a PyTree checkpoint.

Args:
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@

from typing import Any, Awaitable

from etils import epath
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.handlers import composite_handler
from orbax.checkpoint.experimental.v1._src.handlers import registration
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.loading import v0_compatibility
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
from orbax.checkpoint.experimental.v1._src.path import format_utils
from orbax.checkpoint.experimental.v1._src.path import types
from orbax.checkpoint.experimental.v1._src.path import types as path_types

InvalidLayoutError = checkpoint_layout.InvalidLayoutError
CompositeHandler = composite_handler.CompositeHandler
Path = types.Path
Path = path_types.Path
CheckpointLayout = checkpoint_layout.CheckpointLayout

ORBAX_CHECKPOINT_INDICATOR_FILE = (
composite_handler.ORBAX_CHECKPOINT_INDICATOR_FILE
)
PYTREE_METADATA_FILE = "_METADATA"
ORBAX_CHECKPOINT_INDICATOR_FILE = "orbax.checkpoint"


_V0_ERROR_MESSAGE = (
Expand All @@ -48,6 +48,23 @@
)


def is_orbax_checkpoint(path: path_types.PathLike) -> bool:
"""Determines if the given path is an Orbax checkpoint.

Args:
path: The path to the checkpoint directory.

Returns:
True if the path is an Orbax checkpoint, False otherwise.
"""
path = epath.Path(path)
try:
OrbaxLayout(path).validate()
return True
except InvalidLayoutError:
return False


class OrbaxLayout(CheckpointLayout):
"""OrbaxLayout.

Expand Down Expand Up @@ -96,7 +113,7 @@ async def metadata(self) -> metadata_types.CheckpointMetadata[dict[str, Any]]:
custom_metadata=step_metadata.custom_metadata,
)

def validate(self):
async def validate(self):
try:
format_utils.validate_checkpoint_directory(self._path)
if self.has_indicator_file:
Expand All @@ -107,7 +124,7 @@ def validate(self):
f" {_GENERAL_ERROR_MESSAGE}"
) from e

def validate_pytree(self, checkpointable_name: str | None) -> None:
async def validate_pytree(self, checkpointable_name: str | None) -> None:
"""Validates the given path as a PyTree checkpoint."""
try:
format_utils.validate_pytree_checkpoint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,22 @@ def setUp(self):
custom_metadata=self.custom_metadata,
)

def test_valid_orbax_checkpoint(self):
async def test_valid_orbax_checkpoint(self):
layout = OrbaxLayout(self.orbax_path / '0')
layout.validate()
await layout.validate()

def test_invalid_orbax_checkpoint(self):
async def test_invalid_orbax_checkpoint(self):
layout = OrbaxLayout(self.safetensors_path)
with self.assertRaises(InvalidLayoutError):
layout.validate()
await layout.validate()

def test_validate_fails_not_directory(self):
async def test_validate_fails_not_directory(self):
layout = OrbaxLayout(self.orbax_path / '1')
# This tests `format_utils.validate_checkpoint_directory`
with self.assertRaises(InvalidLayoutError):
layout.validate()
await layout.validate()

def test_skips_metadata_validation_if_no_indicator_file(self):
async def test_skips_metadata_validation_if_no_indicator_file(self):
layout = OrbaxLayout(self.orbax_path / '0')
indicator_path = (
self.orbax_path
Expand All @@ -81,25 +81,25 @@ def test_skips_metadata_validation_if_no_indicator_file(self):
metadata_path = self.orbax_path / '0' / '_CHECKPOINT_METADATA'
self.assertTrue(metadata_path.exists())
metadata_path.rmtree() # Remove the metadata file
layout.validate()
await layout.validate()

def test_validate_fails_no_metadata_file(self):
async def test_validate_fails_no_metadata_file(self):
layout = OrbaxLayout(self.orbax_path / '0')
# This tests `format_utils.validate_checkpoint_metadata`
metadata_path = self.orbax_path / '0' / '_CHECKPOINT_METADATA'
self.assertTrue(metadata_path.exists())
metadata_path.rmtree() # Remove the metadata file
with self.assertRaises(InvalidLayoutError):
layout.validate()
await layout.validate()

def test_validate_fails_tmp_directory(self):
async def test_validate_fails_tmp_directory(self):
# This simulates a temporary directory created by Orbax (should fail)
test_utils.save_fake_tmp_dir(self.orbax_path, 0, 'test_checkpoint.tmp')
layout = OrbaxLayout(
epath.Path(self.test_dir.full_path) / 'test_checkpoint.tmp'
)
with self.assertRaises(InvalidLayoutError):
layout.validate()
await layout.validate()

async def test_load_orbax_checkpoint(self):
layout = OrbaxLayout(self.orbax_path / '0')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
InvalidLayoutError = checkpoint_layout.InvalidLayoutError


def get_checkpoint_layout(
async def get_checkpoint_layout(
path: path_types.PathLike, layout_enum: options_lib.CheckpointLayout
) -> checkpoint_layout.CheckpointLayout:
"""Returns the checkpoint layout class for the given path.
Expand Down Expand Up @@ -52,7 +52,7 @@ def get_checkpoint_layout(

try:
layout = layout_class(path)
layout.validate()
await layout.validate()
return layout
except InvalidLayoutError as e:
raise InvalidLayoutError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax
import numpy as np
from orbax.checkpoint._src.arrays import numpy_utils
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
from orbax.checkpoint.experimental.v1._src.path import format_utils
Expand Down Expand Up @@ -280,8 +281,11 @@ async def metadata(self) -> metadata_types.CheckpointMetadata[dict[str, Any]]:
custom_metadata=custom_metadata,
)

def validate(self):
if self._path.is_file() and self._path.suffix == ".safetensors":
async def validate(self):
if (
await async_path.is_file(self._path)
and self._path.suffix == ".safetensors"
):
return
else:
raise InvalidLayoutError(
Expand All @@ -290,7 +294,7 @@ def validate(self):
" suffix."
)

def validate_pytree(self, checkpointable_name: str | None) -> None:
async def validate_pytree(self, checkpointable_name: str | None) -> None:
return

async def load(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,27 @@ def setUp(self):
)
saving.save_pytree(self.orbax_path, self.object_to_save)

def test_valid_safetensors_checkpoint(self):
async def test_valid_safetensors_checkpoint(self):
layout = SafetensorsLayout(self.safetensors_path)
layout.validate()
await layout.validate()

def test_invalid_safetensors_checkpoint_orbax(self):
async def test_invalid_safetensors_checkpoint_orbax(self):
layout = SafetensorsLayout(self.orbax_path / '0')
with self.assertRaises(InvalidLayoutError):
layout.validate()
await layout.validate()

def test_validate_fails_not_file(self):
async def test_validate_fails_not_file(self):
layout = SafetensorsLayout(epath.Path(self.test_dir.full_path))
with self.assertRaises(InvalidLayoutError):
layout.validate()
await layout.validate()

def test_validate_fails_wrong_suffix(self):
async def test_validate_fails_wrong_suffix(self):
wrong_suffix_path = (
epath.Path(self.test_dir.full_path) / 'test_checkpoint.txt'
)
layout = SafetensorsLayout(wrong_suffix_path)
with self.assertRaises(InvalidLayoutError):
layout.validate()
await layout.validate()

@parameterized.product(
dtype=[
Expand Down
Loading
Loading