@@ -833,14 +833,10 @@ def _partial_restore_with_omission(
833833 is_leaf = tree_utils .is_empty_or_leaf ,
834834 )
835835
836- try :
837- value_metadata_tree = tree_structure_utils .tree_trim (
838- serialized_item , value_metadata_tree , strict = True
839- )
840- except ValueError as e :
841- raise ValueError (
842- 'Missing keys were found in the user-provided restore item.'
843- ) from e
836+ value_metadata_tree = tree_structure_utils .tree_trim (
837+ serialized_item , value_metadata_tree , strict = False
838+ )
839+ value_metadata_tree = value_metadata_tree .unsafe_structure
844840
845841 if restore_args is not None :
846842 restore_args = tree_structure_utils .tree_trim (
@@ -850,26 +846,37 @@ def _partial_restore_with_omission(
850846 return value_metadata_tree , restore_args
851847
852848 def _partial_restore_with_placeholders (
853- self ,
854- item : PyTree ,
855- value_metadata_tree : PyTree ,
856- ):
849+ self , item : PyTree , value_metadata_tree : PyTree
850+ ) -> PyTree :
857851 """Restores leaves from `item`, except for those marked as placeholders."""
858852 serialized_item = tree_utils .serialize_tree (item , keep_empty_nodes = True )
859- diff = tree_structure_utils .tree_difference (
860- serialized_item ,
861- value_metadata_tree ,
862- is_leaf = tree_utils .is_empty_or_leaf ,
863- leaves_equal = lambda a , b : True ,
853+ diff = (
854+ tree_structure_utils .tree_difference (
855+ serialized_item ,
856+ value_metadata_tree ,
857+ is_leaf = tree_utils .is_empty_or_leaf ,
858+ leaves_equal = lambda a , b : True ,
859+ )
860+ or {}
864861 )
865- if diff is not None :
866- formatted_diff = tree_structure_utils .format_tree_diff (
867- diff , source_label = 'Item' , target_label = 'Metadata'
868- )
869- raise ValueError (
870- 'User-provided restore item and on-disk value metadata tree'
871- f' structures do not match:\n { formatted_diff } '
872- )
862+ for keypath , value_diff in tree_utils .to_flat_dict (
863+ diff , is_leaf = lambda x : isinstance (x , tree_structure_utils .Diff )
864+ ).items ():
865+ if value_diff .lhs is PLACEHOLDER and value_diff .rhs is None :
866+ parent = value_metadata_tree
867+ for key in keypath [:- 1 ]:
868+ parent = parent [key ]
869+ parent [keypath [- 1 ]] = PLACEHOLDER
870+ else :
871+ formatted_diff = tree_structure_utils .format_tree_diff (
872+ diff , source_label = 'Item' , target_label = 'Metadata'
873+ )
874+ raise ValueError (
875+ 'User-provided restore item and on-disk value metadata tree'
876+ f' structures do not match:\n { formatted_diff } \n If this mismatch is'
877+ ' intentional, pass `partial_restore=True` to only restore'
878+ ' parameters found in `item`.'
879+ )
873880 return jax .tree .map (
874881 lambda v , i : PLACEHOLDER if type_handlers .is_placeholder (i ) else v ,
875882 value_metadata_tree ,
@@ -1021,7 +1028,9 @@ class TrainState:
10211028 )
10221029 raise ValueError (
10231030 'User-provided restore item and on-disk value metadata tree'
1024- f' structures do not match:\n { formatted_diff } '
1031+ f' structures do not match:\n { formatted_diff } \n If this mismatch is'
1032+ ' intentional, pass `partial_restore=True` to only restore'
1033+ ' parameters found in `item`.'
10251034 )
10261035 restore_args = _fill_missing_save_or_restore_args (
10271036 item , restore_args , mode = 'restore'
@@ -1043,6 +1052,14 @@ class TrainState:
10431052 )
10441053 )
10451054
1055+ if args .partial_restore :
1056+ restored_item = jax .tree .map (
1057+ lambda r , i : i if r is type_handlers .PLACEHOLDER else r ,
1058+ restored_item ,
1059+ item ,
1060+ is_leaf = tree_utils .is_empty_or_leaf ,
1061+ )
1062+
10461063 if logging .vlog_is_on (1 ):
10471064 logging .vlog (1 , 'param_infos: %s' , param_infos )
10481065 logging .vlog (1 , 'checkpoint_restore_args: %s' , restore_args )
0 commit comments