Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support float8 dtype storage. #9906

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions paddlenlp/utils/safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@
import mmap
from collections import OrderedDict

import ml_dtypes
import numpy as np

__all__ = [
"fast_safe_open",
"fast_load_file",
]

np.bfloat16 = ml_dtypes.bfloat16
np.float8_e5m2 = ml_dtypes.float8_e5m2
np.float8_e4m3fn = ml_dtypes.float8_e4m3fn


MAX_HEADER_SIZE = 100 * 1000 * 1000

Expand All @@ -49,16 +54,16 @@
"BOOL": np.bool_,
"U8": np.uint8,
"I8": np.int8,
"F8_E5M2": 1, # no fp8
"F8_E4M3": 1, # no fp8
"F8_E5M2": np.float8_e5m2, # no fp8
"F8_E4M3": np.float8_e4m3fn, # no fp8
"I16": np.int16,
"U16": np.uint16,
"I32": np.int32,
"U32": np.uint32,
"I64": np.int64,
"U64": np.uint64,
"F16": np.float16,
"BF16": 2, # no bf16
"BF16": np.bfloat16, # no bf16
"F32": np.float32,
"F64": np.float64,
}
Expand Down Expand Up @@ -238,9 +243,12 @@ def __getitem__(self, index):
return tensor.reshape(target_shape)

def get(self, *args, **kwargs):
tensor = np.empty(shape=self.shape, dtype=self.dtype)
self.bufferfile.seek(self.start_offset)
self.bufferfile.readinto(memoryview(tensor))
# tensor = np.empty(shape=self.shape, dtype=self.dtype)
# self.bufferfile.seek(self.start_offset)
# self.bufferfile.readinto(memoryview(tensor))
nbytes = np.prod(self.shape) * np.dtype(self.dtype).itemsize
buffer = self.bufferfile.read(nbytes)
tensor = np.frombuffer(buffer, dtype=self.dtype).reshape(self.shape)
return tensor

@property
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ regex
numpy<=1.26.4
tiktoken
tokenizers>=0.21,<0.22
ml_dtypes
omegaconf
127 changes: 114 additions & 13 deletions tests/transformers/test_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,119 @@
import unittest

import numpy as np
import paddle
from safetensors.numpy import load_file, save_file

from paddlenlp.utils.safetensors import fast_load_file, fast_safe_open

from ..testing_utils import skip_platform

paddle.set_device("cpu")


def enhanced_to_tensor(tensor):
if tensor.dtype == np.bfloat16:
return paddle.to_tensor(tensor.view(np.uint16))
if tensor.dtype == np.float8_e5m2:
t = paddle.to_tensor(tensor.view(np.int8))
new_t = paddle.empty(t.shape, dtype=paddle.float8_e5m2)
new_t.get_tensor()._share_data_with(t.get_tensor())
return new_t
if tensor.dtype == np.float8_e4m3fn:
t = paddle.to_tensor(tensor.view(np.int8))
new_t = paddle.empty(t.shape, dtype=paddle.float8_e4m3fn)
new_t.get_tensor()._share_data_with(t.get_tensor())
return new_t
# return paddle.to_tensor(tensor.view(np.int8), dtype=paddle.float8_e4m3fn)
raise ValueError()
return paddle.to_tensor(tensor)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trick 绕过fp8赋值。



class EextendDtypeNumpySafe(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extend

def setUp(self):
super().setUp()
self.weight_map = {}
self.tensors = [
([10, 1, 10], "float32"),
([1, 1, 10], "float32"),
([1, 1, 1, 10], "float32"),
([10, 10], "float32"),
([8], "float16"),
([5, 5, 5], "int32"),
]

def get_target_dtype(self, dtype="float32"):
count = 0
weight_map = {}
for shape, _ in self.tensors:
weight_map[f"weight_{count}"] = (np.random.random(shape) * 100).astype(dtype)
count += 1
return weight_map

def get_paddle_target_dtype(self, dtype="float32"):
weight_map = self.get_target_dtype(dtype)
for k, v in list(weight_map.items()):
weight_map[k] = enhanced_to_tensor(v)
return weight_map

@skip_platform("win32", "cygwin")
def test_save_load_file_paddle(self):
with tempfile.TemporaryDirectory() as tmpdirname:
for dtype in ["bfloat16", "float8_e5m2", "float8_e4m3fn"]:
weight_map = self.get_paddle_target_dtype(dtype)
path = os.path.join(tmpdirname, "test.safetensors")
shard = {}
for k in list(weight_map.keys()):
if isinstance(weight_map[k], paddle.Tensor):
shard[k] = weight_map[k].cpu().numpy()
else:
shard[k] = weight_map[k]

save_file(shard, path, metadata={"format": "np"})
sf_load = load_file(path)
fs_sf_load = fast_load_file(path)

for k, v in self.weight_map.items():
paddle.allclose(v, enhanced_to_tensor(sf_load[k]))
paddle.allclose(v, enhanced_to_tensor(fs_sf_load[k]))

@skip_platform("win32", "cygwin")
def test_save_load_file(self):
with tempfile.TemporaryDirectory() as tmpdirname:
for dtype in ["bfloat16", "float8_e4m3fn", "float8_e5m2"]:
weight_map = self.get_target_dtype(dtype)
path = os.path.join(tmpdirname, "test.safetensors")
save_file(weight_map, path, metadata={"format": "np"})
sf_load = load_file(path)
fs_sf_load = fast_load_file(path)
for k, v in self.weight_map.items():
np.testing.assert_equal(v, sf_load[k])
np.testing.assert_equal(v, fs_sf_load[k])

@skip_platform("win32", "cygwin")
def test_dtype_safe_open(self):
with tempfile.TemporaryDirectory() as tmpdirname:
for dtype in ["float32", "int32", "bfloat16", "float8_e4m3fn", "float8_e5m2"]:
weight_map = self.get_target_dtype(dtype)
path = os.path.join(tmpdirname, "test.safetensors")
save_file(weight_map, path, metadata={"format": "np"})

with fast_safe_open(path, framework="np") as f:
for key in f.keys():
safe_slice = f.get_slice(key)
# np.testing.assert_equal(self.weight_map[key][2:1, ...], safe_slice[2:1, ...])
np.testing.assert_equal(weight_map[key][0, ...], safe_slice[0, ...])
np.testing.assert_equal(weight_map[key][0:1, ...], safe_slice[0:1, ...])
np.testing.assert_equal(weight_map[key][..., 2:], safe_slice[..., 2:])
np.testing.assert_equal(weight_map[key][..., 1], safe_slice[..., 1])
np.testing.assert_equal(weight_map[key][:2, ...], safe_slice[:2, ...])
np.testing.assert_equal(weight_map[key][..., :4], safe_slice[..., :4])


class FastSafetensors(unittest.TestCase):
def setUp(self):
super().setUp()
self.weigth_map = {}
self.weight_map = {}
tensors = [
([10, 1, 10], "float32"),
([1, 1, 10], "float32"),
Expand All @@ -38,34 +140,33 @@ def setUp(self):
]
count = 0
for shape, dtype in tensors:
self.weigth_map[f"weight_{count}"] = (np.random.random(shape) * 100).astype(dtype)
self.weight_map[f"weight_{count}"] = (np.random.random(shape) * 100).astype(dtype)
count += 1
print(self.weigth_map)

@skip_platform("win32", "cygwin")
def test_load_file(self):
with tempfile.TemporaryDirectory() as tmpdirname:
path = os.path.join(tmpdirname, "test.safetensors")
save_file(self.weigth_map, path, metadata={"format": "np"})
save_file(self.weight_map, path, metadata={"format": "np"})
sf_load = load_file(path)
fs_sf_load = fast_load_file(path)
for k, v in self.weigth_map.items():
for k, v in self.weight_map.items():
np.testing.assert_equal(v, sf_load[k])
np.testing.assert_equal(v, fs_sf_load[k])

@skip_platform("win32", "cygwin")
def test_safe_open(self):
with tempfile.TemporaryDirectory() as tmpdirname:
path = os.path.join(tmpdirname, "test.safetensors")
save_file(self.weigth_map, path, metadata={"format": "np"})
save_file(self.weight_map, path, metadata={"format": "np"})

with fast_safe_open(path, framework="np") as f:
for key in f.keys():
safe_slice = f.get_slice(key)
# np.testing.assert_equal(self.weigth_map[key][2:1, ...], safe_slice[2:1, ...])
np.testing.assert_equal(self.weigth_map[key][0, ...], safe_slice[0, ...])
np.testing.assert_equal(self.weigth_map[key][0:1, ...], safe_slice[0:1, ...])
np.testing.assert_equal(self.weigth_map[key][..., 2:], safe_slice[..., 2:])
np.testing.assert_equal(self.weigth_map[key][..., 1], safe_slice[..., 1])
np.testing.assert_equal(self.weigth_map[key][:2, ...], safe_slice[:2, ...])
np.testing.assert_equal(self.weigth_map[key][..., :4], safe_slice[..., :4])
# np.testing.assert_equal(self.weight_map[key][2:1, ...], safe_slice[2:1, ...])
np.testing.assert_equal(self.weight_map[key][0, ...], safe_slice[0, ...])
np.testing.assert_equal(self.weight_map[key][0:1, ...], safe_slice[0:1, ...])
np.testing.assert_equal(self.weight_map[key][..., 2:], safe_slice[..., 2:])
np.testing.assert_equal(self.weight_map[key][..., 1], safe_slice[..., 1])
np.testing.assert_equal(self.weight_map[key][:2, ...], safe_slice[:2, ...])
np.testing.assert_equal(self.weight_map[key][..., :4], safe_slice[..., :4])
Loading