Skip to content

Commit 1539983

Browse files
committed
updating flatten/unflatten functions
1 parent 2a89491 commit 1539983

File tree

2 files changed

+44
-21
lines changed

2 files changed

+44
-21
lines changed

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import re
23
import tempfile
34
import unittest
45

@@ -38,6 +39,19 @@ def load_data(file_path: str, device: str):
3839
return loaded_tensors, metadata
3940

4041

42+
def check_saved_tensor_names_format(state_dict, metadata):
43+
original_tensor_names = metadata["tensor_names"]
44+
for key in state_dict.keys():
45+
m = re.match(r"^(.*)\._([^_]+)_.+", key)
46+
if m:
47+
reverted_key = f"{m.group(1)}.{m.group(2)}"
48+
else:
49+
reverted_key = key.split("_", 1)[0]
50+
assert reverted_key in original_tensor_names, (
51+
f"Reverted key {reverted_key} not found in original state_dict keys"
52+
)
53+
54+
4155
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
4256
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
4357
class TestSafeTensors(TestCase):
@@ -66,6 +80,9 @@ def test_safetensors(self, config, act_pre_scale=False):
6680

6781
with tempfile.NamedTemporaryFile() as f:
6882
tensors_data_dict, metadata = flatten_tensor_state_dict(model.state_dict())
83+
84+
test_saved_tensor_names_format(tensors_data_dict, metadata)
85+
6986
save_file(tensors_data_dict, f.name, metadata=metadata)
7087
tensors_data_dict, metadata = load_data(file_path=f.name, device="cuda")
7188
reconstructed_dict = unflatten_tensor_state_dict(

torchao/prototype/safetensors/safetensors_support.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def unflatten_tensor_state_dict(
2424
2525
For example, given a previously flattened tensors_data_dict and metadata:
2626
tensors_data_dict = {
27-
'0.weight:qdata': torch.Tensor(...),
28-
'0.weight:scale': torch.Tensor(...),
29-
'0.bias:_data': torch.Tensor(...),
27+
'0._weight_qdata': torch.Tensor(...),
28+
'0._weight_scale': torch.Tensor(...),
29+
'0.bias': torch.Tensor(...),
3030
}
3131
metadata = {
3232
'0.weight': {
@@ -53,7 +53,7 @@ def unflatten_tensor_state_dict(
5353
}
5454
5555
Args:
56-
tensors_data_dict: a dictionary from "tensor_name:tensor_data_attribute_name" to flattened torch.Tensor data for tensor subclass instance
56+
tensors_data_dict: a dictionary from "{tensor_name}_{tensor_data_attribute_name}" to flattened torch.Tensor data for tensor subclass instance
5757
metadata: a dictionary from "tensor_name" to another dictionary that contains type and attributes for tensor subclass instance
5858
5959
Returns:
@@ -68,23 +68,28 @@ def unflatten_tensor_state_dict(
6868
result = {}
6969

7070
for tensor_name in tensor_names:
71+
prefix = f"{tensor_name.rsplit('.', 1)[0]}._{tensor_name.rsplit('.', 1)[1]}_"
7172
tensor_tensors = {}
7273
for key, value in combined_data.items():
73-
if key.startswith(f"{tensor_name}:"):
74+
if key.startswith(prefix):
7475
# Remove the prefix
75-
tensor_tensors[key[len(tensor_name) + 1 :]] = value
76+
tensor_tensors[key[len(prefix) :]] = value
7677

7778
tensor_metadata = json.loads(metadata.get(tensor_name))
7879
tensor_type = tensor_metadata.get("_type")
7980

8081
if tensor_type in ALLOWED_TENSORS_SUBCLASSES:
82+
if not tensor_tensors:
83+
# we allow the option of loading in state_dict info for a single tensor
84+
# if tensor state dict info is not loaded in yet, we wait for it to be provided
85+
# in a future call
86+
continue
8187
tensor_metadata["_data"].update(tensor_tensors)
8288
result[tensor_name] = object_from_dict(tensor_metadata)
8389
elif tensor_type == torch.Tensor.__name__:
84-
result[tensor_name] = tensor_tensors["_data"]
90+
result[tensor_name] = tensors_data_dict[tensor_name]
8591
else:
8692
raise ValueError(f"Unsupported tensor type: {tensor_type}")
87-
8893
return result
8994

9095

@@ -108,9 +113,9 @@ def flatten_tensor_state_dict(
108113
109114
We flatten this to:
110115
tensors_data = {
111-
'0.weight:qdata': torch.Tensor(...),
112-
'0.weight:scale': torch.Tensor(...),
113-
'0.bias:_data': torch.Tensor(...),
116+
'0._weight_qdata': torch.Tensor(...),
117+
'0._weight_scale': torch.Tensor(...),
118+
'0.bias': torch.Tensor(...),
114119
}
115120
metadata = {
116121
'0.weight': {
@@ -152,22 +157,23 @@ def flatten_tensor_state_dict(
152157
tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name)
153158

154159
tensor_metadata = json.dumps(tensor, cls=TensorSubclassAttributeJSONEncoder)
160+
161+
# Clone tensors to avoid memory sharing issues
162+
tensors_dict_to_save = {
163+
f"{tensor_name.rsplit('.', 1)[0]}._{tensor_name.rsplit('.', 1)[1]}_{key}": (
164+
value.detach().clone() if isinstance(value, torch.Tensor) else value
165+
)
166+
for key, value in tensor_dict.items()
167+
}
168+
155169
elif type(tensor) is torch.Tensor:
156-
tensor_dict = {"_data": tensor}
157170
tensor_metadata = json.dumps({"_type": torch.Tensor.__name__})
171+
tensors_dict_to_save = {tensor_name: tensor}
158172
else:
159173
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
160174

161-
# Clone tensors to avoid memory sharing issues
162-
prefixed_tensors_dict = {
163-
f"{tensor_name}:{key}": (
164-
value.detach().clone() if isinstance(value, torch.Tensor) else value
165-
)
166-
for key, value in tensor_dict.items()
167-
}
168-
169175
metadata[tensor_name] = tensor_metadata
170-
tensors_data_dict.update(prefixed_tensors_dict)
176+
tensors_data_dict.update(tensors_dict_to_save)
171177

172178
metadata["tensor_names"] = json.dumps(list(tensors_dict.keys()))
173179
return tensors_data_dict, metadata

0 commit comments

Comments
 (0)