From 191a96b025074352d30d512bab57b7884db52503 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Wed, 18 Aug 2021 02:52:21 -0500 Subject: [PATCH 01/20] fix: exclusive scan computing pointers list (#68) --- megatron/data/indexed_dataset.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 2b2c1f405..acdf36246 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -359,22 +359,16 @@ def _get_pointers(sizes, npdtype): """Return a numpy array of byte offsets given a list of sizes. Multiplies values in the sizes array by dtype size (bytes), - and then computes a zero-based prefix scan. + and then computes an exclusive scan to get byte offsets. """ - # create numpy array of desired numpy datatype - pointers = np.array(sizes, dtype=npdtype) + # compute element sizes in bytes + bytesizes = np.array(sizes, dtype=npdtype) + bytesizes *= dtype().itemsize - if len(sizes) > 0: - # scale each element by its dtype size - dtype_size = dtype().itemsize - pointers *= dtype_size - - # in-place prefix scan to compute byte offsets - np.cumsum(pointers, axis=0, out=pointers) - - # zero-base the prefix scan (exclusive scan) - pointers -= pointers[0] + # exclusive scan to get byte offsets + pointers = np.cumsum(bytesizes, axis=0) + pointers -= bytesizes return pointers From 5eeae0b802223ce49158f2c11f8fbe182c912220 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Wed, 18 Aug 2021 10:35:12 +0200 Subject: [PATCH 02/20] - Recompute bin/idx using microsoft/Megatron-DeepSpeed (Not changes) - Add a test to stability compared to official repo. --- tests/test_preprocessing.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 3caf9f552..4d27d94b3 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import filecmp import io import json import re @@ -77,3 +77,36 @@ def test_preprocess_data(self): for ext in ["bin", "idx"]: tgt_path = f"{output_prefix}_text_document.{ext}" self.assertTrue(Path(tgt_path).exists(), ) + + def test_process_data_microsoft(self): + """We want to be stable to Microsoft version.""" + src_dir = self.src_dir + data_dir = f"{self.data_dir}/gpt2" + output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) + + input_path = f"{self.tests_dir}/tools/openwebtext-1000.jsonl" + + output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext" + + cmd = f""" + python {src_dir}/tools/preprocess_data.py + --input {input_path} + --output-prefix {output_prefix} + --dataset-impl mmap + --tokenizer-type GPT2BPETokenizer + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab {data_dir}/gpt2-tiny-vocab.json + --append-eod + --workers 2 + """.split() + + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + execute_subprocess_async(cmd, env=self.get_env()) + + for ext in ["bin", "idx"]: + tgt_path = f"{output_prefix}_text_document.{ext}" + ref_path = f"{data_dir}/meg-gpt2-openwebtext_text_document.{ext}" + self.assertTrue(Path(tgt_path).exists(), ) + self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False)) + From 9341269d43cb0ca9b0da3f2ac9d8c60445f53e73 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Wed, 18 Aug 2021 10:36:48 +0200 Subject: [PATCH 03/20] Add openwebtext1000.jsonl to .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 18e7deab8..dfaa172e9 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ # tests # megatron autogenerated indices tests/data/*/*npy +tests/tools/openwebtext-1000.jsonl # macOS .DS_Store From fb274bfe23cf75714698b670e80f8e957d9015f6 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 9 Aug 2021 23:22:41 -0700 Subject: [PATCH 04/20] abstraction to index and randomly access jsonl files --- tools/indexed_json.py | 289 ++++++++++++++++++++++++++++++++ tools/preprocess_dataset_mpi.py | 16 +- 2 files changed, 300 insertions(+), 5 deletions(-) create mode 100644 tools/indexed_json.py diff --git a/tools/indexed_json.py b/tools/indexed_json.py new file mode 100644 index 000000000..67bd57778 --- /dev/null +++ b/tools/indexed_json.py @@ -0,0 +1,289 @@ +import os +import stat +import json +import time +import numpy as np +from mpi4py import MPI + +class IndexedJSON(object): + def __init__(self, filename, comm=None, bufsize=16*1024*1024): + self.filename = filename # JSON file name + self.comm = comm # distributed environment for collective ops + self.bufsize = bufsize # buffer size used while building index + self.numsamples = 0 # number of records in JSON file + self.fh_idx = None # file handle to JSON index file + self.fh_json = None # file handle to JSON file + self.time_index = 0 # record cost to construct index + + # given a JSON file name, compute the name of its index file + filename_idx = self.index_filename(filename) + + # check for index file and create it if it does not exist + exists = False + if self.comm.Get_rank() == 0: + exists = os.path.exists(filename_idx) + exists = self.comm.bcast(exists) + if not exists: + self.create_index(filename, self.bufsize) + + # Identify number of samples in JSON file. + # For now, we can do that using the size of index file. + filesize_idx = self.get_filesize(filename_idx) + self.numsamples = int(filesize_idx / 16) + + # Open the index and the json files for reading. + # Disable buffering to avoid reading extra bytes we won't use. + self.fh_idx = open(filename_idx, "rb", buffering=0) + self.fh_json = open(filename, "rb", buffering=0) +# self.fh_idx = open(filename_idx, "rb") +# self.fh_json = open(filename, "rb") + + def create_shared_file(self, filename): + self.comm.barrier() + + rank = self.comm.Get_rank() + if rank == 0: + f = open(filename, "wb") + self.comm.barrier() + if rank != 0: + f = open(filename, "r+b") + + self.comm.barrier() + return f + + def get_filesize(self, filename): + """Lookup filesize and broadcast to all ranks.""" + filesize = 0 + rank = self.comm.Get_rank() + if rank == 0: + filesize = os.stat(filename)[stat.ST_SIZE] + filesize = self.comm.bcast(filesize) + return filesize + + def get_count(self, val): + """Compute global sum of val across all ranks.""" + inval = np.array([val], dtype=np.int64) + outval = np.zeros_like(inval) + self.comm.Allreduce(inval, outval, op=MPI.SUM) + return outval[0] + + def get_offset(self, val): + """Execute exclusive scan prefix sum of val across ranks.""" + inval = np.array([val], dtype=np.int64) + outval = np.zeros_like(inval) + self.comm.Scan(inval, outval, op=MPI.SUM) + return outval[0] - inval[0] + + def index_filename(self, filename): + """Given the name of a JSON file, return the name of its index file.""" + return filename + '.idx' + + def get_start_end(self, num): + """Given num items, compute and return [start,end) range on each rank.""" + rank = self.comm.Get_rank() + num_ranks = self.comm.Get_size() + + num_per_rank = num // num_ranks + remainder = num % num_ranks + if rank < remainder: + start = (num_per_rank + 1) * rank; + end = start + (num_per_rank + 1) + else: + start = (num_per_rank + 1) * remainder + num_per_rank * (rank - remainder); + end = start + num_per_rank + return start, end + + def create_index(self, filename, bufsize): + """Given a JSON file named dataset.json, write index to dataset.json.idx.""" + + # To compute this index, ranks collective scan the JSON + # file and record the byte offset of newline characters. + # Those byte offsets are accumulated in a temporary index file + # until the entire JSON file has been scanned. The processes + # then read back those byte locations from the temporary file + # to compute the length of each record. Finally for each + # record an (offset,length) pair of int64 types is written into + # the index file to specify the starting offset and length of + # each record in the JSON file. + + time_start = time.time() + rank = self.comm.Get_rank() + numranks = self.comm.Get_size() + + # define file names for the index and the temporary index file + filename_idx = self.index_filename(filename) + filename_tmp = filename_idx + 'tmp' + + # lookup and broadcast size of JSON file to all ranks + filesize = self.get_filesize(filename) + + # create the temporary index file, shared across all ranks + with self.create_shared_file(filename_tmp) as ftmp: + # open and scan the JSON file + recstart = 0 + with open(filename, "rb") as f: + curpos = 0 + while curpos < filesize: + # each rank reads a section of the file + offset = curpos + bufsize * rank + f.seek(offset) + data = f.read(bufsize) + + # scan section for newline chars, and record offset + # of byte immediately following each newline (start of new record) + newlines = [] + pos = 0 + length = len(data) + while pos < length: + found = data.find(b'\n', pos) + if found >= 0: + # We actually store the byte offset to the start + # of the record that would follow the newline char. + newlines.append(offset + found + 1) + + # Update our buffer position and keep scanning. + pos = found + 1 + else: + # No newlines in the remainder of the buffer + break + + # Count number of newline chars we found, + # and compute sum and offset of newlines across ranks. + numrecs = len(newlines) + reccount = self.get_count(numrecs) + recoffset = self.get_offset(numrecs) + + # Store offsets as int64 + vals = np.array(newlines, dtype=np.int64) + + # Write offsets into temporary index file + pos = (recstart + recoffset) * 8 + ftmp.seek(pos) + ftmp.write(vals.tobytes(order='C')) + + # Bump up to next slot in the temporary index file. + recstart += reccount + + # Move on to the next section of the JSON file. + curpos += bufsize * numranks + + # Wait for all ranks to close the file. + self.comm.barrier() + + # Create the actual index file. + with self.create_shared_file(filename_idx) as fidx: + # We'll read the offsets back from the temporary index file. + with open(filename_tmp, "rb") as ftmp: + # Compute the [start,end) range for this rank within the list of offsets. + start, end = self.get_start_end(recstart) + + # Determine how many records this rank is responsible for. + readcount = end - start + if readcount > 0: + # We'll read all offsets in our portion, + # plus one offset that comes immediately before our section. + readcount += 1 + if start > 0: + pos = (start - 1) * 8 + ftmp.seek(pos) + + # Allocate a buffer and read in the offsets + recoffsets = np.zeros(readcount, dtype=np.int64) + if start > 0: + ftmp.readinto(recoffsets) + else: + # We leave the first entry as 0 on rank 0 + ftmp.readinto(recoffsets[1:]) + + # Compute length of each record by computing the difference + # between consecutive offset values. Also ignore the first + # offset when writing our (offset,length) pairs. + lengths = recoffsets[1:] - recoffsets[:-1] + offsets = recoffsets[:-1] + + # Prepare list of (offset,length) pairs for writing. + # Store are int64 types. + vals = np.zeros((readcount - 1, 2), dtype=np.int64) + vals[:,0] = offsets + vals[:,1] = lengths + + # Write our portion of the index values. + if readcount > 0: + pos = start * 16 + fidx.seek(pos) + fidx.write(vals.tobytes(order='C')) + + # Wait for all ranks to finish writing to the index file. + self.comm.barrier() + + # Can now delete the temporary index file. + if rank == 0: + os.remove(filename_tmp) + + # Wait for everyone again and record how long it took. + self.comm.barrier() + time_end = time.time() + self.time_index = time_end - time_start + + def __str__(self): + return (f"IndexedJSON (\n" + f" file: {self.filename}\n" + f" rows: {self.numsamples}\n" + f" secs: {self.time_index})") + + def __len__(self): + """Return number of samples (lines) in the JSON file.""" + return self.numsamples + + def __getitem__(self, idx): + """Given a sample id, return the sample as a python object parsed from JSON string.""" + # read offset and length of record from the index + # seek to offset in JSON file and read the record + buf = self.read(idx) + + # convert json record into a python dictionary + try: + #entry = json.loads(buf.decode("utf-8").strip()) + entry = json.loads(buf) + return entry + except: + # TODO: throw exception instead? + return None + + def __get__(self, idx): + return self.getitem(idx) + + def index(self, idx): + """Given an sample id, return (offset, size) tuple of location of sample in the JSON file.""" + assert idx < self.numsamples + + # seek to the right spot in the index file for the given sample id + offset_idx = idx * 16 + self.fh_idx.seek(offset_idx) + + # read offset and length of record from the index + vals = np.zeros(2, dtype=np.int64) + self.fh_idx.readinto(vals) + offset = vals[0] + size = vals[1] + + return offset, size + + def pread(self, offset, size): + """Read size bytes at the given offset in the JSON file and return as a buffer.""" + # seek to offset in JSON file and read the record + self.fh_json.seek(offset) + buf = self.fh_json.read(size) + return buf + + def read(self, idx): + """Given a sample id, read sample from the file and return as a buffer.""" + # read offset and length of record from the index + # seek to offset in JSON file and read the record + offset, size = self.index(idx) + return self.pread(offset, size) + + def size(self, idx): + """Given a sample id, return the number of bytes of that sample as stored in the JSON file.""" + offset, size = self.index(idx) + return size diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index 331764a29..a3037a669 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -64,6 +64,8 @@ from datasets import config, logging, load_dataset from datasets.utils.file_utils import OfflineModeIsEnabled +from indexed_json import IndexedJSON + from megatron.tokenizer import build_tokenizer from megatron.data.indexed_dataset import data_file_path, index_file_path, make_builder, best_fitting_dtype, gather_files_dist from megatron.data.distdata import DistData @@ -606,11 +608,15 @@ def main(): startup_start = time.time() # load the dataset - dset, err = load_dset(args) - if dset is None: - if err is not None: - raise err - return + if args.input.endswith(".json"): + # assume file is JSONL format + dset = IndexedJSON(args.input, args.mpi_comm) + else: + dset, err = load_dset(args) + if dset is None: + if err is not None: + raise err + return if args.rank == 0: print(dset) msg(f"Processing features: {args.columns}") From d428c025552387b003fc33a1bb12b954ae51e9d5 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Wed, 18 Aug 2021 18:03:22 -0700 Subject: [PATCH 05/20] rebase on parallel merge, replace mpi4py with distdata class --- tools/indexed_json.py | 63 +++++++++------------------------ tools/preprocess_dataset_mpi.py | 2 +- 2 files changed, 18 insertions(+), 47 deletions(-) diff --git a/tools/indexed_json.py b/tools/indexed_json.py index 67bd57778..17fb21211 100644 --- a/tools/indexed_json.py +++ b/tools/indexed_json.py @@ -3,12 +3,11 @@ import json import time import numpy as np -from mpi4py import MPI class IndexedJSON(object): - def __init__(self, filename, comm=None, bufsize=16*1024*1024): + def __init__(self, filename, distctx, bufsize=16*1024*1024): self.filename = filename # JSON file name - self.comm = comm # distributed environment for collective ops + self.distctx = distctx # distributed environment for collective ops self.bufsize = bufsize # buffer size used while building index self.numsamples = 0 # number of records in JSON file self.fh_idx = None # file handle to JSON index file @@ -20,9 +19,9 @@ def __init__(self, filename, comm=None, bufsize=16*1024*1024): # check for index file and create it if it does not exist exists = False - if self.comm.Get_rank() == 0: + if self.distctx.rank == 0: exists = os.path.exists(filename_idx) - exists = self.comm.bcast(exists) + exists = self.distctx.bcast(exists, root=0) if not exists: self.create_index(filename, self.bufsize) @@ -38,50 +37,22 @@ def __init__(self, filename, comm=None, bufsize=16*1024*1024): # self.fh_idx = open(filename_idx, "rb") # self.fh_json = open(filename, "rb") - def create_shared_file(self, filename): - self.comm.barrier() - - rank = self.comm.Get_rank() - if rank == 0: - f = open(filename, "wb") - self.comm.barrier() - if rank != 0: - f = open(filename, "r+b") - - self.comm.barrier() - return f - def get_filesize(self, filename): """Lookup filesize and broadcast to all ranks.""" filesize = 0 - rank = self.comm.Get_rank() - if rank == 0: + if self.distctx.rank == 0: filesize = os.stat(filename)[stat.ST_SIZE] - filesize = self.comm.bcast(filesize) + filesize = self.distctx.bcast(filesize, root=0) return filesize - def get_count(self, val): - """Compute global sum of val across all ranks.""" - inval = np.array([val], dtype=np.int64) - outval = np.zeros_like(inval) - self.comm.Allreduce(inval, outval, op=MPI.SUM) - return outval[0] - - def get_offset(self, val): - """Execute exclusive scan prefix sum of val across ranks.""" - inval = np.array([val], dtype=np.int64) - outval = np.zeros_like(inval) - self.comm.Scan(inval, outval, op=MPI.SUM) - return outval[0] - inval[0] - def index_filename(self, filename): """Given the name of a JSON file, return the name of its index file.""" return filename + '.idx' def get_start_end(self, num): """Given num items, compute and return [start,end) range on each rank.""" - rank = self.comm.Get_rank() - num_ranks = self.comm.Get_size() + rank = self.distctx.rank + num_ranks = self.distctx.numranks num_per_rank = num // num_ranks remainder = num % num_ranks @@ -107,8 +78,8 @@ def create_index(self, filename, bufsize): # each record in the JSON file. time_start = time.time() - rank = self.comm.Get_rank() - numranks = self.comm.Get_size() + rank = self.distctx.rank + numranks = self.distctx.numranks # define file names for the index and the temporary index file filename_idx = self.index_filename(filename) @@ -118,7 +89,7 @@ def create_index(self, filename, bufsize): filesize = self.get_filesize(filename) # create the temporary index file, shared across all ranks - with self.create_shared_file(filename_tmp) as ftmp: + with self.distctx.open(filename_tmp) as ftmp: # open and scan the JSON file recstart = 0 with open(filename, "rb") as f: @@ -150,8 +121,8 @@ def create_index(self, filename, bufsize): # Count number of newline chars we found, # and compute sum and offset of newlines across ranks. numrecs = len(newlines) - reccount = self.get_count(numrecs) - recoffset = self.get_offset(numrecs) + reccount = self.distctx.sum(numrecs) + recoffset = self.distctx.exscan(numrecs) # Store offsets as int64 vals = np.array(newlines, dtype=np.int64) @@ -168,10 +139,10 @@ def create_index(self, filename, bufsize): curpos += bufsize * numranks # Wait for all ranks to close the file. - self.comm.barrier() + self.distctx.barrier() # Create the actual index file. - with self.create_shared_file(filename_idx) as fidx: + with self.distctx.open(filename_idx) as fidx: # We'll read the offsets back from the temporary index file. with open(filename_tmp, "rb") as ftmp: # Compute the [start,end) range for this rank within the list of offsets. @@ -214,14 +185,14 @@ def create_index(self, filename, bufsize): fidx.write(vals.tobytes(order='C')) # Wait for all ranks to finish writing to the index file. - self.comm.barrier() + self.distctx.barrier() # Can now delete the temporary index file. if rank == 0: os.remove(filename_tmp) # Wait for everyone again and record how long it took. - self.comm.barrier() + self.distctx.barrier() time_end = time.time() self.time_index = time_end - time_start diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index a3037a669..686f9be33 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -610,7 +610,7 @@ def main(): # load the dataset if args.input.endswith(".json"): # assume file is JSONL format - dset = IndexedJSON(args.input, args.mpi_comm) + dset = IndexedJSON(args.input, args.distctx) else: dset, err = load_dset(args) if dset is None: From bd6f41fb1d407d0545d05ef9627c2bab48f5e237 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Thu, 19 Aug 2021 15:46:33 -0700 Subject: [PATCH 06/20] look for extension .jsonl --- tools/indexed_json.py | 2 +- tools/preprocess_data_dist.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/indexed_json.py b/tools/indexed_json.py index 17fb21211..430647783 100644 --- a/tools/indexed_json.py +++ b/tools/indexed_json.py @@ -65,7 +65,7 @@ def get_start_end(self, num): return start, end def create_index(self, filename, bufsize): - """Given a JSON file named dataset.json, write index to dataset.json.idx.""" + """Given a JSON file named dataset.jsonl, write index to dataset.jsonl.idx.""" # To compute this index, ranks collective scan the JSON # file and record the byte offset of newline characters. diff --git a/tools/preprocess_data_dist.py b/tools/preprocess_data_dist.py index ebf3b1efe..4db5bdab9 100644 --- a/tools/preprocess_data_dist.py +++ b/tools/preprocess_data_dist.py @@ -612,7 +612,7 @@ def main(): startup_start = time.time() # load the dataset - if args.input.endswith(".json"): + if args.input.endswith(".jsonl"): # assume file is JSONL format dset = IndexedJSON(args.input, args.distctx) else: From a96e2ab31a6fa1994e6c3df6eb8875faf9b7c6aa Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 21 Aug 2021 08:39:54 -0700 Subject: [PATCH 07/20] [testing] fixes for pt-1.10 (#71) * fixes for pt-1.10 * switch to torch_assert_equal wrapper --- megatron/testing_utils.py | 4 ++++ tests/test_activations.py | 10 +++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/megatron/testing_utils.py b/megatron/testing_utils.py index d402b59df..62991c044 100644 --- a/megatron/testing_utils.py +++ b/megatron/testing_utils.py @@ -207,6 +207,10 @@ def get_gpu_count(): else: return 0 +def torch_assert_equal(actual, expected): + """ emulates the removed torch.testing.assert_equal """ + torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0) + def get_tests_dir(append_path=None): """ diff --git a/tests/test_activations.py b/tests/test_activations.py index 85c949f4a..98b91d376 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -5,7 +5,7 @@ from torch.nn import functional as F from megatron.model.activations import liglu, geglu, reglu, swiglu -from megatron.testing_utils import set_seed +from megatron.testing_utils import set_seed, torch_assert_equal class TestActivations(unittest.TestCase): @@ -27,16 +27,16 @@ def test_shapes(self): def test_liglu(self): expected = self.x1 * self.x2 - torch.testing.assert_equal(liglu(self.x), expected) + torch_assert_equal(liglu(self.x), expected) def test_geglu(self): expected = self.x1 * F.gelu(self.x2) - torch.testing.assert_equal(geglu(self.x), expected) + torch_assert_equal(geglu(self.x), expected) def test_reglu(self): expected = self.x1 * F.relu(self.x2) - torch.testing.assert_equal(reglu(self.x), expected) + torch_assert_equal(reglu(self.x), expected) def test_swiglu(self): expected = self.x1 * F.silu(self.x2) - torch.testing.assert_equal(swiglu(self.x), expected) + torch_assert_equal(swiglu(self.x), expected) From b5a029d9b2470b66ec56b7dbe1230799ba07e0f1 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Mon, 23 Aug 2021 04:23:13 +0900 Subject: [PATCH 08/20] Expose GLU activations as arguments (#69) * feat: expose glu activations as argument * chore: rename activations -> glu_activations * refactor: use lookup dict instead of `getattr()` * refactor: mv lookup dict to `glu_activations.py` * chore: rm unnecessary default arg * test: add bf16 test; gelu in `test_training_all()` * Update megatron/testing_utils.py Co-authored-by: Stas Bekman * refactor: use `require_torch_bf16` decorator * chore: comment out bf16 test uncomment in the future when torch supports gelu kernels for bf16 * consistent style * fix look up table * better grouping * fix: replace hard coded options with `GLU_ACTIVATIONS` Co-authored-by: Stas Bekman Co-authored-by: Stas Bekman --- megatron/arguments.py | 5 ++++ .../{activations.py => glu_activations.py} | 12 +++++++-- megatron/model/transformer.py | 5 +++- megatron/testing_utils.py | 26 +++++++++++++++++++ tests/test_activations.py | 18 +++++++++---- tests/test_training.py | 1 + 6 files changed, 59 insertions(+), 8 deletions(-) rename megatron/model/{activations.py => glu_activations.py} (84%) diff --git a/megatron/arguments.py b/megatron/arguments.py index 326c948ee..ba1d0c9a1 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -22,6 +22,7 @@ import deepspeed from megatron.enums import PositionEmbeddingType +from megatron.model.glu_activations import GLU_ACTIVATIONS def parse_args(extra_args_provider=None, defaults={}, @@ -313,6 +314,10 @@ def _add_network_size_args(parser): default=PositionEmbeddingType.absolute, help='Define position embedding type ("absolute" | "rotary"). "absolute" by default.' ) + group.add_argument('--glu-activation', type=str, + choices=GLU_ACTIVATIONS.keys(), + help='GLU activations to use.' + ) return parser diff --git a/megatron/model/activations.py b/megatron/model/glu_activations.py similarity index 84% rename from megatron/model/activations.py rename to megatron/model/glu_activations.py index 82ccdf098..9e0eb5b29 100644 --- a/megatron/model/activations.py +++ b/megatron/model/glu_activations.py @@ -7,10 +7,10 @@ class _GLUBaseModule(nn.Module): def __init__(self, activation_fn): super().__init__() self.activation_fn = activation_fn - + def forward(self, x): # dim=-1 breaks in jit for pt<1.10 - x1, x2 = x.chunk(2, dim=(x.ndim-1)) + x1, x2 = x.chunk(2, dim=(x.ndim - 1)) return x1 * self.activation_fn(x2) @@ -38,3 +38,11 @@ def __init__(self): geglu = torch.jit.script(GEGLU()) reglu = torch.jit.script(ReGLU()) swiglu = torch.jit.script(SwiGLU()) + + +GLU_ACTIVATIONS = { + "geglu": geglu, + "liglu": liglu, + "reglu": reglu, + "swiglu": swiglu, +} diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e7612b76f..473b8e06b 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -30,6 +30,7 @@ import deepspeed +from .glu_activations import GLU_ACTIVATIONS from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb # flags required to enable jit fusion kernels @@ -76,7 +77,9 @@ def __init__(self, init_method, output_layer_init_method): self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu - if args.openai_gelu: + if args.glu_activation: + self.activation_func = GLU_ACTIVATIONS[args.glu_activation] + elif args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: self.activation_func = erf_gelu diff --git a/megatron/testing_utils.py b/megatron/testing_utils.py index 62991c044..801d74d9b 100644 --- a/megatron/testing_utils.py +++ b/megatron/testing_utils.py @@ -25,6 +25,7 @@ import random from distutils.util import strtobool from io import StringIO +from packaging import version from pathlib import Path from typing import Iterator, Union from unittest import mock @@ -212,6 +213,31 @@ def torch_assert_equal(actual, expected): torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0) +def is_torch_bf16_available(): + # from https://github.com/huggingface/transformers/blob/26eb566e43148c80d0ea098c76c3d128c0281c16/src/transformers/file_utils.py#L301 + if is_torch_available(): + import torch + if not torch.cuda.is_available() or torch.version.cuda is None: + return False + if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8: + return False + if int(torch.version.cuda.split(".")[0]) < 11: + return False + if not version.parse(torch.__version__) >= version.parse("1.09"): + return False + return True + else: + return False + + +def require_torch_bf16(test_case): + """Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.9.""" + if not is_torch_bf16_available(): + return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.9")(test_case) + else: + return test_case + + def get_tests_dir(append_path=None): """ Args: diff --git a/tests/test_activations.py b/tests/test_activations.py index 98b91d376..a1763d7b4 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -4,7 +4,7 @@ import torch from torch.nn import functional as F -from megatron.model.activations import liglu, geglu, reglu, swiglu +from megatron.model.glu_activations import GLU_ACTIVATIONS, geglu, liglu, reglu, swiglu from megatron.testing_utils import set_seed, torch_assert_equal @@ -17,13 +17,13 @@ def setUp(self): self.num_channels = random.randint(1, 384) * 2 self.x = torch.randn(self.batch_size, self.seq_len, self.num_channels) self.x1, self.x2 = self.x.chunk(2, dim=-1) + # glu should halve the last dimension + self.output_shape = [self.batch_size, self.seq_len, self.num_channels // 2] def test_shapes(self): - # glu should halve the last dimension - output_shape = [self.batch_size, self.seq_len, self.num_channels // 2] - for activation_fn in [liglu, geglu, reglu, swiglu]: + for activation_fn in GLU_ACTIVATIONS.values(): output = activation_fn(self.x) - self.assertEqual(list(output.shape), output_shape) + self.assertEqual(list(output.shape), self.output_shape) def test_liglu(self): expected = self.x1 * self.x2 @@ -40,3 +40,11 @@ def test_reglu(self): def test_swiglu(self): expected = self.x1 * F.silu(self.x2) torch_assert_equal(swiglu(self.x), expected) + + # from megatron.testing_utils import require_torch_bf16 + # @require_torch_bf16 + # def test_bf16_jit(self): + # x_bf16 = self.x.to(torch.bfloat16) + # for activation_fn in GLU_ACTIVATIONS.values(): + # output = activation_fn(x_bf16) + # self.assertEqual(list(output.shape), self.output_shape) diff --git a/tests/test_training.py b/tests/test_training.py index 7306615f1..f0e45beaa 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -90,6 +90,7 @@ def test_training_all(self): --eval-interval 10 --eval-iters 5 --checkpoint-activations + --glu-activation geglu --exit-interval {exit_interval} --merge-file {data_dir}/gpt2-tiny-merges.txt From 4255845d0685dd59cb70c0026db45f92b7ae4848 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sun, 22 Aug 2021 12:41:32 -0700 Subject: [PATCH 09/20] fix circular import (#72) --- megatron/arguments.py | 4 ++-- tests/test_activations.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index ba1d0c9a1..5f4e2b53f 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -22,7 +22,7 @@ import deepspeed from megatron.enums import PositionEmbeddingType -from megatron.model.glu_activations import GLU_ACTIVATIONS +import megatron def parse_args(extra_args_provider=None, defaults={}, @@ -315,7 +315,7 @@ def _add_network_size_args(parser): help='Define position embedding type ("absolute" | "rotary"). "absolute" by default.' ) group.add_argument('--glu-activation', type=str, - choices=GLU_ACTIVATIONS.keys(), + choices=megatron.model.glu_activations.GLU_ACTIVATIONS.keys(), help='GLU activations to use.' ) diff --git a/tests/test_activations.py b/tests/test_activations.py index a1763d7b4..34097ad22 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -43,7 +43,7 @@ def test_swiglu(self): # from megatron.testing_utils import require_torch_bf16 # @require_torch_bf16 - # def test_bf16_jit(self): + # def test_bf16_jit(self): # x_bf16 = self.x.to(torch.bfloat16) # for activation_fn in GLU_ACTIVATIONS.values(): # output = activation_fn(x_bf16) From 350fb903935f900f400f140d157165d5fa4d7645 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 25 Aug 2021 12:21:40 -0700 Subject: [PATCH 10/20] [codecarbon] integration (#15) * add codecarbon * switch to offline * rework to also restart the tracker at each checkpoint save to ensure as little as possible data is lost * adjust API to match https://github.com/bigscience-workshop/codecarbon/pull/1 * fix logging * new implementation based on https://github.com/mlco2/codecarbon/pull/236 * add test * update requirements --- megatron/arguments.py | 2 ++ megatron/checkpointing.py | 10 +++++-- megatron/global_vars.py | 59 ++++++++++++++++++++++++++++++++++++++- megatron/training.py | 6 ++++ requirements.txt | 2 ++ tests/test_training.py | 1 + 6 files changed, 77 insertions(+), 3 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 5f4e2b53f..de16fd5b3 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -457,6 +457,8 @@ def _add_training_args(parser): help='Run optimizer on CPU') group.add_argument('--cpu_torch_adam', action='store_true', help='Use Torch Adam as optimizer on CPU.') + group.add_argument('--codecarbon-dir', type=str, default=None, + help='Write CodeCarbon logs to this directory.') return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 829fb1101..f7328dcbb 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -22,6 +22,7 @@ import torch +from megatron.global_vars import codecarbon_tracker_flush from megatron import (get_args, mpu, print_rank_0, @@ -135,7 +136,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint() - + # Optimizer stuff. if not args.no_save_optim: if optimizer is not None: @@ -183,6 +184,11 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): if torch.distributed.is_initialized(): torch.distributed.barrier() + # since the code can be exited or aborted in various places we use the checkpoint saving as + # a save saving point for the codecarbon tracker. If the program doesn't run to its normal + # end, then only the data since the last saved checkpoint will be lost. + codecarbon_tracker_flush() + def _transpose_first_dim(t, num_splits, num_splits_first, model): input_shape = t.size() # We use a self_attention module but the values extracted aren't @@ -417,7 +423,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True def load_biencoder_checkpoint(model, only_query_model=False, only_context_model=False, custom_load_path=None): """ - selectively load retrieval models for indexing/retrieving + selectively load retrieval models for indexing/retrieving from saved checkpoints """ diff --git a/megatron/global_vars.py b/megatron/global_vars.py index 5f3c28463..b5dcac4d9 100644 --- a/megatron/global_vars.py +++ b/megatron/global_vars.py @@ -19,6 +19,8 @@ import sys import time +from pathlib import Path + import torch from megatron.tokenizer import build_tokenizer @@ -29,10 +31,10 @@ _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None _GLOBAL_TOKENIZER = None _GLOBAL_TENSORBOARD_WRITER = None +_GLOBAL_CODECARBON_TRACKER = None _GLOBAL_ADLR_AUTORESUME = None _GLOBAL_TIMERS = None - def get_args(): """Return arguments.""" _ensure_var_is_initialized(_GLOBAL_ARGS, 'args') @@ -63,6 +65,10 @@ def get_tensorboard_writer(): to check if it is initialized.""" return _GLOBAL_TENSORBOARD_WRITER +def get_codecarbon_tracker(): + """Return codecarbon tracker. It can be None so no need + to check if it is initialized.""" + return _GLOBAL_CODECARBON_TRACKER def get_adlr_autoresume(): """ADLR autoresume object. It can be None so no need @@ -86,6 +92,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={}, if args.vocab_file or args.tokenizer_name_or_path: _ = _build_tokenizer(args) _set_tensorboard_writer(args) + _set_codecarbon_tracker(args) _set_adlr_autoresume(args) _set_timers() @@ -145,6 +152,56 @@ def _set_tensorboard_writer(args): 'no TensorBoard logs will be written.', flush=True) +def _set_codecarbon_tracker(args): + global _GLOBAL_CODECARBON_TRACKER + if not hasattr(args, 'codecarbon_dir'): + return + + import codecarbon + if args.rank == 0: + print('> setting codecarbon ...') + + output_dir = args.codecarbon_dir + output_file = f"emissions-{args.rank:03d}.csv" + log_level = "warning" + country_iso_code="FRA" + + Path(output_dir).mkdir(parents=True, exist_ok=True) + _GLOBAL_CODECARBON_TRACKER = codecarbon.OfflineEmissionsTracker( + output_dir=output_dir, + output_file=output_file, + log_level=log_level, + country_iso_code=country_iso_code, + ) + + +def codecarbon_tracker_start(): + global _GLOBAL_CODECARBON_TRACKER + if _GLOBAL_CODECARBON_TRACKER is None: + return + + #print("CC START") + _GLOBAL_CODECARBON_TRACKER.start() + + +def codecarbon_tracker_stop(): + global _GLOBAL_CODECARBON_TRACKER + if _GLOBAL_CODECARBON_TRACKER is None: + return + + #print("CC STOP") + _GLOBAL_CODECARBON_TRACKER.stop() + + +def codecarbon_tracker_flush(): + global _GLOBAL_CODECARBON_TRACKER + if _GLOBAL_CODECARBON_TRACKER is None: + return + + #print("CC FLUSH") + _GLOBAL_CODECARBON_TRACKER.flush() + + def _set_adlr_autoresume(args): """Initialize ADLR autoresume.""" global _GLOBAL_ADLR_AUTORESUME diff --git a/megatron/training.py b/megatron/training.py index 21ef13b94..f66544dff 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -51,6 +51,7 @@ from megatron.schedules import forward_backward_pipelining_without_interleaving from megatron.schedules import forward_backward_pipelining_with_interleaving from megatron.utils import report_memory, flops_calculator +from megatron.global_vars import codecarbon_tracker_start, codecarbon_tracker_stop import deepspeed @@ -95,6 +96,8 @@ def pretrain(train_valid_test_dataset_provider, initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args_defaults) + codecarbon_tracker_start() + # Adjust the startup time so it reflects the largest value. # This will be closer to what scheduler will see (outside of # image ... launches. @@ -162,6 +165,9 @@ def pretrain(train_valid_test_dataset_provider, test_data_iterator, model, 0, True) + codecarbon_tracker_stop() + + def update_train_iters(args): # For iteration-based training, we don't need to do anything diff --git a/requirements.txt b/requirements.txt index 234a9902a..a96ff42be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ regex numpy transformers # git+https://github.com/microsoft/DeepSpeed.git@big-science +# edit to a higher SHA or future release if needed +git+git://github.com/mlco2/codecarbon.git@03479b695a771c28df6b877a809f5af3eb9ef3b8 diff --git a/tests/test_training.py b/tests/test_training.py index f0e45beaa..85d1a537b 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -98,6 +98,7 @@ def test_training_all(self): --save {output_dir}/checkpoints --load {output_dir}/checkpoints --data-path {data_dir}/meg-gpt2-openwebtext_text_document + --codecarbon-dir {output_dir}/codecarbon --tensorboard-dir {output_dir}/tensorboard --tensorboard-queue-size 5 --log-timers-to-tensorboard From 3488d0bc063dda2b6e12404e62825d07f33dcf4a Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Wed, 25 Aug 2021 13:58:37 -0700 Subject: [PATCH 11/20] add progress messages --- tools/indexed_json.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tools/indexed_json.py b/tools/indexed_json.py index 430647783..328400e5e 100644 --- a/tools/indexed_json.py +++ b/tools/indexed_json.py @@ -5,7 +5,7 @@ import numpy as np class IndexedJSON(object): - def __init__(self, filename, distctx, bufsize=16*1024*1024): + def __init__(self, filename, distctx, bufsize=16*1024*1024, progress=10.0): self.filename = filename # JSON file name self.distctx = distctx # distributed environment for collective ops self.bufsize = bufsize # buffer size used while building index @@ -13,6 +13,7 @@ def __init__(self, filename, distctx, bufsize=16*1024*1024): self.fh_idx = None # file handle to JSON index file self.fh_json = None # file handle to JSON file self.time_index = 0 # record cost to construct index + self.progress = progress # number of secs between progress msgs (0 to disable) # given a JSON file name, compute the name of its index file filename_idx = self.index_filename(filename) @@ -88,9 +89,14 @@ def create_index(self, filename, bufsize): # lookup and broadcast size of JSON file to all ranks filesize = self.get_filesize(filename) + # if progress messages are enabled, print a header about what we're doing + if rank == 0 and self.progress > 0.0: + print(f"Indexing '{filename}' of {filesize} bytes ...", flush=True) + # create the temporary index file, shared across all ranks with self.distctx.open(filename_tmp) as ftmp: # open and scan the JSON file + time_next = time_start + self.progress recstart = 0 with open(filename, "rb") as f: curpos = 0 @@ -138,6 +144,18 @@ def create_index(self, filename, bufsize): # Move on to the next section of the JSON file. curpos += bufsize * numranks + # this can take a while, so print progress messages if enabled + if rank == 0 and self.progress > 0.0: + time_now = time.time() + if time_now > time_next: + time_next = time_now + self.progress + elapsed = time_now - time_start + percent = curpos * 100.0 / filesize if filesize > 0 else 0.0 + byterate = curpos / elapsed / (1024.0 * 1024.0) if elapsed > 0.0 else 0.0 + remaining = (100.0 - percent) * elapsed / percent if percent > 0.0 else 0.0 + print(f"Scanned {curpos} of {filesize} bytes ({percent:0.2f}%) in {int(elapsed)} secs, " + f"{byterate:0.3f} MB/s, {int(remaining)} secs left ...", flush=True) + # Wait for all ranks to close the file. self.distctx.barrier() @@ -196,6 +214,10 @@ def create_index(self, filename, bufsize): time_end = time.time() self.time_index = time_end - time_start + # if progress messages are enabled, print a summary + if rank == 0 and self.progress > 0.0: + print(f"Indexed '{filename}' in {int(self.time_index)} seconds", flush=True) + def __str__(self): return (f"IndexedJSON (\n" f" file: {self.filename}\n" From 1305fe93241d6e7b9d28e6ab4ca49fcd93ace14d Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Wed, 25 Aug 2021 17:06:35 -0700 Subject: [PATCH 12/20] rebuild index if mtime is old --- tools/indexed_json.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/tools/indexed_json.py b/tools/indexed_json.py index 328400e5e..9cc598169 100644 --- a/tools/indexed_json.py +++ b/tools/indexed_json.py @@ -18,12 +18,20 @@ def __init__(self, filename, distctx, bufsize=16*1024*1024, progress=10.0): # given a JSON file name, compute the name of its index file filename_idx = self.index_filename(filename) - # check for index file and create it if it does not exist - exists = False - if self.distctx.rank == 0: - exists = os.path.exists(filename_idx) - exists = self.distctx.bcast(exists, root=0) + # determine whether we need to create the index + create_index = False + exists = self.test_exists(filename_idx) if not exists: + # index file does not exist + create_index = True + else: + # index exists, but rebuild the index if the original file + # has been modified since the index was built + mtime = self.get_mtime(filename) + mtime_idx = self.get_mtime(filename_idx) + if mtime > mtime_idx: + create_index = True + if create_index: self.create_index(filename, self.bufsize) # Identify number of samples in JSON file. @@ -38,6 +46,14 @@ def __init__(self, filename, distctx, bufsize=16*1024*1024, progress=10.0): # self.fh_idx = open(filename_idx, "rb") # self.fh_json = open(filename, "rb") + def test_exists(self, filename): + """Test whether file exists and broadcast result to all ranks.""" + exists = False + if self.distctx.rank == 0: + exists = os.path.exists(filename) + exists = self.distctx.bcast(exists, root=0) + return exists + def get_filesize(self, filename): """Lookup filesize and broadcast to all ranks.""" filesize = 0 @@ -46,6 +62,14 @@ def get_filesize(self, filename): filesize = self.distctx.bcast(filesize, root=0) return filesize + def get_mtime(self, filename): + """Lookup file mtime and broadcast to all ranks.""" + mtime = 0 + if self.distctx.rank == 0: + mtime = os.stat(filename)[stat.ST_MTIME] + mtime = self.distctx.bcast(mtime, root=0) + return mtime + def index_filename(self, filename): """Given the name of a JSON file, return the name of its index file.""" return filename + '.idx' @@ -222,7 +246,7 @@ def __str__(self): return (f"IndexedJSON (\n" f" file: {self.filename}\n" f" rows: {self.numsamples}\n" - f" secs: {self.time_index})") + f")") def __len__(self): """Return number of samples (lines) in the JSON file.""" From e96df7dc54686c68a87699318463f52d7daca018 Mon Sep 17 00:00:00 2001 From: Thomas Wang <24695242+thomasw21@users.noreply.github.com> Date: Thu, 26 Aug 2021 17:03:10 +0200 Subject: [PATCH 13/20] Check cardon directory is not None (#74) --- megatron/global_vars.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/global_vars.py b/megatron/global_vars.py index b5dcac4d9..1d94e8cf7 100644 --- a/megatron/global_vars.py +++ b/megatron/global_vars.py @@ -154,7 +154,7 @@ def _set_tensorboard_writer(args): def _set_codecarbon_tracker(args): global _GLOBAL_CODECARBON_TRACKER - if not hasattr(args, 'codecarbon_dir'): + if not hasattr(args, 'codecarbon_dir') or args.codecarbon_dir is None: return import codecarbon From 3fd48db52937148989a77dc8dccacb255777cfee Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 26 Aug 2021 10:25:25 -0700 Subject: [PATCH 14/20] [CI] start workflow (#75) * start workflow * fix --- .github/workflows/main.yml | 66 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 .github/workflows/main.yml diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 000000000..f2f75fe95 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,66 @@ +name: do-the-job +on: workflow_dispatch +jobs: + start-runner: + name: Start self-hosted EC2 runner + runs-on: ubuntu-latest + outputs: + label: ${{ steps.start-ec2-runner.outputs.label }} + ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: ${{ secrets.AWS_REGION }} + - name: Start EC2 runner + id: start-ec2-runner + uses: machulav/ec2-github-runner@v2 + with: + mode: start + github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} + ec2-image-id: ami-0c7563df049759e8b + ec2-instance-type: p3.8xlarge + subnet-id: subnet-3502b45e + security-group-id: sg-e8f46d9d + aws-resource-tags: > # optional, requires additional permissions + [ + {"Key": "Name", "Value": "ec2-github-runner"}, + {"Key": "GitHubRepository", "Value": "${{ github.repository }}"} + ] + do-the-job: + name: Do the job on the runner + needs: start-runner # required to start the main job when the runner is ready + runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Hello World + run: echo 'Hello World!' + - name: Test + run: pwd; ls -l + - name: GPU Test + run: nvidia-smi + + stop-runner: + name: Stop self-hosted EC2 runner + needs: + - start-runner # required to get output from the start-runner job + - do-the-job # required to wait when the main job is done + runs-on: ubuntu-latest + if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: ${{ secrets.AWS_REGION }} + - name: Stop EC2 runner + uses: machulav/ec2-github-runner@v2 + with: + mode: stop + github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} + label: ${{ needs.start-runner.outputs.label }} + ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} From 6d88ae20400e657ea5f7cdf6d9af8f45461e9971 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 26 Aug 2021 10:32:07 -0700 Subject: [PATCH 15/20] [CI] wip (#76) * start workflow * fix * fix * Update .github/workflows/main.yml Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com> * Update .github/workflows/main.yml Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com> Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com> --- .github/workflows/main.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f2f75fe95..7d8fa3054 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -13,7 +13,7 @@ jobs: with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ secrets.AWS_REGION }} + aws-region: us-east-2 - name: Start EC2 runner id: start-ec2-runner uses: machulav/ec2-github-runner@v2 @@ -56,7 +56,7 @@ jobs: with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ secrets.AWS_REGION }} + aws-region: us-east-2 - name: Stop EC2 runner uses: machulav/ec2-github-runner@v2 with: From 6bcac1fd81576224a7dceff04d5bd8d303bd4874 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Thu, 26 Aug 2021 11:14:10 -0700 Subject: [PATCH 16/20] store index values in network byte order --- tools/indexed_json.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tools/indexed_json.py b/tools/indexed_json.py index 9cc598169..34d6911fc 100644 --- a/tools/indexed_json.py +++ b/tools/indexed_json.py @@ -221,10 +221,12 @@ def create_index(self, filename, bufsize): vals[:,1] = lengths # Write our portion of the index values. + # We write values to the index file in network byte order, + # so that the file can be read correctly on any system. if readcount > 0: pos = start * 16 fidx.seek(pos) - fidx.write(vals.tobytes(order='C')) + fidx.write(vals.astype(">i8").tobytes(order='C')) # Wait for all ranks to finish writing to the index file. self.distctx.barrier() @@ -274,12 +276,13 @@ def index(self, idx): """Given an sample id, return (offset, size) tuple of location of sample in the JSON file.""" assert idx < self.numsamples - # seek to the right spot in the index file for the given sample id + # Seek to the right spot in the index file for the given sample id. offset_idx = idx * 16 self.fh_idx.seek(offset_idx) - # read offset and length of record from the index - vals = np.zeros(2, dtype=np.int64) + # Read offset and length of record from the index. + # Values in the index file are stored in network byte order. + vals = np.zeros(2, dtype=">i8") self.fh_idx.readinto(vals) offset = vals[0] size = vals[1] From 813d068349541edf35d2049bcbd9e51c5ca2f16c Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Thu, 26 Aug 2021 12:18:17 -0700 Subject: [PATCH 17/20] add magic value and format version number to index file --- tools/indexed_json.py | 95 +++++++++++++++++++++++++++++++------------ 1 file changed, 70 insertions(+), 25 deletions(-) diff --git a/tools/indexed_json.py b/tools/indexed_json.py index 34d6911fc..9ebf12aed 100644 --- a/tools/indexed_json.py +++ b/tools/indexed_json.py @@ -1,6 +1,7 @@ import os import stat import json +import struct import time import numpy as np @@ -16,35 +17,42 @@ def __init__(self, filename, distctx, bufsize=16*1024*1024, progress=10.0): self.progress = progress # number of secs between progress msgs (0 to disable) # given a JSON file name, compute the name of its index file - filename_idx = self.index_filename(filename) + self.filename_idx = self.index_filename(self.filename) # determine whether we need to create the index create_index = False - exists = self.test_exists(filename_idx) + exists = self.test_exists(self.filename_idx) if not exists: # index file does not exist create_index = True else: # index exists, but rebuild the index if the original file # has been modified since the index was built - mtime = self.get_mtime(filename) - mtime_idx = self.get_mtime(filename_idx) + mtime = self.get_mtime(self.filename) + mtime_idx = self.get_mtime(self.filename_idx) if mtime > mtime_idx: create_index = True if create_index: - self.create_index(filename, self.bufsize) - - # Identify number of samples in JSON file. - # For now, we can do that using the size of index file. - filesize_idx = self.get_filesize(filename_idx) - self.numsamples = int(filesize_idx / 16) + self.create_index(self.filename, self.bufsize) # Open the index and the json files for reading. # Disable buffering to avoid reading extra bytes we won't use. - self.fh_idx = open(filename_idx, "rb", buffering=0) - self.fh_json = open(filename, "rb", buffering=0) -# self.fh_idx = open(filename_idx, "rb") -# self.fh_json = open(filename, "rb") + self.fh_idx = open(self.filename_idx, "rb", buffering=0) + self.fh_json = open(self.filename, "rb", buffering=0) +# self.fh_idx = open(self.filename_idx, "rb") +# self.fh_json = open(self.filename, "rb") + + self.read_index_header() + + # Identify number of samples in JSON file. + # For now, we can do that using the size of index file. + self.numsamples = None + if self.idx_version == 1: + # version 1 has a 16-byte header + # followed by a list of (offset, length) pairs of uint64 + header_size = 16 + filesize_idx = self.get_filesize(self.filename_idx) + self.numsamples = int((filesize_idx - header_size) / 16) def test_exists(self, filename): """Test whether file exists and broadcast result to all ranks.""" @@ -70,6 +78,32 @@ def get_mtime(self, filename): mtime = self.distctx.bcast(mtime, root=0) return mtime + def read_index_header(self): + """Read header from index file and check its version.""" + # Rank 0 reads the header, and bcasts its version + version = None + if self.distctx.rank == 0: + try: + # Seek to the front of the file + self.fh_idx.seek(0) + + # Read the magic valud and check that it matches what we expect + magic = self.fh_idx.read(8) + if magic == b'INDXJSON': + # Good magic value, now read file format version number + buf = self.fh_idx.read(8) + if len(buf) == 8: + version = struct.unpack(">Q", buf)[0] + except Exception as e: + pass + + # Get version from rank 0 (should be None on any error) + self.idx_version = self.distctx.bcast(version, root=0) + + # Check that we have a version number that we support + if self.idx_version != 1: + raise ValueError("Unknown index file format version '{self.idx_version}' in file '{self.filename_idx}'") + def index_filename(self, filename): """Given the name of a JSON file, return the name of its index file.""" return filename + '.idx' @@ -185,6 +219,12 @@ def create_index(self, filename, bufsize): # Create the actual index file. with self.distctx.open(filename_idx) as fidx: + # Rank 0 writes the index file header. + if self.distctx.rank == 0: + fidx.write(b'INDXJSON') # use 8-byte magic value of "INDXJSON" + fidx.write(struct.pack(">Q", 1)) # file format version number in network byte order + data_offset = 16 + # We'll read the offsets back from the temporary index file. with open(filename_tmp, "rb") as ftmp: # Compute the [start,end) range for this rank within the list of offsets. @@ -224,7 +264,7 @@ def create_index(self, filename, bufsize): # We write values to the index file in network byte order, # so that the file can be read correctly on any system. if readcount > 0: - pos = start * 16 + pos = data_offset + start * 16 fidx.seek(pos) fidx.write(vals.astype(">i8").tobytes(order='C')) @@ -276,18 +316,23 @@ def index(self, idx): """Given an sample id, return (offset, size) tuple of location of sample in the JSON file.""" assert idx < self.numsamples - # Seek to the right spot in the index file for the given sample id. - offset_idx = idx * 16 - self.fh_idx.seek(offset_idx) + if self.idx_version == 1: + # Version 1 has a 16-byte header followed by a + # list of (offset, length) pairs of uint64 + + # Seek to the right spot in the index file for the given sample id. + header_size = 16 + offset_idx = header_size + idx * 16 + self.fh_idx.seek(offset_idx) - # Read offset and length of record from the index. - # Values in the index file are stored in network byte order. - vals = np.zeros(2, dtype=">i8") - self.fh_idx.readinto(vals) - offset = vals[0] - size = vals[1] + # Read offset and length of record from the index. + # Values in the index file are stored in network byte order. + vals = np.zeros(2, dtype=">i8") + self.fh_idx.readinto(vals) + offset = vals[0] + size = vals[1] - return offset, size + return offset, size def pread(self, offset, size): """Read size bytes at the given offset in the JSON file and return as a buffer.""" From 972211163608818fe9e5ba821246f18d0a5dc264 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Thu, 26 Aug 2021 14:22:58 -0500 Subject: [PATCH 18/20] distributed merge of per-rank Megatron data files (#55) * add parallel merge using mpi * handle case where some ranks might have 0 items * add inclusive scan prefix sum * report more timing info * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> * rename total size variable for clarity * move translation to bin/idx file names a level deeper * parallel merge for cached dataset * add alltrue function * move collectives to new distdata class, add torch.distributed * drop unused prefix_sum function * allow ranks to pass a list of files to be merged * check that input dataset files exist * fix: using wrong doc_idx list for mmap * move init dist and collectives to distdata class * add --merge option, move parallel/serial to their own functions * Update megatron/data/distdata.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> * drop extraneous numpy tolist calls * rename self.MPI to mpi4py * handle case where no ranks have elements in their file * rename tokenize_start to time_start * drop unrelated comment in distdata.min * add comment why pointers_shift is not None and add assert * note why pointers uses sizes count and offset values * can just rely on rank 0 for the leading 0 element * add write_list function * determine element size * add checks for consistent element_size values * check that at least one rank has a file to merge * assert that torch backend is gloo or mpi * add collectives for assert and raise * rename to allassert and allraise_if * check dtype instead of element_size * add uint32 to element_sizes table * infer dtype from files being merged * add write_header function to indexed dataset classes * call write_header internally from IndexedDataset classes * return number of bytes written from write calls * move scatterv to distdata class * add functions to format status and error messages * defer merge_files_dist to future PR * open files using with, refresh comments * rely on default torch datatypes * fix some status messages from preprocess script * fix: exclusive scan computing pointers list * fix: exclusive scan to compute mmap pointers list * note about seek * rename preprocess_dataset_mpi.py to preprocess_data_dist.py * update usage comments at top of script * restore commented print_rank_0 statements * restore status message in mmap merge_file_ * drop mpi4py, sad :( * add test case for parallel merge * add preprocess_data_dist test for serial merge * improve error handling * refactor get_pointers code * bug fix in exscan * further refactor get_pointers * move exscan collective for pointers outside of try block * clarify some comments * include string 1k in name of test files * use temporary file for index * fix: implement scatterv from torch.distributed.scatter * switch to pad method in torch.nn.functional * return data received in scatterv as new tensor * raise exception if conflicting scratch and merge options * use allraise method from distdata in preprocess_data_dist Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> --- megatron/data/distdata.py | 220 ++++++++ megatron/data/indexed_dataset.py | 517 +++++++++++++++++- tests/test_preprocessing.py | 66 ++- ...dataset_mpi.py => preprocess_data_dist.py} | 397 +++++++------- 4 files changed, 992 insertions(+), 208 deletions(-) create mode 100644 megatron/data/distdata.py rename tools/{preprocess_dataset_mpi.py => preprocess_data_dist.py} (61%) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py new file mode 100644 index 000000000..659d51bb9 --- /dev/null +++ b/megatron/data/distdata.py @@ -0,0 +1,220 @@ +import os +import numpy as np + +import torch +import torch.nn.functional as F +import torch.distributed as dist + +class DistDataError(Exception): + """Defines an empty exception to throw when some other rank hit a real exception.""" + pass + +class DistData(object): + def __init__(self, backend='gloo'): + assert backend in ['gloo', 'mpi'], f"torch.distributed backend '{backend}' is not supported, valid options are 'gloo' or 'mpi'" + + dist.init_process_group(backend, init_method="env://") + + # lookup our process rank and the group size + self.rank = dist.get_rank() + self.numranks = dist.get_world_size() + + def allassert(self, cond, msg): + """Check that cond is True on all ranks, assert with msg everywhere if not. + + To prevent deadlocks in cases where an assertion might only fail on one rank, + this executes an allreduce to ensure that if any rank finds that an assertion + has been violated, all ranks fail an assertion check. + The condition must be true on all ranks for this not to assert. + """ + alltrue = self.alltrue(cond) + assert alltrue, msg + + def allraise_if(self, err): + """Raise exception if err is not None on any rank. + + Similarly to allassert, this raises an exception on all ranks if err + is set to an exception on any rank. Rank(s) where err is not None + re-raise err as exception, and ranks where err is None raise DistDataError. + Thus all ranks raise an exception if any rank has an active exception, + which helps avoid deadlocks in cases where an exception may be raised + on a subset of ranks. + """ + alltrue = self.alltrue(err is None) + if not alltrue: + # At least one rank raised an exception. + # Re-raise the actual exception if this rank threw one. + if err is not None: + raise err + + # TODO: is there a better exception to use here? + # On other ranks, raise an "empty" exception to indicate + # that we're only failing because someone else did. + raise DistDataError + + def barrier(self): + """Globally synchronize all processes""" + dist.barrier() + + def bcast(self, val, root): + """Broadcast a scalar value from root to all ranks""" + vals = [val] + dist.broadcast_object_list(vals, src=root) + return vals[0] + + def scatterv_(self, invals: np.array, counts: list, root:int=0): + """Scatter int64 values from invals according to counts array, return received portion in a new tensor""" + + self.allassert(len(counts) == self.numranks, + f"Length of counts list {len(counts)} does not match number of ranks {self.numranks}") + + # Define list of tensors to scatter on the root. + # torch.distributed.scatter requires each tensor to be the same shape, + # so find the max size across all count values and pad. + max_size = max(counts) + scatterlist = None + if self.rank == root: + slices = list(torch.split(torch.from_numpy(invals), counts)) + scatterlist = [F.pad(s, (0, max_size - len(s))) for s in slices] + + # Receive a tensor of the max count size from the root, + # then copy values into output numpy array, which may be smaller. + recvtensor = torch.zeros(max_size, dtype=torch.int64) + dist.scatter(recvtensor, scatterlist, src=root) + return recvtensor[:counts[self.rank]] + + def alltrue(self, val): + """Returns True if all procs input True, False otherwise""" + # torch.dist does not support reductions with bool types + # so we cast to int and cast the result back to bool + tensor = torch.tensor([int(val)], dtype=torch.int32) + dist.all_reduce(tensor, op=dist.ReduceOp.BAND) + return bool(tensor[0]) + + def sum(self, val): + """Compute sum of a scalar val, and return total on all ranks.""" + tensor = torch.tensor([val]) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return tensor[0] + + def exscan(self, val: int): + """Compute prefix sum (exclusive scan) of int64 val, and return offset of each rank.""" + # torch.distributed doesn't have a scan, so fallback to allreduce + tensor = torch.zeros(self.numranks, dtype=torch.int64) + tensor[self.rank:] = val + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return int(tensor[self.rank]) - val + + def min(self, val): + """Return minimum of scalar val to all ranks.""" + tensor = torch.tensor([val]) + dist.all_reduce(tensor, op=dist.ReduceOp.MIN) + return tensor[0] + + def minrank(self, cond): + """Find first rank whose condition is True, return that rank if any, None otherwise.""" + minrank = self.numranks + if cond: + minrank = self.rank + minrank = self.min(minrank) + + if minrank < self.numranks: + return minrank + return None + + def bcast_first(self, val): + """Broadcast val from first rank where it is not None, return val if any, None otherwise""" + # Find the first rank with a valid value. + minrank = self.minrank(val is not None) + + # If there is no rank with a valid value, return None + if minrank is None: + return None + + # Otherwise broadcast the value from the first valid rank. + val = self.bcast(val, root=minrank) + return val + + def all_sum_(self, vals: np.array): + """Sums values in numpy array vals element-wise and update vals in place with final result on all ranks""" + # Builds torch.tensor with from_numpy to use same underlying memory as numpy array. + tensor = torch.from_numpy(vals) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + + def open(self, filename, truncate=None): + """Create, truncate, and open a file shared by all ranks.""" + + # Don't truncate existing file until all ranks reach this point + self.barrier() + + # We'll capture any exception in this variable + err = None + + # Rank 0 creates and truncates file. + if self.rank == 0: + try: + f = open(filename, 'wb') + + # Some file systems like GPFS deliver faster write speed + # if the file size is known before data is written to the file. + if truncate is not None: + f.truncate(truncate) + + except Exception as e: + err = e + + # Verify that rank 0 created the file + self.allraise_if(err) + + # Wait for rank 0 to open (and truncate) file, + # then have all ranks open file for writing. + if self.rank != 0: + try: + f = open(filename, 'r+b') + except Exception as e: + err = e + + # Verify that all ranks successfully opened the file + self.allraise_if(err) + + return f + + def remove(self, filename): + """Remove a shared file.""" + + # Don't remove the file until all are ready + self.barrier() + + # We'll capture any exception in this variable + err = None + + # Rank 0 removes the file if it exists. + if self.rank == 0: + try: + if os.path.exists(filename): + os.remove(filename) + except Exception as e: + err = e + + # Verify that rank 0 successfully removed the file. + self.allraise_if(err) + + def rename(self, srcfile, destfile): + """Rename a shared file.""" + + # Don't rename until all are ready + self.barrier() + + # We'll capture any exception in this variable + err = None + + # Rank 0 renames the file. + if self.rank == 0: + try: + if os.path.exists(srcfile): + os.rename(srcfile, destfile) + except Exception as e: + err = e + + # Verify that the rename succeeded + self.allraise_if(err) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index acdf36246..025e9e333 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -12,6 +12,7 @@ from functools import lru_cache import os +import stat import shutil import struct from itertools import accumulate @@ -268,6 +269,7 @@ class IndexedDatasetBuilder(object): element_sizes = { np.uint8: 1, np.int8: 1, + np.uint16: 2, np.int16: 2, np.int32: 4, np.int64: 8, @@ -275,6 +277,22 @@ class IndexedDatasetBuilder(object): np.double: 8 } + @staticmethod + def write_header(fout, dtype, numdata, numsize, numdoc): + """Writes header for cached indexed dataset to given file handle, return number of bytes written.""" + startpos = fout.tell() + + fout.write(IndexedDataset._HDR_MAGIC) + fout.write(struct.pack(' [0, 10, 30, 35] + if arr.size > 1: + arr[1:] = arr[:-1] + if arr.size > 0: + arr[0] = 0 + + +def get_pointers_with_total(sizes, elemsize, dtype): + """Return a numpy array of type np.dtype giving the byte offsets. + + Multiplies values in the sizes array by elemsize (bytes), + and then computes an exclusive scan to get byte offsets. + Returns the total number of bytes as second item in a tuple. + """ + + # scale values in sizes array by elemsize to get sizes in bytes + pointers = np.array(sizes, dtype=dtype) + pointers *= elemsize + np.cumsum(pointers, axis=0, out=pointers) + + # get total number of bytes from all sizes (last element) + bytes_last = pointers[-1] if len(sizes) > 0 else 0 + + # convert to byte offsets + exscan_from_cumsum_(pointers) + + return pointers, bytes_last + + class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b'MMIDIDX\x00\x00' + @staticmethod + def write_header(fout, dtype, numsizes, numdocs): + """Writes header for mmap indexed dataset to given file handle, return number of bytes written.""" + startpos = fout.tell() + + fout.write(MMapIndexedDataset.Index._HDR_MAGIC) + fout.write(struct.pack(' 0, "All ranks have no input files to merge" + + # Check that files are all of the same index type + indexstr = gather_files_dist_check_impltype(filelist, distctx) + + # Concatenate the data files + gather_files_dist_bin(filemain, filelist, distctx) + + # Combine index files into a single index file + if indexstr == "cached": + gather_files_dist_idx_cached(filemain, filelist, distctx) + elif indexstr == "mmap": + gather_files_dist_idx_mmap(filemain, filelist, distctx) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 4d27d94b3..3723f97bc 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -78,6 +78,13 @@ def test_preprocess_data(self): tgt_path = f"{output_prefix}_text_document.{ext}" self.assertTrue(Path(tgt_path).exists(), ) + def compare_meg_data_files(self, tgt, ref): + for ext in ["bin", "idx"]: + tgt_path = f"{tgt}.{ext}" + ref_path = f"{ref}.{ext}" + self.assertTrue(Path(tgt_path).exists(), ) + self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False)) + def test_process_data_microsoft(self): """We want to be stable to Microsoft version.""" src_dir = self.src_dir @@ -104,9 +111,58 @@ def test_process_data_microsoft(self): # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die execute_subprocess_async(cmd, env=self.get_env()) - for ext in ["bin", "idx"]: - tgt_path = f"{output_prefix}_text_document.{ext}" - ref_path = f"{data_dir}/meg-gpt2-openwebtext_text_document.{ext}" - self.assertTrue(Path(tgt_path).exists(), ) - self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False)) + self.compare_meg_data_files(f"{output_prefix}_text_document", f"{data_dir}/meg-gpt2-openwebtext_text_document") + + def test_process_data_dist_microsoft(self): + """We want to be stable to Microsoft version.""" + src_dir = self.src_dir + data_dir = f"{self.data_dir}/gpt2" + output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) + + output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext_1k" + + cmd = f""" + python -m torch.distributed.launch --nproc_per_node 2 {src_dir}/tools/preprocess_data_dist.py + --input openwebtext-10k + --count 1000 + --output-prefix {output_prefix} + --dataset-impl mmap + --tokenizer-type GPT2BPETokenizer + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab {data_dir}/gpt2-tiny-vocab.json + --append-eod + """.split() + + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + execute_subprocess_async(cmd, env=self.get_env()) + + self.compare_meg_data_files(f"{output_prefix}_text_document", f"{data_dir}/meg-gpt2-openwebtext_text_document") + + def test_process_data_dist_serial_microsoft(self): + """We want to be stable to Microsoft version.""" + src_dir = self.src_dir + data_dir = f"{self.data_dir}/gpt2" + output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) + + output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext_1k" + + cmd = f""" + python -m torch.distributed.launch --nproc_per_node 2 {src_dir}/tools/preprocess_data_dist.py + --input openwebtext-10k + --count 1000 + --merge serial + --output-prefix {output_prefix} + --dataset-impl mmap + --tokenizer-type GPT2BPETokenizer + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab {data_dir}/gpt2-tiny-vocab.json + --append-eod + """.split() + + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + execute_subprocess_async(cmd, env=self.get_env()) + + self.compare_meg_data_files(f"{output_prefix}_text_document", f"{data_dir}/meg-gpt2-openwebtext_text_document") diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_data_dist.py similarity index 61% rename from tools/preprocess_dataset_mpi.py rename to tools/preprocess_data_dist.py index 3cd366160..a2c9c0970 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_data_dist.py @@ -20,26 +20,28 @@ from datasets import load_dataset dset = load_dataset('openwebtext', split='train') -The implementation can use `mpi4py` or `torch.distributed` for node communication, and it assumes that -files are written to a global file system, such that one process +The implementation uses `torch.distributed` for inter-process communication, +and it assumes that files are written to a global file system, such that one process can read a file written by another process. A list of sample index values from the source dataset are selected -by rank 0 and broadcast to all ranks. +by rank 0 and scattered to all ranks. Each process tokenizes a subset of samples and writes its output to a part file. -After all ranks have finished, rank 0 merges and deletes the part files. +After all ranks have finished, the part files are merged into a final output file. + +One may optionally use storage local to each process to store the part file. +For example, on a Linux cluster, one might write the part file to /dev/shm. To run: -mpiexec -np 320 python preprocess_dataset_mpi.py \ - --input openwebtext \ - --shuffle \ - --seed 100 \ - --output-prefix openwebtext-bert \ - --vocab bert-large-uncased-vocab.txt \ - --dataset-impl mmap \ - --tokenizer-type BertWordPieceLowerCase \ - --split-sentences +python -m torch.distributed.launch --nproc_per_node 40 --nnodes 8 \ + preprocess_data_dist.py \ + --input openwebtext \ + --output-prefix openwebtext-bert \ + --vocab bert-large-uncased-vocab.txt \ + --dataset-impl mmap \ + --tokenizer-type BertWordPieceLowerCase \ + --split-sentences """ import argparse @@ -54,7 +56,6 @@ import random import torch -import torch.distributed as dist try: import nltk nltk_available = True @@ -65,7 +66,15 @@ from datasets.utils.file_utils import OfflineModeIsEnabled from megatron.tokenizer import build_tokenizer -from megatron.data.indexed_dataset import data_file_path, index_file_path, make_builder, best_fitting_dtype +from megatron.data.indexed_dataset import data_file_path, index_file_path, make_builder, best_fitting_dtype, gather_files_dist +from megatron.data.distdata import DistData + +def msg(msg, flush=False): + timestamp = time.strftime("%Y-%m-%dT%H:%M:%S") + print(f"{timestamp}: {msg}", flush=flush) + +def msgerr(msg, flush=False): + print(f"ERROR: {msg}", flush=flush) # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): @@ -92,7 +101,7 @@ def __init__(self, args): if self.args.split_sentences: if not nltk_available: - print("NLTK is not available to split sentences.") + msgerr("NLTK is not available to split sentences.") exit() splitter = nltk.load("tokenizers/punkt/english.pickle") if self.args.keep_newlines: @@ -160,41 +169,52 @@ def get_args(): choices=['lazy', 'cached', 'mmap']) group = parser.add_argument_group(title='runtime') - group.add_argument('--mpi4py', action='store_true', - help='Assume script has been launched as an MPI job, and use MPI for communication.') - group.add_argument('--torch-backend', type=str, default='gloo', choices = ['gloo', 'mpi'], + group.add_argument('--torch-backend', type=str, default='gloo', choices=['gloo', 'mpi'], help='Select torch.distributed backend.') group.add_argument('--local_rank', type=int, default=None, help='Local rank of calling process on its node (from torch.distributed.launch).') + group.add_argument('--merge', type=str, default='parallel', choices=['parallel', 'serial', 'both'], + help=('Method to merge intermediate per-rank files into the final data files. ' + 'With "parallel", each rank writes directly to the final files, ' + 'while rank 0 copies data from all per-rank files with "serial". ' + 'A parallel merge can be faster, but for correctness, it requires the underlying file system ' + 'to support parallel write operations to a file that is shared among multiple processes. ' + 'One can choose "both" for testing purposes, in which case the final files written ' + 'by the parallel method are given an additional ".par" extension.')) + group.add_argument('--scratch', type=str, default=None, + help=('Path to local storage on compute nodes to write per-rank files before merging, like /dev/shm. ' + 'One can only use this option with a parallel merge.')) group.add_argument('--log-interval', type=int, default=30, help='Seconds between progress updates (0 to disable)') args = parser.parse_args() args.keep_empty = False - if args.tokenizer_type.lower().startswith('bert'): - if not args.split_sentences: - print("Bert tokenizer detected, are you sure you don't want to split sentences?") + # initialize our distributed environment + args.distctx = DistData(backend=args.torch_backend) - args.level = "document" - if args.split_sentences: - args.level = "sentence" + # some functions like build_tokenizer use args.rank to filter stdout messages + args.rank = args.distctx.rank + args.numranks = args.distctx.numranks # some default/dummy values for the tokenizer - args.rank = 0 args.make_vocab_size_divisible_by = 128 args.tensor_model_parallel_size = 1 args.vocab_extra_ids = 0 - # use mpi4py instead of torch.distributed if requested - args.use_mpi = False - if args.mpi4py: - try: - from mpi4py import MPI - args.MPI = MPI - args.use_mpi = True - except: - print(f"ERROR: mpi4py requested, but failed to import, falling back to torch.distributed.", flush=True) + if args.tokenizer_type.lower().startswith('bert'): + if not args.split_sentences: + if args.rank == 0: + msg("Bert tokenizer detected, are you sure you don't want to split sentences?") + + args.level = "document" + if args.split_sentences: + args.level = "sentence" + + # TODO: perhaps more user friendly to disable scratch and print a warning? + # check that serial merge is not attempted with scratch + if args.scratch is not None and args.merge != 'parallel': + raise ValueError("The --scratch option is only valid with --merge=parallel") return args @@ -202,65 +222,6 @@ def format_byterate(byterate): mbps = byterate / (1024.0 * 1024.0) return f"{mbps:0.3f} MB/s" -def init_distributed(args): - """Determine which distributed runtime to use and connect up processes""" - # select our distributed runtime (MPI or torch.distributed) - # lookup our process rank and the group size - # some functions like build_tokenizer use args.rank to filter stdout messages - if args.use_mpi: - args.mpi_comm = args.MPI.COMM_WORLD - args.rank = args.mpi_comm.Get_rank() - args.numranks = args.mpi_comm.Get_size() - else: - dist.init_process_group(args.torch_backend, init_method="env://") - args.rank = dist.get_rank() - args.numranks = dist.get_world_size() - -def barrier(args): - """Globally synchronize all processes.""" - if args.use_mpi: - args.mpi_comm.barrier() - else: - dist.barrier() - -def scatterv_(args, invals, counts, outval, root=0): - """Scatter int64 values from invals according to counts array, receive values in outval""" - assert len(counts) == args.numranks, f"Length of counts list {len(counts)} does not match number of ranks {args.numranks}" - assert outval.shape == (counts[args.rank],), f"Rank {args.rank}: output buffer is of shape {outval.shape}, expected {(counts[args.rank],)}" - - if args.use_mpi: - counts = np.array(counts) - displs = np.cumsum(counts) - counts - args.mpi_comm.Scatterv([invals, counts, displs, args.MPI.INT64_T], outval, root=root) - else: - scatterlist = None - if args.rank == root: - scatterlist = list(torch.split(torch.from_numpy(invals), counts)) - outtensor = torch.from_numpy(outval) - dist.scatter(outtensor, scatterlist, src=root) - -def all_sum_(args, vals): - """Sums values in vals element-wise and updates vals with final result on all ranks""" - if args.use_mpi: - outval = np.zeros_like(vals) - args.mpi_comm.Allreduce(vals, outval, op=args.MPI.SUM) - vals[:] = outval - else: - tensor = torch.from_numpy(vals) - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - -def all_true(args, val): - """Returns True if all procs input True, False otherwise""" - if args.use_mpi: - inval = np.array([val], dtype=np.bool_) - outval = np.zeros_like(inval) - args.mpi_comm.Allreduce(inval, outval, op=args.MPI.LAND) - return bool(outval[0]) - else: - tensor = torch.tensor([int(val)], dtype=torch.int32) - dist.all_reduce(tensor, op=dist.ReduceOp.BAND) - return bool(tensor[0]) - def load_dset(args): # Avoid downloading datasets unless explicitly requested. # We allow the user to override this behavior if they set $HF_DATASETS_OFFLINE. @@ -280,32 +241,29 @@ def load_dset(args): if args.rank != 0: logging.set_verbosity(logging.ERROR) + time_start = time.time() + # Load the specified HuggingFace dataset. # Give rank 0 a head start in case the dataset is not already cached. - success = True err = None dsetname = args.input if args.rank == 0: - print(f"Opening dataset {dsetname}") + msg(f"Opening dataset {dsetname}") try: dset = load_dataset(dsetname, split=args.split, keep_in_memory=None) except OfflineModeIsEnabled as e: - print(f"ERROR: Cannot download '{dsetname}' since running in offline mode.") - print(f"ERROR: If the dataset is large, it may be more efficient to download with a single process:") - print(f"ERROR: from datasets import load_dataset") - print(f"ERROR: dset = load_dataset('{dsetname}')") - print(f"ERROR: Alternatively, one can force this script to download by setting $HF_DATASETS_OFFLINE=0", flush=True) - success = False + msgerr(f"Cannot download '{dsetname}' since running in offline mode.") + msgerr(f"If the dataset is large, it may be more efficient to download with a single process:") + msgerr(f" from datasets import load_dataset") + msgerr(f" dset = load_dataset('{dsetname}')") + msgerr(f"Alternatively, one can force this script to download by setting $HF_DATASETS_OFFLINE=0", flush=True) err = e except Exception as e: - print("ERROR: Unexpected error:", sys.exc_info()[0], flush=True) - success = False + msgerr(f"Unexpected error: {sys.exc_info()[0]}", flush=True) err = e # determine whether rank 0 succeeded in loading the dataset - success = all_true(args, success) - if not success: - return None, err + args.distctx.allraise_if(err) # Rank 0 succeeded, attempt to load dataset on all other ranks. # This should load from cache now. @@ -314,18 +272,17 @@ def load_dset(args): dset = load_dataset(dsetname, split=args.split, keep_in_memory=None) except Exception as e: # this print might be noisy, but better than nothing - print("ERROR: Unexpected error:", sys.exc_info()[0], flush=True) - success = False + msgerr(f"Unexpected error: {sys.exc_info()[0]}", flush=True) err = e # verify that all ranks loaded the dataset - success = all_true(args, success) - if not success: - if args.rank == 0: - print(f"ERROR: At least one process failed to load {dsetname}", flush=True) - return None, err + args.distctx.allraise_if(err) + + time_end = time.time() + if args.rank == 0: + msg(f"Seconds to load dataset: {time_end - time_start}", flush=True) - return dset, err + return dset def get_num_samples(args, dset_size): """Given a dataset size and optional count argument, return number of samples to process.""" @@ -342,6 +299,7 @@ def select_sample_list(args, dset_size): # create sample index list on rank 0, # optionally shuffle the list, # and optionally limit the sample count + time_select = time.time() idxlist = None if args.rank == 0: # generate a list of all index values @@ -360,12 +318,19 @@ def select_sample_list(args, dset_size): # get a list of the number of elements each rank will hold counts = get_proc_counts(num_samples, args.numranks) - # allocate space to hold its portion of the list - idx = np.zeros(counts[args.rank], np.int64) - # scatter sample index values from rank 0 to all procs # based on distribution defined in counts list - scatterv_(args, idxlist, counts, idx, root=0) + time_bcast = time.time() + idx = args.distctx.scatterv_(idxlist, counts, root=0) + + args.distctx.barrier() + time_end = time.time() + if args.rank == 0: + msg(f"Select index stats:") + msg(f" Shuffle: {args.shuffle}") + msg(f" Seconds to select: {time_bcast - time_select}") + msg(f" Seconds to broadcast: {time_end - time_bcast}") + msg(f" Seconds total: {time_end - time_select}", flush=True) return idx @@ -376,6 +341,11 @@ def get_proc_counts(num, num_ranks): def get_filename(args, key, rank=None): pathname = args.output_prefix + # redirect per-rank file to scratch dir if defined + if args.scratch is not None and rank is not None: + basename = os.path.basename(pathname) + pathname = os.path.join(args.scratch, basename) + if rank is not None: filename = f"{pathname}_{key}_{args.level}_{rank}" else: @@ -384,7 +354,7 @@ def get_filename(args, key, rank=None): return filename def rank_files_write(args, dset, idx, encoder): - tokenize_start = time.time() + time_start = time.time() # compute total number of samples we'e processing num_samples = get_num_samples(args, len(dset)) @@ -394,13 +364,13 @@ def rank_files_write(args, dset, idx, encoder): dset_stats = np.zeros(3, dtype=np.int64) # docs, sentences, bytes # we'll set this to false on any problem - success = True err = None + times = np.zeros(3, dtype=np.float32) # read, tokenize, write try: # create data file for each rank if args.rank == 0: - print(f"Vocab size: {args.vocab_size}") - print(f"Output prefix: {args.output_prefix}") + msg(f"Vocab size: {args.vocab_size}") + msg(f"Output prefix: {args.output_prefix}") output_bin_files = {} output_idx_files = {} builders = {} @@ -408,9 +378,10 @@ def rank_files_write(args, dset, idx, encoder): filebase = get_filename(args, key, args.rank) output_bin_files[key] = data_file_path(filebase) output_idx_files[key] = index_file_path(filebase) + best_dtype = best_fitting_dtype(args.vocab_size) if args.dataset_impl == "mmap" else None builders[key] = make_builder(output_bin_files[key], impl=args.dataset_impl, - dtype=best_fitting_dtype(args.vocab_size)) + dtype=best_dtype) # each rank tokenizes its samples and writes its own file progress_next = time.time() + float(args.log_interval) @@ -418,10 +389,13 @@ def rank_files_write(args, dset, idx, encoder): sample_id = int(i) for key in args.columns: # tokenize text for the given sample index + start_read = time.time() text = dset[sample_id][key] + start_encode = time.time() doc, bytes_processed = encoder.encode_text(text) # add tokenized sequence to our data file + start_write = time.time() for key, sentences in doc.items(): for sentence in sentences: builders[key].add_item(torch.IntTensor(sentence)) @@ -429,22 +403,26 @@ def rank_files_write(args, dset, idx, encoder): dset_stats[0] += 1 dset_stats[1] += len(sentences) dset_stats[2] += bytes_processed + end_write = time.time() + + times[0] += start_encode - start_read + times[1] += start_write - start_encode + times[2] += end_write - start_write if args.rank == 0 and args.log_interval > 0 and time.time() > progress_next: current = time.time() progress_next = current + float(args.log_interval) - elapsed = current - tokenize_start - timestamp = time.strftime("%Y-%m-%dT%H:%M:%S") + elapsed = current - time_start docs = dset_stats[0] * args.numranks percent = docs / num_samples * 100.0 docrate = docs / elapsed if elapsed > 0.0 else 0.0 mbs = dset_stats[2] * args.numranks / elapsed / 1024 / 1024 if elapsed > 0.0 else 0.0 secs_left = int((num_samples - docs) / docrate if docrate > 0.0 else 0.0) - print(f"{timestamp}: Processed (estimated) {docs} of {num_samples} docs ({percent:0.2f}%),", - f"{docrate:0.3f} docs/s, {mbs:0.3f} MB/s,", - f"{secs_left} secs left ...", - flush=True) + msg(f"Processed (estimated) {docs} of {num_samples} docs ({percent:0.2f}%) in {int(elapsed)} secs, " + f"{docrate:0.3f} docs/s, {mbs:0.3f} MB/s, " + f"{secs_left} secs left ...", + flush=True) # finalize file of each rank for key in args.columns: @@ -452,40 +430,80 @@ def rank_files_write(args, dset, idx, encoder): del builders[key] # file closed in __del__ except Exception as e: # caught an exception, assume our file is invalid - success = False err = e # In case rank 0 finishes early and stops printing progress messages, # inform user that it's waiting for other ranks to finish. if args.rank == 0 and args.log_interval > 0: - timestamp = time.strftime("%Y-%m-%dT%H:%M:%S") - print(f"{timestamp}: Waiting for ranks to finalize files ...", flush=True) + msg(f"Waiting for ranks to finalize files ...", flush=True) # wait for all ranks to finish their files - barrier(args) - tokenize_end = time.time() + args.distctx.barrier() + time_end = time.time() # compute total stats across all processes - all_sum_(args, dset_stats) + args.distctx.all_sum_(times) + args.distctx.all_sum_(dset_stats) if args.rank == 0: - secs = tokenize_end - tokenize_start + secs = time_end - time_start docrate = dset_stats[0] / secs if secs > 0.0 else 0.0 sentrate = dset_stats[1] / secs if secs > 0.0 else 0.0 byterate = dset_stats[2] / secs if secs > 0.0 else 0.0 - print("Tokenize stats:", secs) - print(f" Seconds to tokenize: {secs}") - print(f" {dset_stats[0]} docs {docrate} docs/sec") - print(f" {dset_stats[1]} sents {sentrate} sents/sec") - print(f" {dset_stats[2]} bytes {format_byterate(byterate)}") - - # allreduce to check whether all ranks wrote their part successfully - success = all_true(args, success) - return success, err - -def rank_files_merge(args): - # rank 0 merges all per-rank files + secs_read_per_sample = times[0] / dset_stats[0] if dset_stats[0] > 0 else 0.0 + secs_encode_per_sample = times[1] / dset_stats[0] if dset_stats[0] > 0 else 0.0 + secs_write_per_sample = times[2] / dset_stats[0] if dset_stats[0] > 0 else 0.0 + msg("Process stats:") + msg(f" Seconds to process: {secs}") + msg(f" {dset_stats[0]} docs {docrate} docs/sec") + msg(f" {dset_stats[1]} sents {sentrate} sents/sec") + msg(f" {dset_stats[2]} bytes {format_byterate(byterate)}") + msg(f" Total read seconds {times[0]}, {secs_read_per_sample} sec/sample") + msg(f" Total encode seconds {times[1]}, {secs_encode_per_sample} sec/sample") + msg(f" Total write seconds {times[2]}, {secs_write_per_sample} sec/sample") + + # check whether all ranks wrote their part successfully + args.distctx.allraise_if(err) + +def rank_files_merge_parallel(args): + """Each process directly writes its portion of the data from its per-rank file into the final file.""" + merge_start = time.time() + numbytes = np.zeros(1, dtype=np.int64) + for key in args.columns: + # merge the per-rank file from each process into a single shared file + filemain = get_filename(args, key) + filerank = get_filename(args, key, args.rank) + gather_files_dist(filemain, [filerank], args.distctx) + + # total up bytes read during the merge + binfilerank = data_file_path(filerank) + idxfilerank = index_file_path(filerank) + numbytes[0] += os.stat(binfilerank)[stat.ST_SIZE] + numbytes[0] += os.stat(idxfilerank)[stat.ST_SIZE] + + # If user want to use both a parallel and serial merge (for testing), + # rename the parallel output files so that the serial merge does not clobber them. + if args.merge == 'both' and args.rank == 0: + binfilemain = data_file_path(filemain) + idxfilemain = index_file_path(filemain) + os.rename(binfilemain, binfilemain + ".par") + os.rename(idxfilemain, idxfilemain + ".par") + + # Total up number of bytes read across all ranks, + # and wait on all ranks before stopping the timer. + args.distctx.all_sum_(numbytes) + merge_end = time.time() + if args.rank == 0: + secs = merge_end - merge_start + byterate = numbytes[0] / secs if secs > 0.0 else 0.0 + msg("Parallel merge stats:") + msg(f" Scratch: {args.scratch}") + msg(f" Seconds to merge: {secs}") + msg(f" {int(numbytes)} bytes {format_byterate(byterate)}") + +def rank_files_merge_serial(args): + """Rank 0 merges data from all per-rank files into the final file.""" if args.rank == 0: - print("Merging rank files ...", flush=True) + msg("Merging rank files ...", flush=True) merge_start = time.time() numbytes = 0 @@ -497,25 +515,25 @@ def rank_files_merge(args): filebase = get_filename(args, key) output_bin_files[key] = data_file_path(filebase) output_idx_files[key] = index_file_path(filebase) + best_dtype = best_fitting_dtype(args.vocab_size) if args.dataset_impl == "mmap" else None builders[key] = make_builder(output_bin_files[key], impl=args.dataset_impl, - dtype=best_fitting_dtype(args.vocab_size)) + dtype=best_dtype) # merge all ranks into one file for rank in range(args.numranks): for key in args.columns: infile = get_filename(args, key, rank) - - print(f"Merging file {infile}", flush=True) builders[key].merge_file_(infile) # sum up the number of merged bytes binfile = data_file_path(infile) - filesize = os.stat(binfile)[stat.ST_SIZE] - numbytes += filesize + idxfile = index_file_path(infile) + numbytes += os.stat(binfile)[stat.ST_SIZE] + numbytes += os.stat(idxfile)[stat.ST_SIZE] # finalize the merged file - print("Finalizing merged file ...", flush=True) + msg("Finalizing merged file ...", flush=True) for key in args.columns: builders[key].finalize(output_idx_files[key]) del builders[key] # file closed in __del__ @@ -523,18 +541,31 @@ def rank_files_merge(args): merge_end = time.time() secs = merge_end - merge_start byterate = numbytes / secs if secs > 0.0 else 0.0 - print(f"Merged {args.numranks} files into {args.output_prefix}") - print("Merge stats:") - print(f" Seconds to merge: {secs}") - print(f" {numbytes} bytes {format_byterate(byterate)}") + msg(f"Merged {args.numranks} files into {args.output_prefix}") + msg("Serial merge stats:") + msg(f" Seconds to merge: {secs}") + msg(f" {numbytes} bytes {format_byterate(byterate)}") # hold everyone until rank 0 is done - barrier(args) + args.distctx.barrier() + +def rank_files_merge(args): + # use parallel merge if asked + if args.merge in ['parallel', 'both']: + rank_files_merge_parallel(args) + + # if using node-local storage, skip sequential merge + if args.scratch is not None: + return + + # can fall back to a serial merge + if args.merge in ['serial', 'both']: + rank_files_merge_serial(args) def rank_files_delete(args): # delete per-rank files if args.rank == 0: - print("Deleting rank files ...", flush=True) + msg("Deleting rank files ...", flush=True) for key in args.columns: filebase = get_filename(args, key, args.rank) @@ -548,24 +579,17 @@ def rank_files_delete(args): os.remove(idxfile) # hold everyone until all are done - barrier(args) + args.distctx.barrier() def main(): args = get_args() startup_start = time.time() - # connect processes and cache our rank and number of procs in args - init_distributed(args) - # load the dataset - dset, err = load_dset(args) - if dset is None: - if err is not None: - raise err - return + dset = load_dset(args) if args.rank == 0: print(dset) - print("Selecting features:", args.columns) + msg(f"Processing features: {args.columns}") # create sample index list, # optionally shuffle the list, @@ -579,26 +603,26 @@ def main(): args.vocab_size = encoder.tokenizer.vocab_size # wait for all ranks before stopping timer - barrier(args) + args.distctx.barrier() startup_end = time.time() if args.rank == 0: - print("Seconds to startup:", startup_end - startup_start) + msg(f"Seconds to startup: {startup_end - startup_start}") - # have each rank write its file, returns False if any rank had a problem - success, err = rank_files_write(args, dset, idx, encoder) - if not success: + # have each rank write its file, + # all ranks should raise an exception if any rank has a problem + try: + rank_files_write(args, dset, idx, encoder) + except Exception as e: + # If any process fails, we skip the merge since the resulting file would be invalid. + # We still delete files to clean up, since those might be invalid anyway. if args.rank == 0: - # If any process fails, we skip the merge since the resulting file would be invalid. - # We still delete files to clean up, since those might be invalid anyway. - print(f"ERROR: At least one process failed to write its file, skipping merge and cleaning up", flush=True) + msgerr(f"At least one process failed to write its file, skipping merge and cleaning up", flush=True) # delete per-rank files, do this even on error rank_files_delete(args) - # raise exception caught during write phase - if err is not None: - raise err - return + # re-raise exception caught during write phase + raise e # all ranks were successful writing their file, merge them into one rank_files_merge(args) @@ -606,5 +630,10 @@ def main(): # delete per-rank files rank_files_delete(args) + end_time = time.time() + if args.rank == 0: + msg(f"Runtime: {end_time - startup_start} secs", flush=True) + msg(f"Done") + if __name__ == '__main__': main() From 1fea302d3acc9dd7c6b58738fd1fe867a4880c83 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Thu, 26 Aug 2021 13:54:29 -0700 Subject: [PATCH 19/20] clean up merge --- tools/preprocess_data_dist.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tools/preprocess_data_dist.py b/tools/preprocess_data_dist.py index 62f872a26..1fae8968f 100644 --- a/tools/preprocess_data_dist.py +++ b/tools/preprocess_data_dist.py @@ -588,7 +588,12 @@ def main(): startup_start = time.time() # load the dataset - dset = load_dset(args) + if args.input.endswith(".jsonl"): + # assume file is JSONL format + dset = IndexedJSON(args.input, args.mpi_comm) + else: + # otherwise load HuggingFace dataset + dset = load_dset(args) if args.rank == 0: print(dset) msg(f"Processing features: {args.columns}") From d3603130235e6330d8c2aa7b6fb6b2e7ca00b57a Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Thu, 26 Aug 2021 13:56:14 -0700 Subject: [PATCH 20/20] clean up merge --- megatron/data/indexed_dataset.py | 35 ++++++-------------------------- tools/preprocess_data_dist.py | 2 +- 2 files changed, 7 insertions(+), 30 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 99d23bf9a..025e9e333 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -459,21 +459,21 @@ def __init__(self, path, skip_warmup=False): offset = stream.tell() if not skip_warmup: -# print_rank_0(" warming up index mmap file...") + print_rank_0(" warming up index mmap file...") _warmup_mmap_file(path) self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') self._bin_buffer = memoryview(self._bin_buffer_mmap) -# print_rank_0(" reading sizes...") + print_rank_0(" reading sizes...") self._sizes = np.frombuffer( self._bin_buffer, dtype=np.int32, count=self._len, offset=offset) -# print_rank_0(" reading pointers...") + print_rank_0(" reading pointers...") self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes) -# print_rank_0(" reading document index...") + print_rank_0(" reading document index...") self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count, offset=offset + self._sizes.nbytes + self._pointers.nbytes) @@ -618,8 +618,8 @@ def merge_file_(self, another_file): index = MMapIndexedDataset.Index(index_file_path(another_file)) assert index.dtype == self._dtype -# total_len = len(index.sizes)+len(self._sizes) -# print(f" concat {another_file} size={len(index.sizes)} for a total size of {total_len}") + total_len = len(index.sizes)+len(self._sizes) + print(f" concat {another_file} size={len(index.sizes)} for a total size of {total_len}") offset = len(self._sizes) self._sizes.extend(index.sizes) @@ -764,29 +764,6 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx): if dtype_value is None: dtype_value = dtype_code if dtype_value != dtype_code: - dtype_valid = False - - # Check that we have consistent dtypes in all files from all ranks - dtype = gather_files_dist_check_dtype(filelist, dtype_valid, dtype_value, distctx) - - # Capture the last value in dim and data arrays before we delete any items. - # Note this may be zero on any rank that has no items, - # but zero is the correct value in that case. - # We use this last value to compute a shift value that will - # later be added to each element in our lists. - dim_shift = distctx.exscan(dim_offsets[-1]) - data_shift = distctx.exscan(data_offsets[-1]) - - # Drop the zero entry from the lists that start with - # a "0" value unless we're rank 0 - if distctx.rank != 0: - del data_offsets[0] - del dim_offsets[0] - del docs[0] - - # Compute total number of entires in data, size, dim, - # and docs lists across all ranks. Also compute the offset -======= dtype_rank_consistent = False # Check that we have consistent dtypes in all files from all ranks, diff --git a/tools/preprocess_data_dist.py b/tools/preprocess_data_dist.py index 1fae8968f..afb349a6f 100644 --- a/tools/preprocess_data_dist.py +++ b/tools/preprocess_data_dist.py @@ -216,7 +216,7 @@ def get_args(): # TODO: perhaps more user friendly to disable scratch and print a warning? # check that serial merge is not attempted with scratch if args.scratch is not None and args.merge != 'parallel': - raise ValueError("The --scratch option is only valid with --merge=parallel") + raise ValueError("The --scratch option is only valid with --merge=parallel") return args