Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support sft mapdataset #8840

Merged
merged 5 commits into from
Aug 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 272 additions & 0 deletions paddlenlp/data/indexed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import shutil
import struct
import time
from dataclasses import fields
from functools import lru_cache
from itertools import accumulate

Expand Down Expand Up @@ -68,6 +69,19 @@ def make_dataset(path, impl, skip_warmup=False):
return None


def make_sft_dataset(path, dataclass, skip_warmup=False, impl="mmap"):
if impl != "mmap":
raise ValueError("SFT Indexed Dataset only support mmap memory-mapped method temporarily")

print_rank_0(" > building dataset index ...")
start_time = time.time()
sft_indexed_dataset = SFTMMapIndexedDataset(path, dataclass, skip_warmup)
print_rank_0(" > finished creating SFT indexed dataset in {:4f} " "seconds".format(time.time() - start_time))
print_rank_0(" number of samples: {}".format(len(sft_indexed_dataset.doc_idx) - 1))

return sft_indexed_dataset


def dataset_exists(path, impl):
if impl == "mmap":
return MMapIndexedDataset.exists(path)
Expand Down Expand Up @@ -120,6 +134,18 @@ def index_file_path(prefix_path):
return prefix_path + ".idx"


def sft_index_file_path(prefix_path):
return os.path.join(prefix_path, "index.idx")


def sft_data_file_path(prefix_path, dataclass):
file_path_list = []
for field in fields(dataclass):
file_path = os.path.join(prefix_path, f"{field.name}.bin")
file_path_list.append(file_path)
return file_path_list


def data_file_path(prefix_path):
return prefix_path + ".bin"

Expand Down Expand Up @@ -548,13 +574,259 @@ def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))


class SFTMMapIndexedDataset(paddle.io.Dataset):
class Index(object):
_HDR_MAGIC = b"MMIDIDX\x00\x00"

@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, "wb")
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack("<Q", 1))
self._file.write(struct.pack("<B", code(dtype)))

return self

@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers

def write(self, sizes, doc_idx):

pointers = self._get_pointers(sizes)
self._file.write(struct.pack("<Q", len(sizes)))
self._file.write(struct.pack("<Q", len(doc_idx)))

sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes

pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers

doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order="C"))

def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()

return _Writer()

def __init__(self, path, skip_warmup=False):
with open(path, "rb") as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
version = struct.unpack("<Q", stream.read(8))
assert (1,) == version

(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize

self._len = struct.unpack("<Q", stream.read(8))[0]
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()

if not skip_warmup:
print_rank_0(" warming up index mmap file...")
_warmup_mmap_file(path)

self._buffer_mmap = np.memmap(path, mode="r", order="C")
self._buffer = memoryview(self._buffer_mmap)
print_rank_0(" reading sizes...")
self._sizes = np.frombuffer(self._buffer, dtype=np.int32, count=self._len, offset=offset)
print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(
self._buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes
)
print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(
self._buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
)

def __del__(self):
self._buffer_mmap._mmap.close()
del self._buffer_mmap

@property
def dtype(self):
return self._dtype

@property
def sizes(self):
return self._sizes

@property
def doc_idx(self):
return self._doc_idx

@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]

def __len__(self):
return self._doc_count - 1

def __init__(self, path, dataclass, skip_warmup=False):
super().__init__()
self._dataclass = dataclass
self._path = None
self._index = None
self._bin_buffer = None

self._do_init(path, skip_warmup)

def __getstate__(self):
return self._path

def __setstate__(self, state):
self._do_init(state, skip_warmup=True)

def _do_init(self, path, skip_warmup):
self._path = path
if not self.exists(path, self._dataclass):
raise ValueError("Missing file, %s" % (path))

self._index = self.Index(sft_index_file_path(self._path), skip_warmup)
if not skip_warmup:
print_rank_0(" warming up data mmap file...")
for data_file in sft_data_file_path(self._path, self._dataclass):
_warmup_mmap_file(data_file)
print_rank_0(" creating numpy buffer of mmap...")

self._bin_buffer_mmap_dict = {}
self._bin_buffer_dict = {}
for data_file in sft_data_file_path(self._path, self._dataclass):
self._bin_buffer_mmap_dict[data_file] = np.memmap(data_file, mode="r", order="C")
self._bin_buffer_dict[data_file] = memoryview(self._bin_buffer_mmap_dict[data_file])
print_rank_0(" creating memory view of numpy buffer...")

def __del__(self):
for key, value in self._bin_buffer_mmap_dict.items():
value._mmap.close()
for key, value in self._bin_buffer_dict.items():
del value
del self._index

def __len__(self):
return len(self._index)

def __getitem__(self, idx):
def get_index(idx):
doc_idx = self._index.doc_idx
start_sentence, end_sentence = doc_idx[idx], doc_idx[idx + 1]
start_pointers, _ = self._index[start_sentence]
length_list = self._index._sizes[start_sentence:end_sentence]

dataclass_fields = fields(self._dataclass)
dataclass_list = []
sequence_offset = start_pointers
scalar_offset = doc_idx[idx] * np.dtype(self._index.dtype).itemsize

for length in length_list:
field_data = {field.name: [] for field in dataclass_fields}
for field in dataclass_fields:
bin_buffer = self._bin_buffer_dict[os.path.join(self._path, f"{field.name}.bin")]
if field.type != int:
data = np.frombuffer(bin_buffer, dtype=self._index.dtype, count=length, offset=sequence_offset)
field_data[field.name] = data.tolist()
else:
data = np.frombuffer(bin_buffer, dtype=self._index.dtype, count=1, offset=scalar_offset)
field_data[field.name] = int(data[0])

dataclass_list.append(self._dataclass(**field_data))

sequence_offset += length * np.dtype(self._index.dtype).itemsize
scalar_offset += np.dtype(self._index.dtype).itemsize
return dataclass_list

if isinstance(idx, (int, np.integer)):
return get_index(idx)
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
return [get_index(idx) for idx in range(start, stop)]

@property
def sizes(self):
return self._index.sizes

@property
def doc_idx(self):
return self._index.doc_idx

def get_doc_idx(self):
return self._index._doc_idx

def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_

@property
def supports_prefetch(self):
return False

@staticmethod
def exists(path, dataclass):
file_path_list = sft_data_file_path(path, dataclass)
file_path_list.append(sft_index_file_path(path))
for file_path in file_path_list:
if not os.path.exists(file_path):
return False
return True


def make_builder(out_file, impl, save_dtype, loss_mask_file=None):
if impl == "mmap":
return MMapIndexedDatasetBuilder(out_file, dtype=save_dtype, loss_mask_file=loss_mask_file)
else:
return IndexedDatasetBuilder(out_file, dtype=save_dtype)


class SFTMMapIndexedDatasetBuilder(object):
def __init__(self, output_file_dict, dtype):
self._data_file_dict = {}
for key, filename in output_file_dict.items():
self._data_file_dict[key] = open(filename, "wb")
self.output_file_dict = output_file_dict
self._dtype = dtype
self._sizes = []
self._doc_idx = [0]

def add_item(self, sequence):
add_sequence_len = False
for key in self._data_file_dict.keys():
tensor = np.array(getattr(sequence, key), dtype=self._dtype)
if tensor.size > 1 and not add_sequence_len:
self._sizes.append(tensor.size)
add_sequence_len = True
self._data_file_dict[key].write(tensor.tobytes(order="C"))

def end_document(self):
self._doc_idx.append(len(self._sizes))

def finalize(self, index_file):
for key, filename in self._data_file_dict.items():
filename.close()
with SFTMMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes, self._doc_idx)


class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype, loss_mask_file=None):
self._data_file = open(out_file, "wb")
Expand Down
Loading