diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 61ebb0de8..6bb2ded07 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 1ab177a + Default = 7d682df current git hash of repository diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index 76bf87d52..cd6cf8676 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -164,7 +164,7 @@ def _build_index_mappings( doc_idx = _build_doc_idx(documents, num_epochs, np_rng) np.save(doc_idx_filename, doc_idx, allow_pickle=True) print_rank_0( - " > elasped time to build and save doc-idx mapping " + " > elapsed time to build and save doc-idx mapping " "(seconds): {:4f}".format(time.time() - start_time) ) # sample-idx. @@ -174,11 +174,16 @@ def _build_index_mappings( assert doc_idx.dtype == np.int32 assert sizes.dtype == np.int32 - sample_idx = helpers.build_sample_idx( - sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch - ) - # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length, - # num_epochs, tokens_per_epoch) + + num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length + if 2 * (num_samples + 1) < np.iinfo(np.int32).max: + sample_idx = helpers.build_sample_idx_int32( + sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + ) + else: + sample_idx = helpers.build_sample_idx_int64( + sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + ) np.save(sample_idx_filename, sample_idx, allow_pickle=True) print_rank_0( " > elapsed time to build and save sample-idx mapping " @@ -260,7 +265,7 @@ def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch): # Total number of samples. For -1 see comments in `_num_epochs`. num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length - sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32) + sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int64) # Index into sample_idx. sample_index = 0 diff --git a/megatron/data/helpers.cpp b/megatron/data/helpers.cpp index 37a49b523..90488fa61 100644 --- a/megatron/data/helpers.cpp +++ b/megatron/data/helpers.cpp @@ -88,11 +88,11 @@ void build_blending_indices(py::array_t& dataset_index, } } -py::array build_sample_idx(const py::array_t& sizes_, - const py::array_t& doc_idx_, - const int32_t seq_length, - const int32_t num_epochs, - const int64_t tokens_per_epoch) +py::array build_sample_idx_int32(const py::array_t& sizes_, + const py::array_t& doc_idx_, + const int32_t seq_length, + const int32_t num_epochs, + const int64_t tokens_per_epoch) { /* Sample index (sample_idx) is used for gpt2 like dataset for which the documents are flattened and the samples are built based on this @@ -173,6 +173,92 @@ py::array build_sample_idx(const py::array_t& sizes_, free_when_done); // numpy array references } + +py::array build_sample_idx_int64(const py::array_t& sizes_, + const py::array_t& doc_idx_, + const int32_t seq_length, + const int32_t num_epochs, + const int64_t tokens_per_epoch) +{ + /* Sample index (sample_idx) is used for gpt2 like dataset for which + the documents are flattened and the samples are built based on this + 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] + where [..., 0] contains the index into `doc_idx` and [..., 1] is the + starting offset in that document.*/ + + // Consistency checks. + assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); + + // Remove bound checks. + auto sizes = sizes_.unchecked<1>(); + auto doc_idx = doc_idx_.unchecked<1>(); + + // Mapping and it's length (1D). + int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; + int64_t* sample_idx = new int64_t[2 * (num_samples + 1)]; + + cout << " using:" << endl << std::flush; + cout << " number of documents: " << doc_idx_.shape(0) / num_epochs << endl + << std::flush; + cout << " number of epochs: " << num_epochs << endl << std::flush; + cout << " sequence length: " << seq_length << endl << std::flush; + cout << " total number of samples: " << num_samples << endl << std::flush; + + // Index into sample_idx. + int64_t sample_index = 0; + // Index into doc_idx. + int64_t doc_idx_index = 0; + // Beginning offset for each document. + int32_t doc_offset = 0; + // Start with first document and no offset. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + + while (sample_index <= num_samples) { + // Start with a fresh sequence. + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { + // Get the document length. + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the beginning of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + } + + // Method to deallocate memory. + py::capsule free_when_done(sample_idx, [](void* mem_) { + int64_t* mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(int64_t); + return py::array(std::vector{num_samples + 1, 2}, // shape + {2 * byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references +} + inline int32_t get_target_sample_len(const int32_t short_seq_ratio, const int32_t max_length, std::mt19937& rand32_gen) @@ -665,6 +751,7 @@ PYBIND11_MODULE(helpers, m) { m.def("build_mapping", &build_mapping); m.def("build_blocks_mapping", &build_blocks_mapping); - m.def("build_sample_idx", &build_sample_idx); + m.def("build_sample_idx_int32", &build_sample_idx_int32); + m.def("build_sample_idx_int64", &build_sample_idx_int64); m.def("build_blending_indices", &build_blending_indices); }