@@ -24,9 +24,9 @@ def unflatten_tensor_state_dict(
2424
2525 For example, given a previously flattened tensors_data_dict and metadata:
2626 tensors_data_dict = {
27- '0.weight:qdata ': torch.Tensor(...),
28- '0.weight:scale ': torch.Tensor(...),
29- '0.bias:_data ': torch.Tensor(...),
27+ '0._weight_qdata ': torch.Tensor(...),
28+ '0._weight_scale ': torch.Tensor(...),
29+ '0.bias': torch.Tensor(...),
3030 }
3131 metadata = {
3232 '0.weight': {
@@ -53,7 +53,7 @@ def unflatten_tensor_state_dict(
5353 }
5454
5555 Args:
56- tensors_data_dict: a dictionary from "tensor_name: tensor_data_attribute_name" to flattened torch.Tensor data for tensor subclass instance
56+ tensors_data_dict: a dictionary from "{ tensor_name}_{ tensor_data_attribute_name} " to flattened torch.Tensor data for tensor subclass instance
5757 metadata: a dictionary from "tensor_name" to another dictionary that contains type and attributes for tensor subclass instance
5858
5959 Returns:
@@ -68,23 +68,28 @@ def unflatten_tensor_state_dict(
6868 result = {}
6969
7070 for tensor_name in tensor_names :
71+ prefix = f"{ tensor_name .rsplit ('.' , 1 )[0 ]} ._{ tensor_name .rsplit ('.' , 1 )[1 ]} _"
7172 tensor_tensors = {}
7273 for key , value in combined_data .items ():
73- if key .startswith (f" { tensor_name } :" ):
74+ if key .startswith (prefix ):
7475 # Remove the prefix
75- tensor_tensors [key [len (tensor_name ) + 1 :]] = value
76+ tensor_tensors [key [len (prefix ) :]] = value
7677
7778 tensor_metadata = json .loads (metadata .get (tensor_name ))
7879 tensor_type = tensor_metadata .get ("_type" )
7980
8081 if tensor_type in ALLOWED_TENSORS_SUBCLASSES :
82+ if not tensor_tensors :
83+ # we allow the option of loading in state_dict info for a single tensor
84+ # if tensor state dict info is not loaded in yet, we wait for it to be provided
85+ # in a future call
86+ continue
8187 tensor_metadata ["_data" ].update (tensor_tensors )
8288 result [tensor_name ] = object_from_dict (tensor_metadata )
8389 elif tensor_type == torch .Tensor .__name__ :
84- result [tensor_name ] = tensor_tensors [ "_data" ]
90+ result [tensor_name ] = tensors_data_dict [ tensor_name ]
8591 else :
8692 raise ValueError (f"Unsupported tensor type: { tensor_type } " )
87-
8893 return result
8994
9095
@@ -108,9 +113,9 @@ def flatten_tensor_state_dict(
108113
109114 We flatten this to:
110115 tensors_data = {
111- '0.weight:qdata ': torch.Tensor(...),
112- '0.weight:scale ': torch.Tensor(...),
113- '0.bias:_data ': torch.Tensor(...),
116+ '0._weight_qdata ': torch.Tensor(...),
117+ '0._weight_scale ': torch.Tensor(...),
118+ '0.bias': torch.Tensor(...),
114119 }
115120 metadata = {
116121 '0.weight': {
@@ -152,22 +157,23 @@ def flatten_tensor_state_dict(
152157 tensor_dict [tensor_data_name ] = getattr (tensor , tensor_data_name )
153158
154159 tensor_metadata = json .dumps (tensor , cls = TensorSubclassAttributeJSONEncoder )
160+
161+ # Clone tensors to avoid memory sharing issues
162+ tensors_dict_to_save = {
163+ f"{ tensor_name .rsplit ('.' , 1 )[0 ]} ._{ tensor_name .rsplit ('.' , 1 )[1 ]} _{ key } " : (
164+ value .detach ().clone () if isinstance (value , torch .Tensor ) else value
165+ )
166+ for key , value in tensor_dict .items ()
167+ }
168+
155169 elif type (tensor ) is torch .Tensor :
156- tensor_dict = {"_data" : tensor }
157170 tensor_metadata = json .dumps ({"_type" : torch .Tensor .__name__ })
171+ tensors_dict_to_save = {tensor_name : tensor }
158172 else :
159173 raise ValueError (f"Unsupported tensor type: { type (tensor )} " )
160174
161- # Clone tensors to avoid memory sharing issues
162- prefixed_tensors_dict = {
163- f"{ tensor_name } :{ key } " : (
164- value .detach ().clone () if isinstance (value , torch .Tensor ) else value
165- )
166- for key , value in tensor_dict .items ()
167- }
168-
169175 metadata [tensor_name ] = tensor_metadata
170- tensors_data_dict .update (prefixed_tensors_dict )
176+ tensors_data_dict .update (tensors_dict_to_save )
171177
172178 metadata ["tensor_names" ] = json .dumps (list (tensors_dict .keys ()))
173179 return tensors_data_dict , metadata
0 commit comments