diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 2b2c1f405..acdf36246 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -359,22 +359,16 @@ def _get_pointers(sizes, npdtype): """Return a numpy array of byte offsets given a list of sizes. Multiplies values in the sizes array by dtype size (bytes), - and then computes a zero-based prefix scan. + and then computes an exclusive scan to get byte offsets. """ - # create numpy array of desired numpy datatype - pointers = np.array(sizes, dtype=npdtype) + # compute element sizes in bytes + bytesizes = np.array(sizes, dtype=npdtype) + bytesizes *= dtype().itemsize - if len(sizes) > 0: - # scale each element by its dtype size - dtype_size = dtype().itemsize - pointers *= dtype_size - - # in-place prefix scan to compute byte offsets - np.cumsum(pointers, axis=0, out=pointers) - - # zero-base the prefix scan (exclusive scan) - pointers -= pointers[0] + # exclusive scan to get byte offsets + pointers = np.cumsum(bytesizes, axis=0) + pointers -= bytesizes return pointers