Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions megatron/data/indexed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,22 +359,16 @@ def _get_pointers(sizes, npdtype):
"""Return a numpy array of byte offsets given a list of sizes.

Multiplies values in the sizes array by dtype size (bytes),
and then computes a zero-based prefix scan.
and then computes an exclusive scan to get byte offsets.
"""

# create numpy array of desired numpy datatype
pointers = np.array(sizes, dtype=npdtype)
# compute element sizes in bytes
bytesizes = np.array(sizes, dtype=npdtype)
bytesizes *= dtype().itemsize

if len(sizes) > 0:
# scale each element by its dtype size
dtype_size = dtype().itemsize
pointers *= dtype_size

# in-place prefix scan to compute byte offsets
np.cumsum(pointers, axis=0, out=pointers)

# zero-base the prefix scan (exclusive scan)
pointers -= pointers[0]
# exclusive scan to get byte offsets
pointers = np.cumsum(bytesizes, axis=0)
pointers -= bytesizes

return pointers

Expand Down