Skip to content

Commit 1c695a1

Browse files
committed
updating flatten/unflatten functions
1 parent 2a89491 commit 1c695a1

File tree

4 files changed

+35
-22
lines changed

4 files changed

+35
-22
lines changed

benchmarks/benchmark_uintx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from copy import deepcopy
77

88
import torch
9-
109
from torchao.prototype.uintx import (
1110
uintx_affine_weight_only,
1211
unpack_cpu,
1312
)
13+
1414
from torchao.quantization.quant_api import quantize_
1515

1616

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ def test_safetensors(self, config, act_pre_scale=False):
6666

6767
with tempfile.NamedTemporaryFile() as f:
6868
tensors_data_dict, metadata = flatten_tensor_state_dict(model.state_dict())
69+
70+
for key in tensors_data_dict.keys():
71+
assert key.startswith("0._weight_") or key.startswith("0.bias"), (
72+
f"Unexpected key format: {key}"
73+
)
74+
6975
save_file(tensors_data_dict, f.name, metadata=metadata)
7076
tensors_data_dict, metadata = load_data(file_path=f.name, device="cuda")
7177
reconstructed_dict = unflatten_tensor_state_dict(

test/test_low_bit_optim.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
common_utils.SEED = 1234
3131

3232
from packaging.version import Version
33+
3334
from torchao import optim
3435
from torchao.optim.quant_utils import (
3536
_fp32_to_bf16_sr,

torchao/prototype/safetensors/safetensors_support.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def unflatten_tensor_state_dict(
2424
2525
For example, given a previously flattened tensors_data_dict and metadata:
2626
tensors_data_dict = {
27-
'0.weight:qdata': torch.Tensor(...),
28-
'0.weight:scale': torch.Tensor(...),
29-
'0.bias:_data': torch.Tensor(...),
27+
'0._weight_qdata': torch.Tensor(...),
28+
'0._weight_scale': torch.Tensor(...),
29+
'0.bias': torch.Tensor(...),
3030
}
3131
metadata = {
3232
'0.weight': {
@@ -53,7 +53,7 @@ def unflatten_tensor_state_dict(
5353
}
5454
5555
Args:
56-
tensors_data_dict: a dictionary from "tensor_name:tensor_data_attribute_name" to flattened torch.Tensor data for tensor subclass instance
56+
tensors_data_dict: a dictionary from "{tensor_name}_{tensor_data_attribute_name}" to flattened torch.Tensor data for tensor subclass instance
5757
metadata: a dictionary from "tensor_name" to another dictionary that contains type and attributes for tensor subclass instance
5858
5959
Returns:
@@ -68,23 +68,28 @@ def unflatten_tensor_state_dict(
6868
result = {}
6969

7070
for tensor_name in tensor_names:
71+
prefix = f"{tensor_name.rsplit('.', 1)[0]}._{tensor_name.rsplit('.', 1)[1]}_"
7172
tensor_tensors = {}
7273
for key, value in combined_data.items():
73-
if key.startswith(f"{tensor_name}:"):
74+
if key.startswith(prefix):
7475
# Remove the prefix
75-
tensor_tensors[key[len(tensor_name) + 1 :]] = value
76+
tensor_tensors[key[len(prefix) :]] = value
7677

7778
tensor_metadata = json.loads(metadata.get(tensor_name))
7879
tensor_type = tensor_metadata.get("_type")
7980

8081
if tensor_type in ALLOWED_TENSORS_SUBCLASSES:
82+
if not tensor_tensors:
83+
# we allow the option of loading in state_dict info for a single tensor
84+
# if tensor state dict info is not loaded in yet, we wait for it to be provided
85+
# in a future call
86+
continue
8187
tensor_metadata["_data"].update(tensor_tensors)
8288
result[tensor_name] = object_from_dict(tensor_metadata)
8389
elif tensor_type == torch.Tensor.__name__:
84-
result[tensor_name] = tensor_tensors["_data"]
90+
result[tensor_name] = tensors_data_dict[tensor_name]
8591
else:
8692
raise ValueError(f"Unsupported tensor type: {tensor_type}")
87-
8893
return result
8994

9095

@@ -108,9 +113,9 @@ def flatten_tensor_state_dict(
108113
109114
We flatten this to:
110115
tensors_data = {
111-
'0.weight:qdata': torch.Tensor(...),
112-
'0.weight:scale': torch.Tensor(...),
113-
'0.bias:_data': torch.Tensor(...),
116+
'0._weight_qdata': torch.Tensor(...),
117+
'0._weight_scale': torch.Tensor(...),
118+
'0.bias': torch.Tensor(...),
114119
}
115120
metadata = {
116121
'0.weight': {
@@ -152,22 +157,23 @@ def flatten_tensor_state_dict(
152157
tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name)
153158

154159
tensor_metadata = json.dumps(tensor, cls=TensorSubclassAttributeJSONEncoder)
160+
161+
# Clone tensors to avoid memory sharing issues
162+
tensors_dict_to_save = {
163+
f"{tensor_name.rsplit('.', 1)[0]}._{tensor_name.rsplit('.', 1)[1]}_{key}": (
164+
value.detach().clone() if isinstance(value, torch.Tensor) else value
165+
)
166+
for key, value in tensor_dict.items()
167+
}
168+
155169
elif type(tensor) is torch.Tensor:
156-
tensor_dict = {"_data": tensor}
157170
tensor_metadata = json.dumps({"_type": torch.Tensor.__name__})
171+
tensors_dict_to_save = {tensor_name: tensor}
158172
else:
159173
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
160174

161-
# Clone tensors to avoid memory sharing issues
162-
prefixed_tensors_dict = {
163-
f"{tensor_name}:{key}": (
164-
value.detach().clone() if isinstance(value, torch.Tensor) else value
165-
)
166-
for key, value in tensor_dict.items()
167-
}
168-
169175
metadata[tensor_name] = tensor_metadata
170-
tensors_data_dict.update(prefixed_tensors_dict)
176+
tensors_data_dict.update(tensors_dict_to_save)
171177

172178
metadata["tensor_names"] = json.dumps(list(tensors_dict.keys()))
173179
return tensors_data_dict, metadata

0 commit comments

Comments
 (0)