Skip to content

Commit 1f103e5

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
Allow unexpected keys in restore_item during partial restoration.
PiperOrigin-RevId: 823229099
1 parent 5cad306 commit 1f103e5

File tree

3 files changed

+43
-74
lines changed

3 files changed

+43
-74
lines changed

checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer_test_utils.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -767,30 +767,6 @@ def test_partial_restore_with_omission(self):
767767
)
768768
test_utils.assert_tree_equal(self, expected, restored)
769769

770-
with self.subTest('extra_leaf'):
771-
with self.checkpointer(
772-
PyTreeCheckpointHandler()
773-
) as restore_checkpointer:
774-
reference_item = {
775-
'a': 0,
776-
'c': {
777-
'a': 0,
778-
},
779-
'z': 0,
780-
}
781-
with self.assertRaisesRegex(
782-
ValueError,
783-
'Missing keys were found in the user-provided restore item.',
784-
):
785-
restore_checkpointer.restore(
786-
directory,
787-
args=pytree_checkpoint_handler.PyTreeRestoreArgs(
788-
item=reference_item,
789-
restore_args=self.pytree_restore_args,
790-
partial_restore=True,
791-
),
792-
)
793-
794770
def test_restore_logs_read_event(self):
795771
"""Tests that restore logs a read event to DM Sawmill log."""
796772
with self.checkpointer(PyTreeCheckpointHandler()) as checkpointer:

checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -833,14 +833,10 @@ def _partial_restore_with_omission(
833833
is_leaf=tree_utils.is_empty_or_leaf,
834834
)
835835

836-
try:
837-
value_metadata_tree = tree_structure_utils.tree_trim(
838-
serialized_item, value_metadata_tree, strict=True
839-
)
840-
except ValueError as e:
841-
raise ValueError(
842-
'Missing keys were found in the user-provided restore item.'
843-
) from e
836+
value_metadata_tree = tree_structure_utils.tree_trim(
837+
serialized_item, value_metadata_tree, strict=False
838+
)
839+
value_metadata_tree = value_metadata_tree.unsafe_structure
844840

845841
if restore_args is not None:
846842
restore_args = tree_structure_utils.tree_trim(
@@ -850,26 +846,37 @@ def _partial_restore_with_omission(
850846
return value_metadata_tree, restore_args
851847

852848
def _partial_restore_with_placeholders(
853-
self,
854-
item: PyTree,
855-
value_metadata_tree: PyTree,
856-
):
849+
self, item: PyTree, value_metadata_tree: PyTree
850+
) -> PyTree:
857851
"""Restores leaves from `item`, except for those marked as placeholders."""
858852
serialized_item = tree_utils.serialize_tree(item, keep_empty_nodes=True)
859-
diff = tree_structure_utils.tree_difference(
860-
serialized_item,
861-
value_metadata_tree,
862-
is_leaf=tree_utils.is_empty_or_leaf,
863-
leaves_equal=lambda a, b: True,
853+
diff = (
854+
tree_structure_utils.tree_difference(
855+
serialized_item,
856+
value_metadata_tree,
857+
is_leaf=tree_utils.is_empty_or_leaf,
858+
leaves_equal=lambda a, b: True,
859+
)
860+
or {}
864861
)
865-
if diff is not None:
866-
formatted_diff = tree_structure_utils.format_tree_diff(
867-
diff, source_label='Item', target_label='Metadata'
868-
)
869-
raise ValueError(
870-
'User-provided restore item and on-disk value metadata tree'
871-
f' structures do not match:\n{formatted_diff}'
872-
)
862+
for keypath, value_diff in tree_utils.to_flat_dict(
863+
diff, is_leaf=lambda x: isinstance(x, tree_structure_utils.Diff)
864+
).items():
865+
if value_diff.lhs is PLACEHOLDER and value_diff.rhs is None:
866+
parent = value_metadata_tree
867+
for key in keypath[:-1]:
868+
parent = parent[key]
869+
parent[keypath[-1]] = PLACEHOLDER
870+
else:
871+
formatted_diff = tree_structure_utils.format_tree_diff(
872+
diff, source_label='Item', target_label='Metadata'
873+
)
874+
raise ValueError(
875+
'User-provided restore item and on-disk value metadata tree'
876+
f' structures do not match:\n{formatted_diff}\nIf this mismatch is'
877+
' intentional, pass `partial_restore=True` to only restore'
878+
' parameters found in `item`.'
879+
)
873880
return jax.tree.map(
874881
lambda v, i: PLACEHOLDER if type_handlers.is_placeholder(i) else v,
875882
value_metadata_tree,
@@ -1021,7 +1028,9 @@ class TrainState:
10211028
)
10221029
raise ValueError(
10231030
'User-provided restore item and on-disk value metadata tree'
1024-
f' structures do not match:\n{formatted_diff}'
1031+
f' structures do not match:\n{formatted_diff}\nIf this mismatch is'
1032+
' intentional, pass `partial_restore=True` to only restore'
1033+
' parameters found in `item`.'
10251034
)
10261035
restore_args = _fill_missing_save_or_restore_args(
10271036
item, restore_args, mode='restore'
@@ -1043,6 +1052,14 @@ class TrainState:
10431052
)
10441053
)
10451054

1055+
if args.partial_restore:
1056+
restored_item = jax.tree.map(
1057+
lambda r, i: i if r is type_handlers.PLACEHOLDER else r,
1058+
restored_item,
1059+
item,
1060+
is_leaf=tree_utils.is_empty_or_leaf,
1061+
)
1062+
10461063
if logging.vlog_is_on(1):
10471064
logging.vlog(1, 'param_infos: %s', param_infos)
10481065
logging.vlog(1, 'checkpoint_restore_args: %s', restore_args)

checkpoint/orbax/checkpoint/checkpoint_manager_test.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3907,30 +3907,6 @@ def test_partial_restore_with_omission(self):
39073907
)
39083908
test_utils.assert_tree_equal(self, expected, restored)
39093909

3910-
with self.subTest('extra_leaf'):
3911-
with CheckpointManager(directory) as restore_manager:
3912-
reference_item = {
3913-
'a': 0,
3914-
# Omit 'b'
3915-
'c': {
3916-
'a': 0,
3917-
# Omit 'e'
3918-
},
3919-
'z': 0,
3920-
}
3921-
with self.assertRaisesRegex(
3922-
ValueError,
3923-
'Missing keys were found in the user-provided restore item.',
3924-
):
3925-
restore_manager.restore(
3926-
0,
3927-
args=args.PyTreeRestore(
3928-
reference_item,
3929-
restore_args=self.pytree_restore_args,
3930-
partial_restore=True,
3931-
),
3932-
)
3933-
39343910

39353911
if __name__ == '__main__':
39363912
multiprocess_test.main()

0 commit comments

Comments
 (0)