Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions test/prototype/safetensors/test_safetensors_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import tempfile
import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao import quantize_
from torchao.prototype.safetensors.safetensors_support import (
load_tensor_state_dict,
save_tensor_state_dict,
)
from torchao.quantization.granularity import PerRow
from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig
from torchao.utils import (
is_sm_at_least_89,
)


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
class TestSafeTensors(TestCase):
def test_safetensors(self):
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
model = torch.nn.Sequential(
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
)
quantize_(model, config)
example_inputs = (torch.randn(2, 32, dtype=torch.bfloat16, device="cuda"),)
ref_output = model(*example_inputs)

with tempfile.NamedTemporaryFile() as f:
save_tensor_state_dict(model.state_dict(), f.name)
reconstructed_dict = load_tensor_state_dict(f.name, device="cuda")

model = torch.nn.Sequential(
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
)
model.load_state_dict(reconstructed_dict, assign=True)
output = model(*example_inputs)
assert torch.equal(output, ref_output)


if __name__ == "__main__":
run_tests()
Empty file.
161 changes: 161 additions & 0 deletions torchao/prototype/safetensors/safetensors_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import dataclasses
import enum
import json
from typing import Any, Dict

import torch

import torchao
from torchao.quantization import Float8Tensor
from torchao.quantization.quantize_.common import KernelPreference
from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs
Comment on lines +9 to +11
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like some imports are not used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are used to create the ALLOWED_CLASSES dict


ALLOWED_CLASSES = {
"Float8Tensor": Float8Tensor,
"Float8MMConfig": torchao.float8.inference.Float8MMConfig,
"QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs,
"PerRow": torchao.quantization.PerRow,
"PerTensor": torchao.quantization.PerTensor,
"KernelPreference": KernelPreference,
}


class Float8TensorAttributeJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, Float8Tensor):
tensor_attr_dict = {}
all_tensor_attributes = (
o.optional_tensor_attribute_names + o.tensor_attribute_names
)

for tensor_attribute_name in all_tensor_attributes:
attribute = getattr(o, tensor_attribute_name)
encoded_attribute = self.encode_value(attribute)
tensor_attr_dict[tensor_attribute_name] = encoded_attribute

return {"_type": o.__class__.__name__, "_data": tensor_attr_dict}

if hasattr(o, "_fields") and hasattr(
o, "_asdict"
): # Check for NamedTuple characteristics
asdict_data = o._asdict()
# Process each field to handle nested objects
processed_data = {k: self.encode_value(v) for k, v in asdict_data.items()}

return {
"_type": o.__class__.__name__,
"_data": processed_data,
}

if dataclasses.is_dataclass(o) and not isinstance(o, type):
data_dict = {}
# Process each field to handle nested objects
for f in dataclasses.fields(o):
data_dict[f.name] = self.encode_value(getattr(o, f.name))

return {
"_type": o.__class__.__name__,
"_data": data_dict,
}

if isinstance(o, torch.dtype):
return {"_type": "torch.dtype", "_data": str(o).split(".")[-1]}

if isinstance(o, enum.Enum):
# Store the full class name for enums to ensure uniqueness
return {"_type": f"{o.__class__.__name__}", "_data": o.name}

if isinstance(o, list):
return [self.encode_value(item) for item in o]

if isinstance(o, dict):
return {k: self.encode_value(v) for k, v in o.items()}

# Default case
return super().default(o)

def encode_value(self, value):
"""Helper method to recursively encode a value"""
# Try to use default for custom type
try:
# This will handle all our special cases and raise TypeError
# if it can't handle the type
result = self.default(value)
return result
except TypeError:
pass

# Default case - return as is
# (This will be processed by standard JSON encoder later)
return value


def object_from_dict(data: Dict[str, Any]):
if not isinstance(data, dict):
raise TypeError(f"Expected dictionary, got {type(data)}")

if "_type" not in data or "_data" not in data:
raise ValueError("Input dictionary missing required '_type' or '_data' fields")

type_path = data["_type"]
obj_data = data["_data"]

if type_path == "torch.dtype":
return getattr(torch, obj_data)

cls = ALLOWED_CLASSES.get(type_path)

# If we couldn't find the class in any allowed module, raise an error
if cls is None:
allowed_modules_str = ", ".join(ALLOWED_CLASSES)
raise ValueError(
f"Failed to find class {type_path} in any of the allowed modules: {allowed_modules_str}"
)

# Handle the case where obj_data is not a dictionary
if not isinstance(obj_data, dict):
if issubclass(cls, enum.Enum):
# For enums, convert string to enum value
return getattr(cls, obj_data)
else:
# For other primitive types, create an instance with the value
try:
return cls(obj_data)
except:
return obj_data

processed_data = {}

for key, value in obj_data.items():
if isinstance(value, dict) and "_type" in value and "_data" in value:
# Recursively handle nested configs
processed_data[key] = object_from_dict(value)
elif isinstance(value, list):
# Handle lists or tuples of possible configs
processed_data[key] = [
object_from_dict(item)
if isinstance(item, dict) and "_type" in item and "_data" in item
else item
for item in value
]
elif isinstance(value, tuple):
raise NotImplementedError(
"Tuples will be serialized as List in JSON, so we recommend to use "
f"Lists instead to avoid surprises. got: {value}"
)
elif isinstance(value, dict):
# Handle dicts of possible configs
processed_data[key] = {
k: object_from_dict(v)
if isinstance(v, dict) and "_type" in v and "_data" in v
else v
for k, v in value.items()
}
else:
processed_data[key] = value

# Create and return the instance
try:
return cls(**processed_data)
except Exception as e:
raise ValueError(f"Failed to create instance of {cls.__name__}: {e}")
143 changes: 143 additions & 0 deletions torchao/prototype/safetensors/safetensors_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import json
import logging
from typing import Dict

import torch
from safetensors.torch import load_file, save_file

from torchao.prototype.safetensors.safetensors_serialization import (
Float8TensorAttributeJSONEncoder,
object_from_dict,
)
from torchao.quantization import Float8Tensor

logger: logging.Logger = logging.getLogger(__name__)


def load_tensor_state_dict(file_path: str, device: str):
"""
Load a dictionary of tensor subclasses from a safetensors file.

For torch.Tensors, we load:
- _data: the tensor data
- _type: the tensor type

For Float8Tensor, we load:
- tensor_data: qdata and scale
- tensor_attributes:
- block_size
- mm_config
- hp_value_lb
- hp_value_ub
- act_quant_kwargs
- kernel_preference
- dtype

Args:
file_path: Path to the safetensors file

Returns:
Dictionary of reconstructed tensor subclasses
"""
loaded_tensors = load_file(file_path, device)

with open(file_path, "rb") as f:
import struct

header_size = struct.unpack("<Q", f.read(8))[0]
header_bytes = f.read(header_size)
header = json.loads(header_bytes)
metadata = header.get("__metadata__", {})

if "tensor_names" not in metadata:
raise ValueError("No tensors found")

tensor_names = json.loads(metadata["tensor_names"])
result = {}

for tensor_name in tensor_names:
tensor_tensors = {}
for key, value in loaded_tensors.items():
if key.startswith(f"{tensor_name}:"):
# Remove the prefix
tensor_tensors[key[len(tensor_name) + 1 :]] = value

tensor_metadata = json.loads(metadata.get(tensor_name))
tensor_type = tensor_metadata.get("_type")

if tensor_type == Float8Tensor.__name__:
tensor_metadata["_data"].update(tensor_tensors)
result[tensor_name] = object_from_dict(tensor_metadata)
elif tensor_type == torch.Tensor.__name__:
result[tensor_name] = tensor_tensors["_data"]
else:
raise ValueError(f"Unsupported tensor type: {tensor_type}")

logger.info(
f"Loaded {len(tensor_names)} tensor subclasses from {file_path} with metadata"
)
return result


def save_tensor_state_dict(
tensor_dict: Dict[str, Dict[str, torch.Tensor]],
file_path: str,
):
"""
Save a dictionary of tensor subclasses with appropriate metadata.

For torch.Tensors, we save:
- _data: the tensor data
- _type: the tensor type

For Float8Tensor, we save:
- tensor_data:
- qdata
- scale
- tensor_attributes:
- block_size
- mm_config
- hp_value_lb
- hp_value_ub
- act_quant_kwargs
- kernel_preference
- dtype

Args:
tensor_dict: Dictionary of tensor subclasses to save, with keys as tensor names
file_path: Path where to save the tensors
"""

combined_metadata = {}
combined_tensors_dict = {}

for tensor_name, tensor in tensor_dict.items():
if isinstance(tensor, Float8Tensor):
tensors_dict = {}
for tensor_data_name in tensor.tensor_data_names:
tensors_dict[tensor_data_name] = getattr(tensor, tensor_data_name)

metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder)
elif type(tensor) is torch.Tensor:
tensors_dict = {"_data": tensor}
metadata = json.dumps({"_type": torch.Tensor.__name__})
else:
raise ValueError(f"Unsupported tensor type: {type(tensor)}")

# Clone tensors to avoid memory sharing issues
prefixed_tensors_dict = {
f"{tensor_name}:{key}": (
value.detach().clone() if isinstance(value, torch.Tensor) else value
)
for key, value in tensors_dict.items()
}

combined_metadata[tensor_name] = metadata
combined_tensors_dict.update(prefixed_tensors_dict)

combined_metadata["tensor_names"] = json.dumps(list(tensor_dict.keys()))

save_file(combined_tensors_dict, file_path, metadata=combined_metadata)
logger.info(
f"Saved {len(tensor_dict)} tensor subclasses to {file_path} with metadata"
)
Loading