Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a check to prevent zero-sized arrays from being saved. This behav…
Browse files Browse the repository at this point in the history
…ior already resulted in an error, but it was one that was difficult to parse.

PiperOrigin-RevId: 721870235
cpgaffney1 authored and Orbax Authors committed Jan 31, 2025
1 parent 6e80ecc commit b360537
Showing 3 changed files with 24 additions and 0 deletions.
5 changes: 5 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -13,6 +13,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`AsyncCheckpointer.save()`, and `CheckpointManager.save()`, which saves a custom
dict of user metadata to `StepMetadata`.

### Fixed

- Add a check to prevent zero-sized arrays from being saved. This behavior
already resulted in an error, but it was one that was difficult to parse.

## [0.11.1] - 2025-01-28

### Changed
Original file line number Diff line number Diff line change
@@ -2308,3 +2308,14 @@ def deserialize(
checkpoint_handler.save(
self.directory, args=PyTreeSaveArgs(self.pytree)
)

@parameterized.parameters((True,), (False,))
def test_zero_size_array(self, use_jax_array: bool):
arr = np.ones(shape=(0,))
mesh = jax.sharding.Mesh(np.array(jax.devices()), axis_names=('x',))
pspec = jax.sharding.PartitionSpec()
if use_jax_array:
arr = test_utils.create_sharded_array(arr, mesh, pspec)
tree = [arr]
with self.assertRaisesRegex(ValueError, 'zero size'):
self.handler.save(self.directory, args=PyTreeSaveArgs(tree))
Original file line number Diff line number Diff line change
@@ -249,6 +249,12 @@ def check_input_arguments(*args):
raise ValueError('Found input args with mismatched lengths.')


def check_array_values(values: Sequence[Union[jax.Array, np.ndarray]]):
for v in values:
if v.size == 0:
raise ValueError('Cannot save arrays with zero size.')


async def _validate_params(
directory: epath.Path,
ts_context: ts.Context,
@@ -631,6 +637,7 @@ async def serialize(
"""Uses Tensorstore to serialize a numpy array."""
args = args or [types.SaveArgs()] * len(values)
check_input_arguments(values, infos, args)
check_array_values(values)
if logging.vlog_is_on(1):
_print_ts_debug_data(self._metadata_key, infos)
copied_values = [copy.deepcopy(v) for v in values]
@@ -1102,6 +1109,7 @@ async def serialize(
)
args = args or [types.SaveArgs()] * len(values)
check_input_arguments(values, infos, args)
check_array_values(values)

assert all([info.enable_pinned_host_transfer for info in infos]) or all(
[not info.enable_pinned_host_transfer for info in infos]

0 comments on commit b360537

Please sign in to comment.