Skip to content

Commit

Permalink
Merge pull request #835 from EleutherAI/large_bs_dataloader
Browse files Browse the repository at this point in the history
Fix Large-BS Dataloader Bug
  • Loading branch information
StellaAthena authored Mar 16, 2023
2 parents 1cf5a30 + a7e3ac3 commit 91b72d9
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 14 deletions.
2 changes: 1 addition & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 1ab177a
Default = 7d682df

current git hash of repository

Expand Down
19 changes: 12 additions & 7 deletions megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _build_index_mappings(
doc_idx = _build_doc_idx(documents, num_epochs, np_rng)
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
print_rank_0(
" > elasped time to build and save doc-idx mapping "
" > elapsed time to build and save doc-idx mapping "
"(seconds): {:4f}".format(time.time() - start_time)
)
# sample-idx.
Expand All @@ -174,11 +174,16 @@ def _build_index_mappings(

assert doc_idx.dtype == np.int32
assert sizes.dtype == np.int32
sample_idx = helpers.build_sample_idx(
sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch
)
# sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# num_epochs, tokens_per_epoch)

num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length
if 2 * (num_samples + 1) < np.iinfo(np.int32).max:
sample_idx = helpers.build_sample_idx_int32(
sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch
)
else:
sample_idx = helpers.build_sample_idx_int64(
sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch
)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
print_rank_0(
" > elapsed time to build and save sample-idx mapping "
Expand Down Expand Up @@ -260,7 +265,7 @@ def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch):

# Total number of samples. For -1 see comments in `_num_epochs`.
num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32)
sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int64)

# Index into sample_idx.
sample_index = 0
Expand Down
99 changes: 93 additions & 6 deletions megatron/data/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ void build_blending_indices(py::array_t<uint8_t>& dataset_index,
}
}

py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& doc_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch)
py::array build_sample_idx_int32(const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& doc_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch)
{
/* Sample index (sample_idx) is used for gpt2 like dataset for which
the documents are flattened and the samples are built based on this
Expand Down Expand Up @@ -173,6 +173,92 @@ py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
free_when_done); // numpy array references
}


py::array build_sample_idx_int64(const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& doc_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch)
{
/* Sample index (sample_idx) is used for gpt2 like dataset for which
the documents are flattened and the samples are built based on this
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
starting offset in that document.*/

// Consistency checks.
assert(seq_length > 1);
assert(num_epochs > 0);
assert(tokens_per_epoch > 1);

// Remove bound checks.
auto sizes = sizes_.unchecked<1>();
auto doc_idx = doc_idx_.unchecked<1>();

// Mapping and it's length (1D).
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
int64_t* sample_idx = new int64_t[2 * (num_samples + 1)];

cout << " using:" << endl << std::flush;
cout << " number of documents: " << doc_idx_.shape(0) / num_epochs << endl
<< std::flush;
cout << " number of epochs: " << num_epochs << endl << std::flush;
cout << " sequence length: " << seq_length << endl << std::flush;
cout << " total number of samples: " << num_samples << endl << std::flush;

// Index into sample_idx.
int64_t sample_index = 0;
// Index into doc_idx.
int64_t doc_idx_index = 0;
// Beginning offset for each document.
int32_t doc_offset = 0;
// Start with first document and no offset.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;

while (sample_index <= num_samples) {
// Start with a fresh sequence.
int32_t remaining_seq_length = seq_length + 1;
while (remaining_seq_length != 0) {
// Get the document length.
auto doc_id = doc_idx[doc_idx_index];
auto doc_length = sizes[doc_id] - doc_offset;
// And add it to the current sequence.
remaining_seq_length -= doc_length;
// If we have more than a full sequence, adjust offset and set
// remaining length to zero so we return from the while loop.
// Note that -1 here is for the same reason we have -1 in
// `_num_epochs` calculations.
if (remaining_seq_length <= 0) {
doc_offset += (remaining_seq_length + doc_length - 1);
remaining_seq_length = 0;
} else {
// Otherwise, start from the beginning of the next document.
++doc_idx_index;
doc_offset = 0;
}
}
// Record the sequence.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
}

// Method to deallocate memory.
py::capsule free_when_done(sample_idx, [](void* mem_) {
int64_t* mem = reinterpret_cast<int64_t*>(mem_);
delete[] mem;
});

// Return the numpy array.
const auto byte_size = sizeof(int64_t);
return py::array(std::vector<int64_t>{num_samples + 1, 2}, // shape
{2 * byte_size, byte_size}, // C-style contiguous strides
sample_idx, // the data pointer
free_when_done); // numpy array references
}

inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length,
std::mt19937& rand32_gen)
Expand Down Expand Up @@ -665,6 +751,7 @@ PYBIND11_MODULE(helpers, m)
{
m.def("build_mapping", &build_mapping);
m.def("build_blocks_mapping", &build_blocks_mapping);
m.def("build_sample_idx", &build_sample_idx);
m.def("build_sample_idx_int32", &build_sample_idx_int32);
m.def("build_sample_idx_int64", &build_sample_idx_int64);
m.def("build_blending_indices", &build_blending_indices);
}

0 comments on commit 91b72d9

Please sign in to comment.