-
Notifications
You must be signed in to change notification settings - Fork 454
safetensors support #2881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
safetensors support #2881
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
dbe8e4e
initial commit
liangel-02 2389675
safe_serialization=True
liangel-02 d0f7d53
json
liangel-02 57106e9
save and load
liangel-02 66a6697
load json
liangel-02 7252a33
remove binary files
liangel-02 54cd56a
delete debug
liangel-02 499da7e
comment
liangel-02 dfe0e13
refactor
liangel-02 a707c91
remove binary
liangel-02 75e420d
ruff
liangel-02 756e918
addressing comments
liangel-02 26b96a9
format
liangel-02 54e433e
remove
liangel-02 abce5eb
address comments
liangel-02 4d92717
clean up code
liangel-02 4f91503
fix comments
liangel-02 6397ed6
format
liangel-02 a9d960a
tempfile
liangel-02 9824c96
logging
liangel-02 c4e9165
addressing comments
liangel-02 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| import tempfile | ||
| import unittest | ||
|
|
||
| import torch | ||
| from torch.testing._internal.common_utils import ( | ||
| TestCase, | ||
| run_tests, | ||
| ) | ||
|
|
||
| from torchao import quantize_ | ||
| from torchao.prototype.safetensors.safetensors_support import ( | ||
| load_tensor_state_dict, | ||
| save_tensor_state_dict, | ||
| ) | ||
| from torchao.quantization.granularity import PerRow | ||
| from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig | ||
| from torchao.utils import ( | ||
| is_sm_at_least_89, | ||
| ) | ||
|
|
||
|
|
||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") | ||
| class TestSafeTensors(TestCase): | ||
| def test_safetensors(self): | ||
| config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) | ||
| model = torch.nn.Sequential( | ||
| torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda") | ||
| ) | ||
| quantize_(model, config) | ||
| example_inputs = (torch.randn(2, 32, dtype=torch.bfloat16, device="cuda"),) | ||
| ref_output = model(*example_inputs) | ||
|
|
||
| with tempfile.NamedTemporaryFile() as f: | ||
| save_tensor_state_dict(model.state_dict(), f.name) | ||
| reconstructed_dict = load_tensor_state_dict(f.name, device="cuda") | ||
|
|
||
| model = torch.nn.Sequential( | ||
| torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda") | ||
| ) | ||
| model.load_state_dict(reconstructed_dict, assign=True) | ||
| output = model(*example_inputs) | ||
| assert torch.equal(output, ref_output) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() |
Empty file.
161 changes: 161 additions & 0 deletions
161
torchao/prototype/safetensors/safetensors_serialization.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| import dataclasses | ||
| import enum | ||
| import json | ||
| from typing import Any, Dict | ||
|
|
||
| import torch | ||
|
|
||
| import torchao | ||
| from torchao.quantization import Float8Tensor | ||
| from torchao.quantization.quantize_.common import KernelPreference | ||
| from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs | ||
|
|
||
| ALLOWED_CLASSES = { | ||
| "Float8Tensor": Float8Tensor, | ||
| "Float8MMConfig": torchao.float8.inference.Float8MMConfig, | ||
| "QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs, | ||
| "PerRow": torchao.quantization.PerRow, | ||
liangel-02 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "PerTensor": torchao.quantization.PerTensor, | ||
| "KernelPreference": KernelPreference, | ||
| } | ||
|
|
||
|
|
||
| class Float8TensorAttributeJSONEncoder(json.JSONEncoder): | ||
| def default(self, o): | ||
| if isinstance(o, Float8Tensor): | ||
| tensor_attr_dict = {} | ||
| all_tensor_attributes = ( | ||
| o.optional_tensor_attribute_names + o.tensor_attribute_names | ||
| ) | ||
|
|
||
| for tensor_attribute_name in all_tensor_attributes: | ||
| attribute = getattr(o, tensor_attribute_name) | ||
| encoded_attribute = self.encode_value(attribute) | ||
| tensor_attr_dict[tensor_attribute_name] = encoded_attribute | ||
|
|
||
| return {"_type": o.__class__.__name__, "_data": tensor_attr_dict} | ||
|
|
||
| if hasattr(o, "_fields") and hasattr( | ||
| o, "_asdict" | ||
| ): # Check for NamedTuple characteristics | ||
| asdict_data = o._asdict() | ||
| # Process each field to handle nested objects | ||
| processed_data = {k: self.encode_value(v) for k, v in asdict_data.items()} | ||
|
|
||
| return { | ||
| "_type": o.__class__.__name__, | ||
| "_data": processed_data, | ||
| } | ||
|
|
||
| if dataclasses.is_dataclass(o) and not isinstance(o, type): | ||
| data_dict = {} | ||
| # Process each field to handle nested objects | ||
| for f in dataclasses.fields(o): | ||
| data_dict[f.name] = self.encode_value(getattr(o, f.name)) | ||
|
|
||
| return { | ||
| "_type": o.__class__.__name__, | ||
| "_data": data_dict, | ||
| } | ||
|
|
||
| if isinstance(o, torch.dtype): | ||
| return {"_type": "torch.dtype", "_data": str(o).split(".")[-1]} | ||
|
|
||
| if isinstance(o, enum.Enum): | ||
| # Store the full class name for enums to ensure uniqueness | ||
| return {"_type": f"{o.__class__.__name__}", "_data": o.name} | ||
|
|
||
| if isinstance(o, list): | ||
| return [self.encode_value(item) for item in o] | ||
|
|
||
| if isinstance(o, dict): | ||
| return {k: self.encode_value(v) for k, v in o.items()} | ||
|
|
||
| # Default case | ||
| return super().default(o) | ||
|
|
||
| def encode_value(self, value): | ||
| """Helper method to recursively encode a value""" | ||
| # Try to use default for custom type | ||
| try: | ||
| # This will handle all our special cases and raise TypeError | ||
| # if it can't handle the type | ||
| result = self.default(value) | ||
| return result | ||
| except TypeError: | ||
| pass | ||
|
|
||
| # Default case - return as is | ||
| # (This will be processed by standard JSON encoder later) | ||
| return value | ||
|
|
||
|
|
||
| def object_from_dict(data: Dict[str, Any]): | ||
| if not isinstance(data, dict): | ||
| raise TypeError(f"Expected dictionary, got {type(data)}") | ||
|
|
||
| if "_type" not in data or "_data" not in data: | ||
| raise ValueError("Input dictionary missing required '_type' or '_data' fields") | ||
|
|
||
| type_path = data["_type"] | ||
| obj_data = data["_data"] | ||
|
|
||
| if type_path == "torch.dtype": | ||
| return getattr(torch, obj_data) | ||
|
|
||
| cls = ALLOWED_CLASSES.get(type_path) | ||
|
|
||
| # If we couldn't find the class in any allowed module, raise an error | ||
| if cls is None: | ||
| allowed_modules_str = ", ".join(ALLOWED_CLASSES) | ||
| raise ValueError( | ||
| f"Failed to find class {type_path} in any of the allowed modules: {allowed_modules_str}" | ||
| ) | ||
|
|
||
| # Handle the case where obj_data is not a dictionary | ||
| if not isinstance(obj_data, dict): | ||
| if issubclass(cls, enum.Enum): | ||
| # For enums, convert string to enum value | ||
| return getattr(cls, obj_data) | ||
| else: | ||
| # For other primitive types, create an instance with the value | ||
| try: | ||
| return cls(obj_data) | ||
| except: | ||
| return obj_data | ||
|
|
||
| processed_data = {} | ||
|
|
||
| for key, value in obj_data.items(): | ||
| if isinstance(value, dict) and "_type" in value and "_data" in value: | ||
| # Recursively handle nested configs | ||
| processed_data[key] = object_from_dict(value) | ||
| elif isinstance(value, list): | ||
| # Handle lists or tuples of possible configs | ||
| processed_data[key] = [ | ||
| object_from_dict(item) | ||
| if isinstance(item, dict) and "_type" in item and "_data" in item | ||
| else item | ||
| for item in value | ||
| ] | ||
| elif isinstance(value, tuple): | ||
| raise NotImplementedError( | ||
| "Tuples will be serialized as List in JSON, so we recommend to use " | ||
| f"Lists instead to avoid surprises. got: {value}" | ||
| ) | ||
| elif isinstance(value, dict): | ||
| # Handle dicts of possible configs | ||
| processed_data[key] = { | ||
| k: object_from_dict(v) | ||
| if isinstance(v, dict) and "_type" in v and "_data" in v | ||
| else v | ||
| for k, v in value.items() | ||
| } | ||
| else: | ||
| processed_data[key] = value | ||
|
|
||
| # Create and return the instance | ||
| try: | ||
| return cls(**processed_data) | ||
| except Exception as e: | ||
| raise ValueError(f"Failed to create instance of {cls.__name__}: {e}") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| import json | ||
| import logging | ||
| from typing import Dict | ||
|
|
||
| import torch | ||
| from safetensors.torch import load_file, save_file | ||
|
|
||
| from torchao.prototype.safetensors.safetensors_serialization import ( | ||
| Float8TensorAttributeJSONEncoder, | ||
| object_from_dict, | ||
| ) | ||
| from torchao.quantization import Float8Tensor | ||
|
|
||
| logger: logging.Logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def load_tensor_state_dict(file_path: str, device: str): | ||
| """ | ||
| Load a dictionary of tensor subclasses from a safetensors file. | ||
|
|
||
| For torch.Tensors, we load: | ||
| - _data: the tensor data | ||
| - _type: the tensor type | ||
|
|
||
| For Float8Tensor, we load: | ||
| - tensor_data: qdata and scale | ||
| - tensor_attributes: | ||
| - block_size | ||
| - mm_config | ||
| - hp_value_lb | ||
| - hp_value_ub | ||
| - act_quant_kwargs | ||
| - kernel_preference | ||
| - dtype | ||
|
|
||
| Args: | ||
| file_path: Path to the safetensors file | ||
|
|
||
| Returns: | ||
| Dictionary of reconstructed tensor subclasses | ||
| """ | ||
| loaded_tensors = load_file(file_path, device) | ||
|
|
||
| with open(file_path, "rb") as f: | ||
| import struct | ||
|
|
||
| header_size = struct.unpack("<Q", f.read(8))[0] | ||
| header_bytes = f.read(header_size) | ||
| header = json.loads(header_bytes) | ||
| metadata = header.get("__metadata__", {}) | ||
|
|
||
| if "tensor_names" not in metadata: | ||
| raise ValueError("No tensors found") | ||
|
|
||
| tensor_names = json.loads(metadata["tensor_names"]) | ||
| result = {} | ||
|
|
||
| for tensor_name in tensor_names: | ||
| tensor_tensors = {} | ||
| for key, value in loaded_tensors.items(): | ||
| if key.startswith(f"{tensor_name}:"): | ||
| # Remove the prefix | ||
| tensor_tensors[key[len(tensor_name) + 1 :]] = value | ||
|
|
||
| tensor_metadata = json.loads(metadata.get(tensor_name)) | ||
| tensor_type = tensor_metadata.get("_type") | ||
|
|
||
| if tensor_type == Float8Tensor.__name__: | ||
| tensor_metadata["_data"].update(tensor_tensors) | ||
| result[tensor_name] = object_from_dict(tensor_metadata) | ||
| elif tensor_type == torch.Tensor.__name__: | ||
| result[tensor_name] = tensor_tensors["_data"] | ||
| else: | ||
| raise ValueError(f"Unsupported tensor type: {tensor_type}") | ||
|
|
||
| logger.info( | ||
| f"Loaded {len(tensor_names)} tensor subclasses from {file_path} with metadata" | ||
| ) | ||
| return result | ||
|
|
||
|
|
||
| def save_tensor_state_dict( | ||
| tensor_dict: Dict[str, Dict[str, torch.Tensor]], | ||
| file_path: str, | ||
| ): | ||
| """ | ||
| Save a dictionary of tensor subclasses with appropriate metadata. | ||
|
|
||
| For torch.Tensors, we save: | ||
| - _data: the tensor data | ||
| - _type: the tensor type | ||
|
|
||
| For Float8Tensor, we save: | ||
| - tensor_data: | ||
| - qdata | ||
| - scale | ||
| - tensor_attributes: | ||
| - block_size | ||
| - mm_config | ||
| - hp_value_lb | ||
| - hp_value_ub | ||
| - act_quant_kwargs | ||
| - kernel_preference | ||
| - dtype | ||
|
|
||
| Args: | ||
| tensor_dict: Dictionary of tensor subclasses to save, with keys as tensor names | ||
| file_path: Path where to save the tensors | ||
| """ | ||
|
|
||
| combined_metadata = {} | ||
| combined_tensors_dict = {} | ||
|
|
||
| for tensor_name, tensor in tensor_dict.items(): | ||
| if isinstance(tensor, Float8Tensor): | ||
| tensors_dict = {} | ||
| for tensor_data_name in tensor.tensor_data_names: | ||
| tensors_dict[tensor_data_name] = getattr(tensor, tensor_data_name) | ||
|
|
||
| metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder) | ||
| elif type(tensor) is torch.Tensor: | ||
| tensors_dict = {"_data": tensor} | ||
| metadata = json.dumps({"_type": torch.Tensor.__name__}) | ||
| else: | ||
| raise ValueError(f"Unsupported tensor type: {type(tensor)}") | ||
|
|
||
| # Clone tensors to avoid memory sharing issues | ||
| prefixed_tensors_dict = { | ||
| f"{tensor_name}:{key}": ( | ||
| value.detach().clone() if isinstance(value, torch.Tensor) else value | ||
| ) | ||
| for key, value in tensors_dict.items() | ||
| } | ||
|
|
||
| combined_metadata[tensor_name] = metadata | ||
| combined_tensors_dict.update(prefixed_tensors_dict) | ||
|
|
||
| combined_metadata["tensor_names"] = json.dumps(list(tensor_dict.keys())) | ||
|
|
||
| save_file(combined_tensors_dict, file_path, metadata=combined_metadata) | ||
| logger.info( | ||
| f"Saved {len(tensor_dict)} tensor subclasses to {file_path} with metadata" | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like some imports are not used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are used to create the ALLOWED_CLASSES dict